Skip to content

vllm_omni.diffusion.models.bagel.autoencoder

logger module-attribute

logger = init_logger(__name__)

AttnBlock

Bases: Module

in_channels instance-attribute

in_channels = in_channels

k instance-attribute

k = Conv2d(in_channels, in_channels, kernel_size=1)

norm instance-attribute

norm = GroupNorm(
    num_groups=32,
    num_channels=in_channels,
    eps=1e-06,
    affine=True,
)

proj_out instance-attribute

proj_out = Conv2d(in_channels, in_channels, kernel_size=1)

q instance-attribute

q = Conv2d(in_channels, in_channels, kernel_size=1)

v instance-attribute

v = Conv2d(in_channels, in_channels, kernel_size=1)

attention

attention(h_: Tensor) -> Tensor

forward

forward(x: Tensor) -> Tensor

AutoEncoder

Bases: Module

decoder instance-attribute

decoder = Decoder(
    resolution=resolution,
    in_channels=in_channels,
    ch=ch,
    out_ch=out_ch,
    ch_mult=ch_mult,
    num_res_blocks=num_res_blocks,
    z_channels=z_channels,
)

encoder instance-attribute

encoder = Encoder(
    resolution=resolution,
    in_channels=in_channels,
    ch=ch,
    ch_mult=ch_mult,
    num_res_blocks=num_res_blocks,
    z_channels=z_channels,
)

reg instance-attribute

scale_factor instance-attribute

scale_factor = scale_factor

shift_factor instance-attribute

shift_factor = shift_factor

decode

decode(z: Tensor) -> Tensor

encode

encode(x: Tensor) -> Tensor

forward

forward(x: Tensor) -> Tensor

AutoEncoderParams dataclass

ch instance-attribute

ch: int

ch_mult instance-attribute

ch_mult: list[int]

downsample instance-attribute

downsample: int

in_channels instance-attribute

in_channels: int

num_res_blocks instance-attribute

num_res_blocks: int

out_ch instance-attribute

out_ch: int

resolution instance-attribute

resolution: int

scale_factor instance-attribute

scale_factor: float

shift_factor instance-attribute

shift_factor: float

z_channels instance-attribute

z_channels: int

Decoder

Bases: Module

ch instance-attribute

ch = ch

conv_in instance-attribute

conv_in = Conv2d(
    z_channels, block_in, kernel_size=3, stride=1, padding=1
)

conv_out instance-attribute

conv_out = Conv2d(
    block_in, out_ch, kernel_size=3, stride=1, padding=1
)

ffactor instance-attribute

ffactor = 2 ** (num_resolutions - 1)

in_channels instance-attribute

in_channels = in_channels

mid instance-attribute

mid = Module()

norm_out instance-attribute

norm_out = GroupNorm(
    num_groups=32,
    num_channels=block_in,
    eps=1e-06,
    affine=True,
)

num_res_blocks instance-attribute

num_res_blocks = num_res_blocks

num_resolutions instance-attribute

num_resolutions = len(ch_mult)

resolution instance-attribute

resolution = resolution

up instance-attribute

up = ModuleList()

z_shape instance-attribute

z_shape = (1, z_channels, curr_res, curr_res)

forward

forward(z: Tensor) -> Tensor

DiagonalGaussian

Bases: Module

chunk_dim instance-attribute

chunk_dim = chunk_dim

sample instance-attribute

sample = sample

forward

forward(z: Tensor) -> Tensor

DistributedAutoEncoder

Bases: AutoEncoder, DistributedVaeMixin

spatial_compression_ratio instance-attribute

spatial_compression_ratio = downsample

tile_sample_min_height instance-attribute

tile_sample_min_height = 512

tile_sample_min_width instance-attribute

tile_sample_min_width = 512

tile_sample_stride_height instance-attribute

tile_sample_stride_height = 448

tile_sample_stride_width instance-attribute

tile_sample_stride_width = 448

use_tiling instance-attribute

use_tiling = False

blend_h

blend_h(
    left: Tensor, current: Tensor, blend_extent: int
) -> Tensor

blend_v

blend_v(
    above: Tensor, current: Tensor, blend_extent: int
) -> Tensor

decode

decode(z: Tensor) -> Tensor

decode_tile_exec

decode_tile_exec(task: TileTask) -> Tensor

decode_tile_merge

decode_tile_merge(
    coord_tensor_map: dict[tuple[int, ...], Tensor],
    grid_spec: GridSpec,
) -> Tensor

decode_tile_split

decode_tile_split(
    z: Tensor,
) -> tuple[list[TileTask], GridSpec]

encode

encode(x: Tensor) -> Tensor

encode_tile_exec

encode_tile_exec(task: TileTask) -> Tensor

encode_tile_merge

encode_tile_merge(
    coord_tensor_map: dict[tuple[int, ...], Tensor],
    grid_spec: GridSpec,
) -> Tensor

encode_tile_split

encode_tile_split(
    x: Tensor,
) -> tuple[list[TileTask], GridSpec]

Downsample

Bases: Module

conv instance-attribute

conv = Conv2d(
    in_channels,
    in_channels,
    kernel_size=3,
    stride=2,
    padding=0,
)

forward

forward(x: Tensor)

Encoder

Bases: Module

ch instance-attribute

ch = ch

conv_in instance-attribute

conv_in = Conv2d(
    in_channels, ch, kernel_size=3, stride=1, padding=1
)

conv_out instance-attribute

conv_out = Conv2d(
    block_in,
    2 * z_channels,
    kernel_size=3,
    stride=1,
    padding=1,
)

down instance-attribute

down = ModuleList()

in_ch_mult instance-attribute

in_ch_mult = in_ch_mult

in_channels instance-attribute

in_channels = in_channels

mid instance-attribute

mid = Module()

norm_out instance-attribute

norm_out = GroupNorm(
    num_groups=32,
    num_channels=block_in,
    eps=1e-06,
    affine=True,
)

num_res_blocks instance-attribute

num_res_blocks = num_res_blocks

num_resolutions instance-attribute

num_resolutions = len(ch_mult)

resolution instance-attribute

resolution = resolution

forward

forward(x: Tensor) -> Tensor

ResnetBlock

Bases: Module

conv1 instance-attribute

conv1 = Conv2d(
    in_channels,
    out_channels,
    kernel_size=3,
    stride=1,
    padding=1,
)

conv2 instance-attribute

conv2 = Conv2d(
    out_channels,
    out_channels,
    kernel_size=3,
    stride=1,
    padding=1,
)

in_channels instance-attribute

in_channels = in_channels

nin_shortcut instance-attribute

nin_shortcut = Conv2d(
    in_channels,
    out_channels,
    kernel_size=1,
    stride=1,
    padding=0,
)

norm1 instance-attribute

norm1 = GroupNorm(
    num_groups=32,
    num_channels=in_channels,
    eps=1e-06,
    affine=True,
)

norm2 instance-attribute

norm2 = GroupNorm(
    num_groups=32,
    num_channels=out_channels,
    eps=1e-06,
    affine=True,
)

out_channels instance-attribute

out_channels = out_channels

forward

forward(x)

Upsample

Bases: Module

conv instance-attribute

conv = Conv2d(
    in_channels,
    in_channels,
    kernel_size=3,
    stride=1,
    padding=1,
)

forward

forward(x: Tensor)

swish

swish(x: Tensor) -> Tensor