10 Commits

Author SHA1 Message Date
  Shen-Chenhui 510e457f13 Merge branch 'upload/v2.0' of https://github.com/hpcaitech/Open-Sora into upload/v2.0 9 months ago
  Shen-Chenhui 5588eed12f train ae default without wandb 9 months ago
  Hongxin Liu 307d80da78
[hotfix] fix ring attn bwd for fa3 (#803) 9 months ago
  Shen-Chenhui 18e2df8b17 align naming of ae 9 months ago
  Shen-Chenhui 48c6a4a917 update vae & ae paths 9 months ago
  Shen-Chenhui 1b8253c9f6 further cleanup 9 months ago
  zhengzangw 2092cfb060 update ae.md 9 months ago
  Shen-Chenhui f3a2a884d1 Merge branch 'upload/v2.0' of https://github.com/hpcaitech/Open-Sora into upload/v2.0 9 months ago
  Shen-Chenhui 814a6cbbe4 cleaned up vae 9 months ago
  zhengzangw a351ef70d3 update hcae config 9 months ago
42 changed files with 271 additions and 1669 deletions
Split View
  1. +1
    -1
      README.md
  2. +4
    -0
      configs/diffusion/inference/high_compression.py
  3. +2
    -1
      configs/diffusion/train/high_compression.py
  4. +0
    -26
      configs/vae/inference/causal_dcae.py
  5. +0
    -27
      configs/vae/inference/dc_ae.py
  6. +0
    -29
      configs/vae/inference/flux_ae_2d.py
  7. +0
    -20
      configs/vae/inference/hunyuan_4x16x16.py
  8. +0
    -23
      configs/vae/inference/hunyuan_video.py
  9. +33
    -0
      configs/vae/inference/hunyuanvideo_vae.py
  10. +0
    -11
      configs/vae/inference/image.py
  11. +0
    -33
      configs/vae/inference/video.py
  12. +32
    -0
      configs/vae/inference/video_dc_ae.py
  13. +0
    -38
      configs/vae/train/256px.py
  14. +0
    -12
      configs/vae/train/256px_nodisc.py
  15. +0
    -34
      configs/vae/train/512px.py
  16. +0
    -39
      configs/vae/train/debug.py
  17. +0
    -49
      configs/vae/train/hunyuan_4x16x16.py
  18. +0
    -19
      configs/vae/train/hunyuan_8x8x8_disc.py
  19. +0
    -10
      configs/vae/train/hunyuan_residual.py
  20. +0
    -30
      configs/vae/train/hunyuan_video.py
  21. +0
    -15
      configs/vae/train/hunyuan_video_disc.py
  22. +0
    -9
      configs/vae/train/hunyuan_video_disc_fp32.py
  23. +0
    -5
      configs/vae/train/hunyuan_video_disc_hinge.py
  24. +0
    -10
      configs/vae/train/image.py
  25. +0
    -21
      configs/vae/train/image_dc_ae.py
  26. +0
    -102
      configs/vae/train/video.py
  27. +1
    -10
      configs/vae/train/video_dc_ae.py
  28. +4
    -3
      configs/vae/train/video_dc_ae_disc.py
  29. +143
    -7
      docs/ae.md
  30. +8
    -4
      docs/hcae.md
  31. +6
    -62
      opensora/models/dc_ae/ae_model_zoo.py
  32. +13
    -152
      opensora/models/dc_ae/models/dc_ae.py
  33. +2
    -23
      opensora/models/dc_ae/utils/init.py
  34. +3
    -24
      opensora/models/hunyuan_vae/autoencoder_kl_causal_3d.py
  35. +1
    -6
      opensora/models/hunyuan_vae/distributed.py
  36. +0
    -151
      opensora/models/hunyuan_vae/unet_causal_3d_blocks.py
  37. +0
    -8
      opensora/models/hunyuan_vae/vae.py
  38. +0
    -621
      opensora/models/hunyuan_vae/vae_channel.py
  39. +13
    -6
      opensora/models/mmdit/distributed.py
  40. +0
    -12
      opensora/models/vae/losses.py
  41. +4
    -6
      scripts/diffusion/train.py
  42. +1
    -10
      scripts/vae/train.py

+ 1
- 1
README.md View File

@@ -123,7 +123,7 @@ see [here](/assets/texts/t2v_samples.txt) for full prompts.
- **[Tech Report of Open-Sora 2.0]()**
- **[Step by step to train or finetune your own model](docs/train.md)**
- **[Step by step to train and evaluate an video autoencoder](docs/ae.md)**
- **[Visit the high compression video autoencoder](docs/hc-ae.md)**
- **[Visit the high compression video autoencoder](docs/hcae.md)**
- Reports of previous version (better see in according branch):
- [Open-Sora 1.3](docs/report_04.md): shift-window attention, unified spatial-temporal VAE, etc.
- [Open-Sora 1.2](docs/report_03.md), [Tech Report](https://arxiv.org/abs/2412.20404): rectified flow, 3d-VAE, score condition, evaluation, etc.


+ 4
- 0
configs/diffusion/inference/high_compression.py View File

@@ -28,3 +28,7 @@ ae = dict(
tile_overlap_factor=0.25,
)
ae_spatial_compression = 32

sampling_option = dict(
num_frames=128,
)

+ 2
- 1
configs/diffusion/train/high_compression.py View File

@@ -41,10 +41,11 @@ condition_config = dict(
i2v_head=7,
)

grad_ckpt_settings = (100, 100)
patch_size = 1
model = dict(
from_pretrained=None,
grad_ckpt_settings=None,
grad_ckpt_settings=grad_ckpt_settings,
in_channels=128,
cond_embed=True,
patch_size=patch_size,


+ 0
- 26
configs/vae/inference/causal_dcae.py View File

@@ -1,26 +0,0 @@
_base_ = ["video.py"]

model = dict(
_delete_=True,
type="hunyuan_vae",
from_pretrained=None,
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=32,
scale_factor=0.476986,
shift_factor=0,
use_spatial_tiling=True,
use_temporal_tiling=True,
# architecture
channel=True,
time_compression_ratio=4,
spatial_compression_ratio=32,
block_out_channels=(128, 128, 256, 256, 512, 512),
# set the following to True to use residual
encoder_add_residual=True,
decoder_add_residual=True,
# residual slice or pad
encoder_slice_t=False,
decoder_slice_t=True,
)

+ 0
- 27
configs/vae/inference/dc_ae.py View File

@@ -1,27 +0,0 @@
dtype = "bf16"
batch_size = 1
seed = 42
save_dir = "samples/vae_vid"

plugin = "zero2"
dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=16,
data_path="/mnt/ddn/sora/meta/test/vid_vae.csv",
)
bucket_config = {
"256px_ar1:1": {32: (1.0, 1)},
}

model = dict(
type="dc_ae",
# model_name="mit-han-lab/dc-ae-f32c32-sana-1.0",
# model_name="dc-ae-f32c32-sana-1.0",
model_name="dc-ae-f128c512-sana-1.0",
from_scratch=True,
)

num_workers = 24
num_bucket_build_workers = 16
prefetch_factor = 4

+ 0
- 29
configs/vae/inference/flux_ae_2d.py View File

@@ -1,29 +0,0 @@
dtype = "bf16"
batch_size = 1
seed = 42
save_dir = "samples/flux_ae_2d"

dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=16,
data_path="/mnt/jfs-hdd/sora/meta/validation/img_1k.csv",
)

bucket_config = {
"1024px_ar1:1": {1: (1.0, 1)},
}

model = dict(
type="autoencoder_2d",
from_pretrained="pretrained_models/flux1-dev/ae.safetensors",
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=1.0,
shift_factor=0.0,
)

+ 0
- 20
configs/vae/inference/hunyuan_4x16x16.py View File

@@ -1,20 +0,0 @@
_base_ = ["hunyuan_video.py"]


model = dict(
_delete_=True,
type="hunyuan_vae",
from_pretrained=None,
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=16,
dropout=0.0,
block_out_channels=(128, 256, 256, 512, 512),
time_compression_ratio=4,
spatial_compression_ratio=16,
encoder_add_residual=True,
encoder_slice_t=True,
decoder_add_residual=True,
decoder_slice_t=True,
)

+ 0
- 23
configs/vae/inference/hunyuan_video.py View File

@@ -1,23 +0,0 @@
_base_ = ["video.py"]

model = dict(
_delete_=True,
type="hunyuan_vae",
from_pretrained="/mnt/jfs-hdd/sora/checkpoints/pretrained_models/hunyuan-video-t2v-720p/vae/pytorch_model.pt",
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=16,
scale_factor=0.476986,
shift_factor=0,
use_spatial_tiling=True,
use_temporal_tiling=True,
# set the following to True to use residual
encoder_add_residual=False,
decoder_add_residual=False,
# residual slice or pad
encoder_slice_t=False,
decoder_slice_t=False,
# temporal
time_compression_ratio=4,
)

+ 33
- 0
configs/vae/inference/hunyuanvideo_vae.py View File

@@ -0,0 +1,33 @@
dtype = "bf16"
batch_size = 1
seed = 42
save_dir = "samples/hunyuanvideo_vae"

plugin = "zero2"
dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=16,
data_path="datasets/pexels_45k_necessary.csv",
)
bucket_config = {
"512px_ar1:1": {97: (1.0, 1)},
}

num_workers = 24
num_bucket_build_workers = 16
prefetch_factor = 4

model = dict(
type="hunyuan_vae",
from_pretrained="./ckpts/hunyuan_vae.safetensors",
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=16,
scale_factor=0.476986,
shift_factor=0,
use_spatial_tiling=True,
use_temporal_tiling=True,
time_compression_ratio=4,
)

+ 0
- 11
configs/vae/inference/image.py View File

@@ -1,11 +0,0 @@
_base_ = ["video.py"]

save_dir = "samples/vae_img"
dataset = dict(
type="video_text",
transform_name="resize_crop",
data_path="/mnt/ddn/sora/meta/test/image_vae.csv",
)
bucket_config = {
"512px_ar1:1": {1: (1.0, 1)},
}

+ 0
- 33
configs/vae/inference/video.py View File

@@ -1,33 +0,0 @@
dtype = "bf16"
batch_size = 1
seed = 42
save_dir = "samples/vae_vid"

plugin = "zero2"
dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=16,
data_path="/mnt/ddn/sora/meta/test/vid_vae.csv",
)
bucket_config = {
"256px_ar1:1": {32: (1.0, 1)},
}

model = dict(
type="autoencoder_3d",
from_pretrained=None,
in_channels=3,
out_ch=3,
ch=128,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=1.0,
shift_factor=0.0,
tiling=4,
)

num_workers = 24
num_bucket_build_workers = 16
prefetch_factor = 4

+ 32
- 0
configs/vae/inference/video_dc_ae.py View File

@@ -0,0 +1,32 @@
dtype = "bf16"
batch_size = 1
seed = 42

dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=16,
data_path="datasets/pexels_45k_necessary.csv",
)
bucket_config = {
"512px_ar1:1": {96: (1.0, 1)},
}

model = dict(
type="dc_ae",
model_name="dc-ae-f32t4c128",
from_pretrained="./ckpts/F32T4C128_AE.safetensors",
from_scratch=True,
use_spatial_tiling=True,
use_temporal_tiling=True,
spatial_tile_size=256,
temporal_tile_size=32,
tile_overlap_factor=0.25,
)

save_dir = "samples/video_dc_ae"

num_workers = 24
num_bucket_build_workers = 16
prefetch_factor = 4


+ 0
- 38
configs/vae/train/256px.py View File

@@ -1,38 +0,0 @@
_base_ = ["video.py"]

dtype = "bf16"
ckpt_every = 2500

mixed_image_ratio = 0.01
restart_disc = True

# == loss weights ==
opl_loss_weight = 0.0
vae_loss_config = dict(
kl_loss_weight=1e-6,
perceptual_loss_weight=1.0,
)
gen_loss_config = dict(
disc_factor=1.0,
disc_weight=0.5, # proportion of grad norm w.r.t. nll loss
)
disc_loss_config = dict(
disc_loss_type="hinge",
)

# == optimizer ==
optim = dict(
lr=1e-6,
betas=(0.9, 0.999),
)
optim_discriminator = dict(
lr=1e-6,
betas=(0.9, 0.999),
)
ema_decay = None

# TORCH_COMPILE_DISABLE=1 torchrun --nproc_per_node 8 scripts/vae/train.py configs/vae/train/256px.py --dataset.data-path cache/meta/tmp_vae_bpp_bppmin-0.035.csv --model.from_pretrained /mnt/jfs-hdd/sora/checkpoints/pretrained_models/vae_videoocean_1025.pt --wandb True

# TORCH_COMPILE_DISABLE=1 CUDA_VISIBLE_DEVICES=7 torchrun --master-port 14312 --nproc_per_node 1 scripts/vae/inference.py configs/vae/inference/video.py --dataset.data-path /mnt/jfs-hdd/sora/data/eval_loss/eval_vid.csv --save-dir samples/vae_12_14 --ckpt-path outputs/241214_012637-vae_train_video/epoch0-global_step17000 --eval-setting 32x512x512 --type video

# CUDA_VISIBELE_DEVICES=6 python vae/eval_common_metric.py --batch_size 1 --real_video_dir ~/Video-Ocean/samples/vae_12_13/orig --generated_video_dir ~/Video-Ocean/samples/vae_12_13/recn --device cuda --sample_fps 16 --crop_size 512 --resolution 512 --num_frames 32 --sample_rate 1 --metric ssim psnr lpips

+ 0
- 12
configs/vae/train/256px_nodisc.py View File

@@ -1,12 +0,0 @@
_base_ = ["256px.py"]

vae_loss_config = dict(
perceptual_loss_weight=0.1,
)
gen_loss_config = dict(
disc_factor=0.0,
)
disc_loss_config = dict(
disc_factor=0.0,
)
discriminator = None

+ 0
- 34
configs/vae/train/512px.py View File

@@ -1,34 +0,0 @@
_base_ = ["video.py"]

bucket_config = {
"_delete_": True,
"512px_ar1:1": {32: (1.0, 1)},
}
grad_checkpoint = True

# == loss weights ==
opl_loss_weight = 1e3
vae_loss_config = dict(
kl_loss_weight=5e-4,
perceptual_loss_weight=0.1,
)
gen_loss_config = dict(
disc_factor=1.0,
disc_weight=0.2,
)
optim = dict(
cls="HybridAdam",
lr=5e-6,
eps=1e-8,
weight_decay=0.0,
adamw_mode=True,
betas=(0.9, 0.999),
)

# TORCH_COMPILE_DISABLE=1 CUDA_VISIBLE_DEVICES=7 torchrun --master-port 14312 --nproc_per_node 1 scripts/vae/inference.py configs/vae/inference/video.py --dataset.data-path /mnt/jfs-hdd/sora/data/eval_loss/eval_vid.csv --save-dir samples/vae_12_14 --ckpt-path outputs/241214_012637-vae_train_video/epoch0-global_step17000 --eval-setting 32x512x512 --type video

# CUDA_VISIBELE_DEVICES=6 python vae/eval_common_metric.py --batch_size 1 --real_video_dir ~/Video-Ocean/samples/vae_12_13/orig --generated_video_dir ~/Video-Ocean/samples/vae_12_13/recn --device cuda --sample_fps 16 --crop_size 512 --resolution 512 --num_frames 32 --sample_rate 1 --metric ssim psnr lpips

# TORCH_COMPILE_DISABLE=1 torchrun --nproc_per_node 8 --master_port 30303 scripts/vae/train.py configs/vae/train/video.py --dataset.data-path /mnt/ddn/sora/meta/merge/vo1.1_stage-1_res-256/20241210.csv --model.from_pretrained /mnt/jfs-hdd/sora/checkpoints/pretrained_models/vae_videoocean_1025.pt

# TORCH_COMPILE_DISABLE=1 torchrun --nproc_per_node 8 --master_port 30303 scripts/vae/train.py configs/vae/train/512px.py --dataset.data-path /mnt/ddn/sora/meta/merge/vo1.1_stage-1_res-256/20241210.csv --model.from_pretrained /mnt/jfs-hdd/sora/checkpoints/pretrained_models/vae_videoocean_1025.pt

+ 0
- 39
configs/vae/train/debug.py View File

@@ -1,39 +0,0 @@
_base_ = ["hunyuan_video_disc.py"]

# model = dict(
# # use_temporal_tiling = True,
# use_slicing = True,
# use_spatial_tiling = True,
# sample_tsize = 32,
# )

grad_checkpoint = True
# grad_checkpoint_buffer_size = 15 * 1024**3

bucket_config = {
"_delete_": True,
"512px_ar1:1": {129: (1.0, 1)},
# "360p": {33: (1.0, 1)}
}

mixed_image_ratio = 0.0

log_every = 1

plugin = "hybrid"
plugin_config = dict(
tp_size=8,
pp_size=1,
zero_stage=2,
static_graph=True,
# sp_size=1,
# sequence_parallelism_mode="ring_attn",
# enable_sequence_parallelism=True,
)
ema_decay = None

disc_plugin = "zero2"
disc_plugin_config = dict(
reduce_bucket_size_in_m=128,
overlap_allgather=False,
)

+ 0
- 49
configs/vae/train/hunyuan_4x16x16.py View File

@@ -1,49 +0,0 @@
_base_ = ["video.py"]

dataset = dict(
rand_sample_interval=8,
)

bucket_config = {
"_delete_": True,
"256px_ar1:1": {33: (1.0, 2)},
}

vae_loss_config = dict(
perceptual_loss_weight=0.1,
kl_loss_weight=1e-6,
)

opl_loss_weight = 0

model = dict(
_delete_=True,
type="hunyuan_vae",
from_pretrained=None,
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=16,
dropout=0.0,
block_out_channels=(128, 256, 256, 512, 512),
time_compression_ratio=4,
spatial_compression_ratio=16,
encoder_add_residual=True,
encoder_slice_t=True,
decoder_add_residual=True,
decoder_slice_t=True,
)

discriminator = dict(type="N_Layer_discriminator_3D", from_pretrained=None, input_nc=3, n_layers=5, conv_cls="conv3d")

gen_loss_config = dict(
gen_start=0,
disc_factor=1,
disc_weight=0.05,
)

disc_loss_config = dict(
disc_start=0,
disc_factor=1.0,
disc_loss_type="wgan-gp",
)

+ 0
- 19
configs/vae/train/hunyuan_8x8x8_disc.py View File

@@ -1,19 +0,0 @@
_base_ = ["hunyuan_video.py"]

model = dict(
time_compression_ratio=8,
)

discriminator = dict(type="N_Layer_discriminator_3D", from_pretrained=None, input_nc=3, n_layers=5, conv_cls="conv3d")

gen_loss_config = dict(
gen_start=0,
disc_factor=1,
disc_weight=0.05,
)

disc_loss_config = dict(
disc_start=0,
disc_factor=1.0,
disc_loss_type="wgan-gp",
)

+ 0
- 10
configs/vae/train/hunyuan_residual.py View File

@@ -1,10 +0,0 @@
_base_ = ["hunyuan_video_disc.py"]

model = dict(
encoder_add_residual=True,
encoder_slice_t=True,
decoder_add_residual=True,
decoder_slice_t=True,
)

mixed_image_ratio = 0.25

+ 0
- 30
configs/vae/train/hunyuan_video.py View File

@@ -1,30 +0,0 @@
_base_ = ["video.py"]

dataset = dict(
rand_sample_interval=8,
)

bucket_config = {
"_delete_": True,
"256px_ar1:1": {33: (1.0, 2)},
# "256px_ar1:1": {65: (1.0, 1)},
}

vae_loss_config = dict(
perceptual_loss_weight=0.1,
kl_loss_weight=1e-6,
)

opl_loss_weight = 0

model = dict(
_delete_=True,
type="hunyuan_vae",
# from_pretrained="/mnt/jfs-hdd/sora/checkpoints/pretrained_models/hunyuan-video-t2v-720p/vae/pytorch_model.pt",
from_pretrained=None,
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=16,
dropout=0.0, # TODO: expr with 0.1
)

+ 0
- 15
configs/vae/train/hunyuan_video_disc.py View File

@@ -1,15 +0,0 @@
_base_ = ["hunyuan_video.py"]

discriminator = dict(type="N_Layer_discriminator_3D", from_pretrained=None, input_nc=3, n_layers=5, conv_cls="conv3d")

gen_loss_config = dict(
gen_start=0,
disc_factor=1,
disc_weight=0.05,
)

disc_loss_config = dict(
disc_start=0,
disc_factor=1.0,
disc_loss_type="wgan-gp",
)

+ 0
- 9
configs/vae/train/hunyuan_video_disc_fp32.py View File

@@ -1,9 +0,0 @@
_base_ = ["hunyuan_video_disc.py"]

bucket_config = {
"_delete_": True,
"256px_ar1:1": {33: (1.0, 1)},
# "256px_ar1:1": {65: (1.0, 1)},
}

ema_decay = None

+ 0
- 5
configs/vae/train/hunyuan_video_disc_hinge.py View File

@@ -1,5 +0,0 @@
_base_ = ["hunyuan_video_disc.py"]

disc_loss_config = dict(
disc_loss_type="hinge",
)

+ 0
- 10
configs/vae/train/image.py View File

@@ -1,10 +0,0 @@
_base_ = ["video.py"]

bucket_config = {
"256px_ar1:1": {1: (1.0, 1)},
}

discriminator = None
gen_loss_confg = None
disc_loss_config = None
disc_lr_scheduler = None

+ 0
- 21
configs/vae/train/image_dc_ae.py View File

@@ -1,21 +0,0 @@
_base_ = ["image.py"]

model = dict(
_delete_=True,
type="dc_ae",
model_name="dc-ae-f32c32-sana-1.0",
from_scratch=True,
)
vae_loss_config = dict(
perceptual_loss_weight=0.1,
kl_loss_weight=0,
)
opl_loss_weight = 0

bucket_config = {
"_delete_": True,
"256px_ar1:1": {1: (1.0, 12)},
}
optim = dict(
lr=1e-4,
)

+ 0
- 102
configs/vae/train/video.py View File

@@ -1,102 +0,0 @@
# Define dataset
dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=16,
)
bucket_config = {
"256px_ar1:1": {32: (1.0, 1)},
}

grad_checkpoint = False
num_bucket_build_workers = 64
num_workers = 12
prefetch_factor = 2
pin_memory_cache_pre_alloc_numels = [50 * 1024 * 1024] * num_workers * prefetch_factor

# Define model
model = dict(
type="autoencoder_3d",
from_pretrained="/mnt/jfs-hdd/sora/checkpoints/pretrained_models/flux/FLUX.1-dev_convert/ae_central_inflate_zero_init.safetensors",
in_channels=3,
out_ch=3,
ch=128,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=1.0,
shift_factor=0.0,
)
# discriminator = dict(type="N_Layer_discriminator_3D", from_pretrained=None, input_nc=3, n_layers=5, conv_cls="conv3d")

# == loss weights ==
mixed_strategy = "mixed_video_image"
mixed_image_ratio = 0.25 # 1:4

opl_loss_weight = 1e5
# opl_loss_weight = 1e3

vae_loss_config = dict(
perceptual_loss_weight=1.0,
kl_loss_weight=5e-4,
# kl_loss_weight=5e-6,
logvar_init=0.0,
) # reconstruction loss (nll) + lpips similarity loss + kl loss

gen_loss_config = dict(
gen_start=0,
disc_factor=1e5,
disc_weight=0.5,
)

disc_loss_config = dict(
disc_start=0,
disc_factor=1.0,
disc_loss_type="wgan-gp",
)

# == optimizer ==
optim = dict(
cls="HybridAdam",
lr=5e-6,
eps=1e-8,
weight_decay=0.0,
adamw_mode=True,
betas=(0.9, 0.98),
)
optim_discriminator = dict(
cls="HybridAdam",
lr=1e-5,
eps=1e-8,
weight_decay=0.0,
adamw_mode=True,
betas=(0.9, 0.98),
)

lr_scheduler = dict(warmup_steps=0)
disc_lr_scheduler = dict(warmup_steps=0)

update_warmup_steps = True
# start_epoch = start_step = 0

grad_clip = 1.0

# Acceleration settings
dtype = "bf16"
plugin = "zero2"
plugin_config = dict(
reduce_bucket_size_in_m=128,
overlap_allgather=False,
)

# Others
seed = 42
outputs = "outputs"
epochs = 100
log_every = 10
ckpt_every = 2000
keep_n_latest = 50
ema_decay = 0.99

# wandb
wandb_project = "mmdit_vae"

+ 1
- 10
configs/vae/train/video_dc_ae.py View File

@@ -14,6 +14,7 @@ model = dict(
dataset = dict(
type="video_text",
transform_name="resize_crop",
data_path="datasets/pexels_45k_necessary.csv",
fps_max=24,
)

@@ -61,23 +62,13 @@ keep_n_latest = 50
ema_decay = 0.99
wandb_project = "dcae"

discriminator = None
disc_lr_scheduler = None
optim_discriminator = None

update_warmup_steps = True

# ============
# loss config
# ============
opl_loss_weight = 0

vae_loss_config = dict(
perceptual_loss_weight=0.5,
kl_loss_weight=0,
logvar_init=0,
)

gen_loss_config = None
disc_loss_config = None


+ 4
- 3
configs/vae/train/video_dc_ae_disc.py View File

@@ -1,7 +1,6 @@
_base_ = ["video_dc_ae.py"]

discriminator = dict(
_delete_=True,
type="N_Layer_discriminator_3D",
from_pretrained=None,
input_nc=3,
@@ -12,13 +11,11 @@ disc_lr_scheduler = dict(warmup_steps=0)

gen_loss_config = dict(
gen_start=0,
disc_factor=1,
disc_weight=0.05,
)

disc_loss_config = dict(
disc_start=0,
disc_factor=1.0,
disc_loss_type="hinge",
)

@@ -31,3 +28,7 @@ optim_discriminator = dict(
betas=(0.9, 0.98),
)

grad_checkpoint = True
model = dict(
disc_off_grad_ckpt = True, # set to true if your `grad_checkpoint` is True
)

+ 143
- 7
docs/ae.md View File

@@ -1,18 +1,154 @@
# Step by step to train and evaluate an video autoencoder
# Step by step to train and evaluate an video autoencoder (AE)
Inspired by [SANA](https://arxiv.org/abs/2410.10629), we aim to drastically increase the compression ratio in the AE. We propose a video autoencoder architecture based on [DC-AE](https://github.com/mit-han-lab/efficientvit), the __Video DC-AE__, which compression the video by 4x in the temporal dimension and 32x32 in the spatial dimension. Compared to [HunyuanVideo](https://github.com/Tencent/HunyuanVideo)'s VAE of 4x8x8, our proposed AE has a much higher spatial compression ratio.
Thus, we can effectively reduce the token length in the diffusion model by a total of 16x (assuming the same patch sizes), drastically increase both training and inference speed.

## Installation
## Data Preparation

```
pip install diffusers==0.31.0
```
Follow this [guide](./train.md#prepare-dataset) to prepare the __DATASET__ for training and inference. You may use our provided dataset or custom ones.
To use custom dataset, pass the argument `--dataset.data_path <your_data_path>` to the following training or inference command.

## Training
The command to launch training is as follows:

We train our __Video DC-AE__ from scratch on 8xGPUs for 3 weeks.

We first train with the following command:

```bash
torchrun --nproc_per_node 8 scripts/vae/train.py configs/vae/train/video_dc_ae.py
```
torchrun --nproc_per_node 8 scripts/vae/train_video_sana.py configs/vae/train/video_dc_ae_train_channel_proj.py --dataset.data-path /home/zhengzangwei/Open-Sora/datasets/pexels_45k_necessary.csv --model.model_name dc-ae-f32t4c128 --wandb True --optim.lr 5e-5 --wandb-project dcae --vae_loss_config.perceptual_loss_weight 0.5 --wandb-expr-name dc-ae-f32t4c64_full_model_ft

When the model is almost converged, we add a discriminator and continue to train the model with the checkpoint `model_ckpt` using the following command:

```bash
torchrun --nproc_per_node 8 scripts/vae/train.py configs/vae/train/video_dc_ae_disc.py --model.from_pretrained <model_ckpt>
```
You may pass the flag `--wandb True` if you have a [wandb](https://wandb.ai/home) account and wish to track the training progress online.

## Inference

Download the relevant weights following [this guide](../README.md#model-download). Alternatively, you may use your own trained model by passing the following flag `--model.from_pretrained <your_model_ckpt_path>`.

### Video DC-AE

Use the following code to reconstruct the videos using our trained `Video DC-AE`:

```bash
torchrun --nproc_per_node 1 --standalone scripts/vae/inference.py configs/vae/inference/video_dc_ae.py --save-dir samples/dcae
```

### Hunyuan Video

Alternatively, we have incorporated [HunyuanVideo vae](https://github.com/Tencent/HunyuanVideo) into our code, you may run inference with the following command:

```bash
torchrun --nproc_per_node 1 --standalone scripts/vae/inference.py configs/vae/inference/hunyuanvideo_vae.py --save-dir samples/hunyuanvideo_vae
```

## Config Interpretation

All AE configs are located in `configs/vae/`, divided into configs for training (`configs/vae/train`) and for inference (`configs/vae/inference`).

### Training Config

For training, the same config rules as [those](./train.md#config) for the diffusion model are applied.

<details>
<summary> <b>Loss Config</b> </summary>
Our __Video DC-AE__ is based on the [DC-AE](https://github.com/mit-han-lab/efficientvit) architecture, which doesn't have a variational component. Thus, our training simply composes of the *reconstruction loss* and the *perceptual loss*.
Experimentally, we found that setting a ratio of 0.5 for the perceptual loss is effective.

```python
vae_loss_config = dict(
perceptual_loss_weight=0.5, # weigh the perceptual loss by 0.5
kl_loss_weight=0, # no KL loss
)
```

In a later stage, we include a discriminator, and the training loss for the ae has an additional generator loss component, where we use a small ratio of 0.05 to weigh the loss calculated:
```python
gen_loss_config = dict(
gen_start=0, # include generator loss from step 0 onwards
disc_weight=0.05, # weigh the loss by 0.05
)
```

The discriminator we use is trained from scratch, and it's loss is simply the hinged loss:
```python
disc_loss_config = dict(
disc_start=0, # update the discriminator from step 0 onwards
disc_loss_type="hinge", # the discriminator loss type
)
```
</details>

<details>
<summary> <b> Data Bucket Config </b> </summary>
For the data bucket, we used 32 frames of 256px videos to train our AE.
```python
bucket_config = {
"256px_ar1:1": {32: (1.0, 1)},
}
```
</details>

<details>
<summary> <b>Train with more frames or higher resolutions</b> </summary>

If you train with longer frames or larger resolutions, you may increase the `spatial_tile_size` and `temporal_tile_size` during inference without degrading the AE performance (see [Inference Config](ae.md#inference-config)). This may give you advantage of faster AE inference such as when training the diffusion model (although at the cost of slower AE training).

You may increase the video frames to 96 (although multiples of 4 works, we generally recommend to use frame numbers of multiples of 32):

```python
bucket_config = {
"256px_ar1:1": {96: (1.0, 1)},
}
grad_checkpoint = True
```
or train for higher resolution such as 512px:
```python
bucket_config = {
"512px_ar1:1": {32: (1.0, 1)},
}
grad_checkpoint = True
```
Note that gradient checkpoint needs to be turned on in order to avoid prevent OOM error.

Moreover, if `grad_checkpointing` is set to `True` in discriminator training, you need to pass the flag `--model.disc_off_grad_ckpt True` or simply set in the config:
```python
grad_checkpoint = True
model = dict(
disc_off_grad_ckpt = True, # set to true if your `grad_checkpoint` is True
)
```
This is to make sure the discriminator loss will have a gradient at the laster later during adaptive loss calculation.
</details>




### Inference Config

For AE inference, we have replicated the tiling mechanism in hunyuan to our Video DC-AE, which can be turned on with the following:

```python
model = dict(
...,
use_spatial_tiling=True,
use_temporal_tiling=True,
spatial_tile_size=256,
temporal_tile_size=32,
tile_overlap_factor=0.25,
...,
)
```

By default, both spatial tiling and temporal tiling are turned on for the best performance.
Since our Video DC-AE is trained on 256px videos of 32 frames only, `spatial_tile_size` should be set to 256 and `temporal_tile_size` should be set to 32.
If you train your own Video DC-AE with other resolutions and length, you may adjust the values accordingly.

You can specify the directory to store output samples with `--save_dir <your_dir>` or setting it in config, for instance:

```python
save_dir = "./samples"
```

+ 8
- 4
docs/hcae.md View File

@@ -2,10 +2,14 @@

## Introduction

## Traini
## Inference

torchrun --nproc_per_node 1 --standalone scripts/diffusion/inference.py configs/diffusion/inference/high_compression.py --dataset.data-path assets/texts/sora.csv --ckpt-path /mnt/jfs-hdd/sora/checkpoints/zhengzangwei/outputs/adapt_video_sana/250304_141109-diffusion_train_dc_ae_video_temporal_compression_i2v/epoch0-global_step2000 --save-dir samples/debug_03_06/adapt_2k_op --sampling_option.num_frames 128
```bash
torchrun --nproc_per_node 1 --standalone scripts/diffusion/inference.py configs/diffusion/inference/high_compression.py --dataset.data-path assets/texts/sora.csv
```

## Training


torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/high_compression.py --dataset.data-path /mnt/ddn/sora/meta/vo3/stage2/video+image_stage2_nopart3.parquet --wandb True --wandb-project adapt_video_sana --load /mnt/jfs-hdd/sora/checkpoints/zhengzangwei/outputs/adapt_video_sana/250304_141109-diffusion_train_dc_ae_video_temporal_compression_i2v/epoch0-global_step2000 --outputs /mnt/jfs-hdd/sora/checkpoints/zhengzangwei/outputs/adapt_video_sana/ --start-step 0 --start-epoch 0 --seed 2026
```bash
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/high_compression.py --dataset.data-path
```

+ 6
- 62
opensora/models/dc_ae/ae_model_zoo.py View File

@@ -24,28 +24,13 @@ from torch import nn
from opensora.registry import MODELS
from opensora.utils.ckpt import load_checkpoint

from .models.dc_ae import DCAE, DCAEConfig, dc_ae_f32, dc_ae_f64c128, dc_ae_f64t4c256, dc_ae_f128c512
from .models.dc_ae import DCAE, DCAEConfig, dc_ae_f32

__all__ = ["create_dc_ae_model_cfg", "DCAE_HF", "AutoencoderKL", "DC_AE"]
__all__ = ["create_dc_ae_model_cfg", "DCAE_HF", "DC_AE"]


REGISTERED_DCAE_MODEL: dict[str, tuple[Callable, Optional[str]]] = {
"dc-ae-f32c32-in-1.0": (dc_ae_f32, None),
"dc-ae-f64c128-in-1.0": (dc_ae_f64c128, None),
"dc-ae-f128c512-in-1.0": (dc_ae_f128c512, None),
#################################################################################################
"dc-ae-f32c32-mix-1.0": (dc_ae_f32, None),
"dc-ae-f64c128-mix-1.0": (dc_ae_f64c128, None),
"dc-ae-f128c512-mix-1.0": (dc_ae_f128c512, None),
#################################################################################################
"dc-ae-f32c32-sana-1.0": (dc_ae_f32, None),
"dc-ae-f32c32-sana-1.0-video": (dc_ae_f32, None),
"dc-ae-f32c32-sana-1.0-video-temporal-compression": (dc_ae_f32, None),
"dc-ae-f32t4c256": (dc_ae_f32, None),
"dc-ae-f32t4c128": (dc_ae_f32, None),
"dc-ae-f32t4c64": (dc_ae_f32, None),
"dc-ae-f64t4c128": (dc_ae_f64c128, None),
"dc-ae-f64t4c256": (dc_ae_f64t4c256, None),
}


@@ -71,15 +56,13 @@ def DC_AE(
from_scratch: bool = False,
from_pretrained: str | None = None,
is_training: bool = False,
rename_keys: dict = None,
use_spatial_tiling: bool = False,
use_temporal_tiling: bool = False,
spatial_tile_size: int = 256,
temporal_tile_size: int = 32,
tile_overlap_factor: float = 0.25,
tune_channel_proj: bool = False,
train_decoder_only: bool = False,
scaling_factor: float = None,
disc_off_grad_ckpt: bool = False,
) -> DCAE_HF:
if not from_scratch:
model = DCAE_HF.from_pretrained(model_name).to(device_map, torch_dtype)
@@ -87,20 +70,9 @@ def DC_AE(
model = DCAE_HF(model_name).to(device_map, torch_dtype)

if from_pretrained is not None:
model = load_checkpoint(model, from_pretrained, device_map=device_map, rename_keys=rename_keys)
model = load_checkpoint(model, from_pretrained, device_map=device_map)
print(f"loaded dc_ae from ckpt path: {from_pretrained}")

if train_decoder_only:
print("freezing all encoder weights!")
# tune the channel projection layer only
if tune_channel_proj:
model.cfg.tune_channel_proj = True
print("freezing all weights except the channel projection!")
for _, param in model.named_parameters():
param.requires_grad = False
for _, param in model.decoder.named_parameters():
param.requires_grad = True

model.cfg.is_training = is_training
model.use_spatial_tiling = use_spatial_tiling
model.use_temporal_tiling = use_temporal_tiling
@@ -109,33 +81,5 @@ def DC_AE(
model.tile_overlap_factor = tile_overlap_factor
if scaling_factor is not None:
model.scaling_factor = scaling_factor
return model


class AutoencoderKL(nn.Module):
def __init__(self, model_name: str):
super().__init__()
self.model_name = model_name
if self.model_name in ["stabilityai/sd-vae-ft-ema"]:
self.model = diffusers.models.AutoencoderKL.from_pretrained(self.model_name)
self.spatial_compression_ratio = 8
elif self.model_name == "flux-vae":
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
self.model = diffusers.models.AutoencoderKL.from_pretrained(pipe.vae.config._name_or_path)
self.spatial_compression_ratio = 8
else:
raise ValueError(f"{self.model_name} is not supported for AutoencoderKL")

def encode(self, x: torch.Tensor) -> torch.Tensor:
if self.model_name in ["stabilityai/sd-vae-ft-ema", "flux-vae"]:
return self.model.encode(x).latent_dist.sample()
else:
raise ValueError(f"{self.model_name} is not supported for AutoencoderKL")

def decode(self, latent: torch.Tensor) -> torch.Tensor:
if self.model_name in ["stabilityai/sd-vae-ft-ema", "flux-vae"]:
return self.model.decode(latent).sample
else:
raise ValueError(f"{self.model_name} is not supported for AutoencoderKL")
model.decoder.disc_off_grad_ckpt = disc_off_grad_ckpt
return model

+ 13
- 152
opensora/models/dc_ae/models/dc_ae.py View File

@@ -41,7 +41,7 @@ from .nn.ops import (
ResidualBlock,
)

__all__ = ["DCAE", "dc_ae_f32", "dc_ae_f64c128", "dc_ae_f128c512", "dc_ae_f64t4c256"]
__all__ = ["DCAE", "dc_ae_f32"]


@dataclass
@@ -62,7 +62,6 @@ class EncoderConfig:
double_latent: bool = False
is_video: bool = False
temporal_downsample: tuple[bool, ...] = ()
tune_channel_proj: bool = False


@dataclass
@@ -82,7 +81,6 @@ class DecoderConfig:
out_act: str = "relu"
is_video: bool = False
temporal_upsample: tuple[bool, ...] = ()
tune_channel_proj: bool = False


@dataclass
@@ -105,8 +103,6 @@ class DCAEConfig:
scaling_factor: Optional[float] = None
is_image_model: bool = False

tune_channel_proj: bool = False
train_decoder_only: bool = False
is_training: bool = False # NOTE: set to True in vae train config

use_spatial_tiling: bool = False
@@ -114,6 +110,7 @@ class DCAEConfig:
spatial_tile_size: int = 256
temporal_tile_size: int = 32
tile_overlap_factor: float = 0.25


def build_block(
@@ -437,11 +434,7 @@ class Encoder(nn.Module):
for stage in self.stages:
if len(stage.op_list) == 0:
continue
# x = stage(x)
if self.cfg.tune_channel_proj:
x = stage(x)
else:
x = auto_grad_checkpoint(stage, x)
x = auto_grad_checkpoint(stage, x)
# x = self.project_out(x)
x = auto_grad_checkpoint(self.project_out, x)
return x
@@ -512,17 +505,17 @@ class Decoder(nn.Module):
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.cfg.tune_channel_proj:
x = self.project_in(x)
else:
x = auto_grad_checkpoint(self.project_in, x)
x = auto_grad_checkpoint(self.project_in, x)
for stage in reversed(self.stages):
if len(stage.op_list) == 0:
continue
# x = stage(x)
x = auto_grad_checkpoint(stage, x)
# x = self.project_out(x)
x = auto_grad_checkpoint(self.project_out, x)

if self.disc_off_grad_ckpt:
x = self.project_out(x)
else:
x = auto_grad_checkpoint(self.project_out, x)
return x


@@ -532,10 +525,6 @@ class DCAE(nn.Module):
self.cfg = cfg
self.encoder = Encoder(cfg.encoder)
self.decoder = Decoder(cfg.decoder)
if cfg.tune_channel_proj: # propergate the tune_channel_proj flag for gradient checkpointing handling
self.encoder.cfg.tune_channel_proj = True
self.decoder.cfg.tune_channel_proj = True

self.scaling_factor = cfg.scaling_factor
self.time_compression_ratio = cfg.time_compression_ratio
self.spatial_compression_ratio = cfg.spatial_compression_ratio
@@ -799,33 +788,7 @@ class DCAE(nn.Module):


def dc_ae_f32(name: str, pretrained_path: str) -> DCAEConfig:
if name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
cfg_str = (
"latent_channels=32 "
"time_compression_ratio=1 "
"spatial_compression_ratio=32 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[0,4,8,2,2,2] "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[0,5,10,2,2,2] "
"decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu]"
)
elif name in ["dc-ae-f32c32-sana-1.0"]:
cfg_str = (
"latent_channels=32 "
"time_compression_ratio=1 "
"spatial_compression_ratio=32 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[2,2,2,3,3,3] "
"encoder.downsample_block_type=Conv "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms2d decoder.act=silu "
"scaling_factor=0.41407 "
"is_image_model=True"
)
elif name in ["dc-ae-f32c32-sana-1.0-video", "dc-ae-f32t4c256", "dc-ae-f32t4c128", "dc-ae-f32t4c64"]:
if name in ["dc-ae-f32t4c128"]:
cfg_str = (
"time_compression_ratio=4 "
"spatial_compression_ratio=32 "
@@ -839,83 +802,10 @@ def dc_ae_f32(name: str, pretrained_path: str) -> DCAEConfig:
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms3d decoder.act=silu decoder.out_norm=rms3d "
"decoder.is_video=True "
) # make sure there is no trailing blankspace in the last line
if name in ["dc-ae-f32t4c256", "dc-ae-f32t4c128", "dc-ae-f32t4c64"]:
cfg_str += (
"encoder.temporal_downsample=[False,False,False,True,True,False] "
"decoder.temporal_upsample=[False,False,False,True,True,False]"
) # make sure there is preceding blankspace in the first line
if name in ["dc-ae-f32t4c256"]:
cfg_str += " latent_channels=256 " "scaling_factor=0.46505"
elif name in ["dc-ae-f32t4c128"]:
cfg_str += " latent_channels=128"
elif name in ["dc-ae-f32t4c64"]:
cfg_str += " latent_channels=64"
elif name in ["dc-ae-f32c32-sana-1.0-video"]:
cfg_str += " latent_channels=32"
else:
raise NotImplementedError
cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg))
cfg.pretrained_path = pretrained_path
return cfg


def dc_ae_f64c128(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig:
if name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
cfg_str = (
"latent_channels=128 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024,2048] encoder.depth_list=[0,4,8,2,2,2,2] "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024,2048] decoder.depth_list=[0,5,10,2,2,2,2] "
"decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu,silu]"
)
elif name in ["dc-ae-f64t4c128"]:
cfg_str = (
"time_compression_ratio=4 "
"spatial_compression_ratio=64 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024,1024] encoder.depth_list=[2,2,2,3,3,3,3] "
"encoder.downsample_block_type=Conv "
"encoder.norm=rms3d "
"encoder.is_video=True "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024,1024] decoder.depth_list=[3,3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms3d decoder.act=silu decoder.out_norm=rms3d "
"decoder.is_video=True "
"encoder.temporal_downsample=[False,False,False,True,True,False,False] "
"decoder.temporal_upsample=[False,False,False,True,True,False,False] "
"encoder.temporal_downsample=[False,False,False,True,True,False] "
"decoder.temporal_upsample=[False,False,False,True,True,False] "
"latent_channels=128"
)
else:
raise NotImplementedError
cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg))
cfg.pretrained_path = pretrained_path
return cfg


def dc_ae_f64t4c256(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig:
if name in ["dc-ae-f64t4c256"]:
cfg_str = (
"time_compression_ratio=4 "
"spatial_compression_ratio=64 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024,1024] encoder.depth_list=[2,2,2,3,3,3,3] "
"encoder.downsample_block_type=Conv "
"encoder.norm=rms3d "
"encoder.is_video=True "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024,1024] decoder.depth_list=[3,3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms3d decoder.act=silu decoder.out_norm=rms3d "
"decoder.is_video=True "
"encoder.temporal_downsample=[False,False,False,True,True,False,False] "
"decoder.temporal_upsample=[False,False,False,True,True,False,False] "
"latent_channels=256"
)
) # make sure there is no trailing blankspace in the last line
else:
raise NotImplementedError
cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
@@ -923,32 +813,3 @@ def dc_ae_f64t4c256(name: str, pretrained_path: Optional[str] = None) -> DCAECon
cfg.pretrained_path = pretrained_path
return cfg


def dc_ae_f128c512(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig:
if name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
cfg_str = (
"latent_channels=512 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024,2048,2048] encoder.depth_list=[0,4,8,2,2,2,2,2] "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU,EViT_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024,2048,2048] decoder.depth_list=[0,5,10,2,2,2,2,2] "
"decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu,silu,silu]"
)
elif name in ["dc-ae-f128c512-sana-1.0"]:
cfg_str = (
"latent_channels=512 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024,2048,2048] encoder.depth_list=[2,2,2,3,3,3,3,3] "
"encoder.downsample_block_type=Conv "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024,2048,2048] decoder.depth_list=[3,3,3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms2d decoder.act=silu "
"scaling_factor=0.722656"
)
else:
raise NotImplementedError
cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg))
cfg.pretrained_path = pretrained_path
return cfg

+ 2
- 23
opensora/models/dc_ae/utils/init.py View File

@@ -20,7 +20,7 @@ import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm

__all__ = ["init_modules", "zero_last_gamma"]
__all__ = ["init_modules"]


def init_modules(model: Union[nn.Module, list[nn.Module]], init_type="trunc_normal") -> None:
@@ -60,25 +60,4 @@ def init_modules(model: Union[nn.Module, list[nn.Module]], init_type="trunc_norm
if isinstance(weight, torch.nn.Parameter):
init_func(weight)
if isinstance(bias, torch.nn.Parameter):
bias.data.zero_()


def zero_last_gamma(model: nn.Module, init_val=0) -> None:
import opensora.models.dc_ae.models.nn.ops as ops
for m in model.modules():
if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer):
if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
parent_module = m.main.point_conv
elif isinstance(m.main, ops.ResBlock):
parent_module = m.main.conv2
elif isinstance(m.main, ops.ConvLayer):
parent_module = m.main
elif isinstance(m.main, (ops.LiteMLA)):
parent_module = m.main.proj
else:
parent_module = None
if parent_module is not None:
norm = getattr(parent_module, "norm", None)
if norm is not None:
nn.init.constant_(norm.weight, init_val)

bias.data.zero_()

+ 3
- 24
opensora/models/hunyuan_vae/autoencoder_kl_causal_3d.py View File

@@ -54,10 +54,6 @@ from opensora.models.hunyuan_vae.vae import (
DiagonalGaussianDistribution,
EncoderCausal3D,
)
from opensora.models.hunyuan_vae.vae_channel import (
ChannelDecoderCausal3D,
ChannelEncoderCausal3D,
)


@dataclass
@@ -82,10 +78,6 @@ class AutoEncoder3DConfig:
use_temporal_tiling: bool = False
tile_overlap_factor: float = 0.25
dropout: float = 0.0
encoder_add_residual: bool = False
encoder_slice_t: bool = False
decoder_add_residual: bool = False
decoder_slice_t: bool = False
channel: bool = False


@@ -110,11 +102,7 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
self.spatial_compression_ratio = config.spatial_compression_ratio
self.z_channels = config.latent_channels

# channel True to mimic sana that process info with channel averaging and expansion
encoder_type = ChannelEncoderCausal3D if config.channel else EncoderCausal3D
decoder_type = ChannelDecoderCausal3D if config.channel else DecoderCausal3D

self.encoder = encoder_type(
self.encoder = EncoderCausal3D(
in_channels=config.in_channels,
out_channels=config.latent_channels,
block_out_channels=config.block_out_channels,
@@ -126,11 +114,9 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
spatial_compression_ratio=config.spatial_compression_ratio,
mid_block_add_attention=config.mid_block_add_attention,
dropout=config.dropout,
add_residual=config.encoder_add_residual,
slice_t=config.encoder_slice_t,
)

self.decoder = decoder_type(
self.decoder = DecoderCausal3D(
in_channels=config.latent_channels,
out_channels=config.out_channels,
block_out_channels=config.block_out_channels,
@@ -141,8 +127,6 @@ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
spatial_compression_ratio=config.spatial_compression_ratio,
mid_block_add_attention=config.mid_block_add_attention,
dropout=config.dropout,
add_residual=config.decoder_add_residual,
slice_t=config.decoder_slice_t,
)

self.quant_conv = nn.Conv3d(2 * config.latent_channels, 2 * config.latent_channels, kernel_size=1)
@@ -643,7 +627,6 @@ def CausalVAE3D_HUNYUAN(
from_pretrained: str = None,
device_map: str | torch.device = "cuda",
torch_dtype: torch.dtype = torch.bfloat16,
train_decoder_only: bool = False,
**kwargs,
) -> AutoencoderKLCausal3D:
config = AutoEncoder3DConfig(from_pretrained=from_pretrained, **kwargs)
@@ -651,9 +634,5 @@ def CausalVAE3D_HUNYUAN(
model = AutoencoderKLCausal3D(config).to(torch_dtype)
if from_pretrained:
model = load_checkpoint(model, from_pretrained, device_map=device_map, strict=True)
if train_decoder_only:
for _, param in model.named_parameters():
param.requires_grad = False
for _, param in model.decoder.named_parameters():
param.requires_grad = True

return model

+ 1
- 6
opensora/models/hunyuan_vae/distributed.py View File

@@ -544,8 +544,6 @@ class TPUpDecoderBlockCausal3D(UpsampleCausal3D):
kernel_size=3,
bias=True,
upsample_factor=(2, 2, 2),
add_residual=False,
slice_t=False,
tp_group=None,
split_input: bool = False,
split_output: bool = False,
@@ -553,7 +551,7 @@ class TPUpDecoderBlockCausal3D(UpsampleCausal3D):
shortcut_=None,
):
assert tp_group is not None, "tp_group must be provided"
super().__init__(channels, out_channels, kernel_size, bias, upsample_factor, add_residual, slice_t)
super().__init__(channels, out_channels, kernel_size, bias, upsample_factor)
conv = conv_ if conv_ is not None else self.conv.conv
self.conv.conv = Conv3dTPRow.from_native_module(
conv, tp_group, split_input=split_input, split_output=split_output
@@ -562,8 +560,6 @@ class TPUpDecoderBlockCausal3D(UpsampleCausal3D):
tp_size = dist.get_world_size(group=self.tp_group)
assert self.channels % tp_size == 0, f"channels {self.channels} must be divisible by tp_size {tp_size}"
self.channels = self.channels // tp_size
if add_residual:
self.shortcut = shortcut_

def forward(self, input_tensor):
input_tensor = split_forward_gather_backward(input_tensor, 1, self.tp_group)
@@ -577,7 +573,6 @@ class TPUpDecoderBlockCausal3D(UpsampleCausal3D):
conv.kernel_size[0],
conv.bias is not None,
module.upsample_factor,
module.add_residual,
conv_=conv,
shortcut_=getattr(module, "shortcut", None),
tp_group=process_group,


+ 0
- 151
opensora/models/hunyuan_vae/unet_causal_3d_blocks.py View File

@@ -95,61 +95,6 @@ class CausalConv3d(nn.Module):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)


class ChannelDuplicatingPixelShuffleUpSampleLayer(nn.Module):
def __init__(
self,
factor=(1, 2, 2),
slice_t=False, # either slice T or pad T
):
super().__init__()
self.factor = factor
self.slice_t = slice_t

def forward(self, x: torch.Tensor) -> torch.Tensor:
T = x.size(2)
if self.factor[0] == 2:
if T == 1: # image
x = x.repeat_interleave(self.factor[1] * self.factor[2], dim=1)
residual = rearrange(
x, "B (C fh fw) T H W -> B C T (H fh) (W fw)", fh=self.factor[1], fw=self.factor[2]
)
else: # video
if self.slice_t:
# slice T and process differently
first_f, other_f = x.split((1, T - 1), dim=2)
first_f = first_f.repeat_interleave(self.factor[1] * self.factor[2], dim=1)
first_f = rearrange(
first_f, "B (C fh fw) T H W -> B C T (H fh) (W fw)", fh=self.factor[1], fw=self.factor[2]
)
other_f = other_f.repeat_interleave(self.factor[0] * self.factor[1] * self.factor[2], dim=1)
other_f = rearrange(
other_f,
"B (C ft fh fw) T H W -> B C (T ft) (H fh) (W fw)",
ft=self.factor[0],
fh=self.factor[1],
fw=self.factor[2],
)
residual = torch.cat((first_f, other_f), dim=2)
else:
x = x.repeat_interleave(self.factor[0] * self.factor[1] * self.factor[2], dim=1)
residual = rearrange(
x,
"B (C ft fh fw) T H W -> B C (T ft) (H fh) (W fw)",
ft=self.factor[0],
fh=self.factor[1],
fw=self.factor[2],
)
residual = residual[:, :, 1:] # remove 1st frame TODO: this may not be wise
elif self.factor[0] == 1:
x = x.repeat_interleave(self.factor[1] * self.factor[2], dim=1)
residual = rearrange(x, "B (C fh fw) T H W -> B C T (H fh) (W fw)", fh=self.factor[1], fw=self.factor[2])
else:
raise NotImplementedError

return residual


class UpsampleCausal3D(nn.Module):
"""
A 3D upsampling layer with an optional convolution.
@@ -162,17 +107,12 @@ class UpsampleCausal3D(nn.Module):
kernel_size: int = 3,
bias=True,
upsample_factor=(2, 2, 2),
add_residual=False,
slice_t=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.upsample_factor = upsample_factor
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
self.add_residual = add_residual
if self.add_residual:
self.shortcut = ChannelDuplicatingPixelShuffleUpSampleLayer(factor=upsample_factor, slice_t=slice_t)

def forward(
self,
@@ -215,82 +155,8 @@ class UpsampleCausal3D(nn.Module):

hidden_states = self.conv(hidden_states)

#######################
# handle residual
#######################
if self.add_residual:
residual = self.shortcut(input_tensor)
hidden_states += residual

return hidden_states


class PixelUnshuffleChannelAveragingDownSampleLayer(nn.Module):
"""
residual for downsample layer;
if has downsample in T dim, add reshaping for T as well.
Note: (T-1),H,W must be multiples of 2
"""

def __init__(
self,
factor=(1, 2, 2), # can be (1,2,2) or (2,2,2)
slice_t=False, # either slice T or pad T if need to reduce the T dimension
):
super().__init__()
self.factor = factor
self.slice_t = slice_t
self.time_causal_padding = (0, 0, 0, 0, 1, 0) # W, H, T

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert self.factor[0] == 1 or self.factor[0] == 2, f"unsupported temporal reduction {self.factor[0]}"
# shape check
T, H, W = x.shape[-3:]
assert (
(T - 1) % self.factor[0] == H % self.factor[1] == W % self.factor[2] == 0
), f"{T}-1, {W}, {H} not divisible by {self.factor}"
if self.factor[0] == 2: # temporal reduction
if self.slice_t:
if T > 1: # video
# slice T and process differently
first_f, other_f = x.split((1, T - 1), dim=2)
first_f = rearrange(
first_f, "B C T (H fh) (W fw) -> B C (fh fw) T H W", fw=self.factor[1], fh=self.factor[2]
)
first_f = first_f.mean(dim=2)
other_f = rearrange(
other_f,
"B C (T ft) (H fh) (W fw) -> B C (ft fh fw) T H W",
ft=self.factor[0],
fw=self.factor[1],
fh=self.factor[2],
)
other_f = other_f.mean(dim=2)
residual = torch.cat((first_f, other_f), dim=2)
else: # image, only work on H & W
x = rearrange(x, "B C T (H fh) (W fw) -> B C (fh fw) T H W", fw=self.factor[1], fh=self.factor[2])
residual = x.mean(dim=2)
else: # use padding to handle temporal reduction
x = F.pad(x, self.time_causal_padding, mode="replicate")
# reshape and take average for shortcut
x = rearrange(
x,
"B C (T ft) (H fh) (W fw) -> B C (ft fh fw) T H W",
ft=self.factor[0],
fw=self.factor[1],
fh=self.factor[2],
)
residual = x.mean(dim=2)
elif self.factor[0] == 1: # no temporal reduction
# reshape and take average for shortcut
x = rearrange(x, "B C T (H fh) (W fw) -> B C (fh fw) T H W", fw=self.factor[1], fh=self.factor[2])
residual = x.mean(dim=2)
else:
raise NotImplementedError

return residual


class DownsampleCausal3D(nn.Module):
"""
A 3D downsampling layer with an optional convolution.
@@ -302,25 +168,16 @@ class DownsampleCausal3D(nn.Module):
kernel_size=3,
bias=True,
stride=2,
add_residual=False,
slice_t=False,
):
super().__init__()
self.channels = channels
self.out_channels = channels
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
self.add_residual = add_residual
if self.add_residual:
self.shortcut = PixelUnshuffleChannelAveragingDownSampleLayer(factor=stride, slice_t=slice_t)

def forward(self, input_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert input_tensor.shape[1] == self.channels
hidden_states = self.conv(input_tensor)

if self.add_residual:
residual = self.shortcut(input_tensor)
hidden_states += residual

return hidden_states


@@ -512,8 +369,6 @@ class DownEncoderBlockCausal3D(nn.Module):
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_stride: int = 2,
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
resnets = []
@@ -541,8 +396,6 @@ class DownEncoderBlockCausal3D(nn.Module):
DownsampleCausal3D(
out_channels,
stride=downsample_stride,
add_residual=add_residual,
slice_t=slice_t,
)
]
)
@@ -575,8 +428,6 @@ class UpDecoderBlockCausal3D(nn.Module):
output_scale_factor: float = 1.0,
add_upsample: bool = True,
upsample_scale_factor=(2, 2, 2),
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
resnets = []
@@ -606,8 +457,6 @@ class UpDecoderBlockCausal3D(nn.Module):
out_channels,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
add_residual=add_residual,
slice_t=slice_t,
)
]
)


+ 0
- 8
opensora/models/hunyuan_vae/vae.py View File

@@ -55,8 +55,6 @@ class EncoderCausal3D(nn.Module):
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
dropout: float = 0.0,
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
self.layers_per_block = layers_per_block
@@ -98,8 +96,6 @@ class EncoderCausal3D(nn.Module):
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_residual=add_residual,
slice_t=slice_t,
)

self.down_blocks.append(down_block)
@@ -171,8 +167,6 @@ class DecoderCausal3D(nn.Module):
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
dropout: float = 0.0,
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
self.layers_per_block = layers_per_block
@@ -227,8 +221,6 @@ class DecoderCausal3D(nn.Module):
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_residual=add_residual,
slice_t=slice_t,
)

self.up_blocks.append(up_block)


+ 0
- 621
opensora/models/hunyuan_vae/vae_channel.py View File

@@ -1,621 +0,0 @@
from typing import Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from opensora.acceleration.checkpoint import auto_grad_checkpoint, checkpoint
from opensora.models.hunyuan_vae.unet_causal_3d_blocks import (
CausalConv3d,
ResnetBlockCausal3D,
UNetMidBlockCausal3D,
chunk_nearest_interpolate,
)


def pixel_shuffle_channel_averaging(
input,
channel_factor=2,
factor=(1, 2, 2),
):
B, C, T, H, W = input.size()
assert T % factor[0] == 0 and H % factor[1] == 0 and W % factor[2] == 0
assert (factor[0] * factor[1] * factor[2]) % channel_factor == 0
output = input.view(B, C, T // factor[0], factor[0], H // factor[1], factor[1], W // factor[2], factor[2])
output = output.permute(0, 1, 3, 5, 7, 2, 4, 6)
output = output.contiguous().view(B, C * channel_factor, -1, T // factor[0], H // factor[1], W // factor[2])
output = output.mean(dim=2)
return output


def channel_repeat_pixel_shuffle(
input,
channel_factor=2,
factor=(1, 2, 2),
):
B, C, T, H, W = input.size()
assert factor[0] * factor[1] * factor[2] % channel_factor == 0
repeat = factor[0] * factor[1] * factor[2] // channel_factor
output = input.repeat_interleave(repeat, dim=1)
output = output.view(B, C // channel_factor, factor[0], factor[1], factor[2], T, H, W)
output = output.permute(0, 1, 5, 2, 6, 3, 7, 4)
output = output.contiguous().view(B, C // channel_factor, T * factor[0], H * factor[1], W * factor[2])
return output


class PixelUnshuffleChannelAveragingDownSampleLayer(nn.Module):
"""
residual for downsample layer;
if has downsample in T dim, add reshaping for T as well.
Note: (T-1),H,W must be multiples of 2
"""

def __init__(
self,
factor=(1, 2, 2), # can be (1,2,2) or (2,2,2)
slice_t=False, # either slice T or pad T if need to reduce the T dimension
channel_factor=2,
):
super().__init__()
self.factor = factor
self.slice_t = slice_t
self.time_causal_padding = (0, 0, 0, 0, 1, 0) # W, H, T
self.channel_factor = channel_factor

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert self.factor[0] == 1 or self.factor[0] == 2, f"unsupported temporal reduction {self.factor[0]}"
# shape check
T, H, W = x.shape[-3:]
assert (
(T - 1) % self.factor[0] == H % self.factor[1] == W % self.factor[2] == 0
), f"{T}-1, {W}, {H} not divisible by {self.factor}"
if self.factor[0] == 2: # temporal reduction
if self.slice_t:
if T > 1: # video
# slice T and process differently
first_f, other_f = x.split((1, T - 1), dim=2)
first_f = pixel_shuffle_channel_averaging(
first_f, channel_factor=self.channel_factor, factor=(1, self.factor[1], self.factor[2])
)
other_f = pixel_shuffle_channel_averaging(
other_f, channel_factor=self.channel_factor, factor=self.factor
)
residual = torch.cat((first_f, other_f), dim=2)
else: # image, only work on H & W
residual = pixel_shuffle_channel_averaging(
x, channel_factor=self.channel_factor, factor=self.factor
)
else: # use padding to handle temporal reduction
x = F.pad(x, self.time_causal_padding, mode="replicate")
# reshape and take average for shortcut
residual = pixel_shuffle_channel_averaging(x, channel_factor=self.channel_factor, factor=self.factor)
elif self.factor[0] == 1: # no temporal reduction
# reshape and take average for shortcut
residual = pixel_shuffle_channel_averaging(x, channel_factor=self.channel_factor, factor=self.factor)
else:
raise NotImplementedError

return residual


class ChannelDuplicatingPixelShuffleUpSampleLayer(nn.Module):
def __init__(
self,
factor=(1, 2, 2),
slice_t=False, # either slice T or pad T
channel_factor=2,
):
super().__init__()
self.factor = factor
self.slice_t = slice_t
self.channel_factor = channel_factor

def forward(self, x: torch.Tensor) -> torch.Tensor:
T = x.size(2)
if self.factor[0] == 2:
if T == 1: # image
residual = channel_repeat_pixel_shuffle(
x, channel_factor=self.channel_factor, factor=(1, self.factor[1], self.factor[2])
)
else: # video
if self.slice_t:
# slice T and process differently
first_f, other_f = x.split((1, T - 1), dim=2)
first_f = channel_repeat_pixel_shuffle(
first_f, channel_factor=self.channel_factor, factor=(1, self.factor[1], self.factor[2])
)
other_f = channel_repeat_pixel_shuffle(
other_f, channel_factor=self.channel_factor, factor=self.factor
)
residual = torch.cat((first_f, other_f), dim=2)
else:
residual = channel_repeat_pixel_shuffle(x, channel_factor=self.channel_factor, factor=self.factor)
residual = residual[:, :, 1:] # remove 1st frame TODO: this may not be wise
elif self.factor[0] == 1:
residual = channel_repeat_pixel_shuffle(x, channel_factor=self.channel_factor, factor=self.factor)
else:
raise NotImplementedError

return residual


class DownsampleCausal3D(nn.Module):
"""
A 3D downsampling layer with an optional convolution.
"""

def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3,
bias=True,
stride=2,
add_residual=False,
slice_t=False,
):
super().__init__()
self.channels = in_channels
self.out_channels = out_channels
assert self.out_channels % self.channels == 0
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias)
self.add_residual = add_residual
if self.add_residual:
self.shortcut = PixelUnshuffleChannelAveragingDownSampleLayer(
factor=stride, slice_t=slice_t, channel_factor=self.out_channels // self.channels
)

def forward(self, input_tensor: torch.FloatTensor) -> torch.FloatTensor:
assert input_tensor.shape[1] == self.channels
hidden_states = self.conv(input_tensor)

if self.add_residual:
residual = self.shortcut(input_tensor)
hidden_states += residual

return hidden_states


class DownEncoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_downsample: bool = True,
downsample_stride: int = 2,
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
resnets = []

for i in range(num_layers):
if add_downsample:
target_channel = in_channels
else:
target_channel = in_channels if i < num_layers - 1 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=target_channel,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)

self.resnets = nn.ModuleList(resnets)

if add_downsample:
self.downsamplers = nn.ModuleList(
[
DownsampleCausal3D(
in_channels,
out_channels,
stride=downsample_stride,
add_residual=add_residual,
slice_t=slice_t,
)
]
)
else:
self.downsamplers = None

def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = auto_grad_checkpoint(resnet, hidden_states)

if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = auto_grad_checkpoint(downsampler, hidden_states)

return hidden_states


class ChannelEncoderCausal3D(nn.Module):
r"""
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
"""

def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
dropout: float = 0.0,
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
self.mid_block = None
self.down_blocks = nn.ModuleList([])

# down
output_channel = block_out_channels[0]
for i, _ in enumerate(block_out_channels):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
num_time_downsample_layers = int(np.log2(time_compression_ratio))

if time_compression_ratio == 1:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = False
elif time_compression_ratio == 4:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(
i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
)
elif time_compression_ratio == 8:
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
add_time_downsample = bool(i < num_spatial_downsample_layers)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")

downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
downsample_stride_T = (2,) if add_time_downsample else (1,)
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
down_block = DownEncoderBlockCausal3D(
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
dropout=dropout,
add_downsample=bool(add_spatial_downsample or add_time_downsample),
downsample_stride=downsample_stride,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_residual=add_residual,
slice_t=slice_t,
)

self.down_blocks.append(down_block)

# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=mid_block_add_attention,
)

# out
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()

conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)

def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `EncoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"

sample = self.conv_in(sample)

# down
for down_block in self.down_blocks:
sample = down_block(sample)

# middle
sample = auto_grad_checkpoint(self.mid_block, sample)

# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)

return sample


class UpsampleCausal3D(nn.Module):
"""
A 3D upsampling layer with an optional convolution.
"""

def __init__(
self,
channels: int,
out_channels: int,
kernel_size: int = 3,
bias=True,
upsample_factor=(2, 2, 2),
add_residual=False,
slice_t=False,
):
super().__init__()
self.channels = channels
self.out_channels = out_channels
assert channels % out_channels == 0
self.upsample_factor = upsample_factor
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
self.add_residual = add_residual
if self.add_residual:
self.shortcut = ChannelDuplicatingPixelShuffleUpSampleLayer(
factor=upsample_factor, slice_t=slice_t, channel_factor=channels // out_channels
)

def forward(
self,
input_tensor: torch.FloatTensor,
) -> torch.FloatTensor:
assert input_tensor.shape[1] == self.channels

#######################
# handle hidden states
#######################
hidden_states = input_tensor
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
# dtype = hidden_states.dtype
# if dtype == torch.bfloat16:
# hidden_states = hidden_states.to(torch.float32)

# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()

# interpolate H & W only for the first frame; interpolate T & H & W for the rest
T = hidden_states.size(2)
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
# process non-1st frames
if T > 1:
other_h = chunk_nearest_interpolate(other_h, scale_factor=self.upsample_factor)
# proess 1st fram
first_h = first_h.squeeze(2)
first_h = chunk_nearest_interpolate(first_h, scale_factor=self.upsample_factor[1:])
first_h = first_h.unsqueeze(2)
# concat together
if T > 1:
hidden_states = torch.cat((first_h, other_h), dim=2)
else:
hidden_states = first_h

# If the input is bfloat16, we cast back to bfloat16
# if dtype == torch.bfloat16:
# hidden_states = hidden_states.to(dtype)

hidden_states = self.conv(hidden_states)

#######################
# handle residual
#######################
if self.add_residual:
residual = self.shortcut(input_tensor)
hidden_states += residual

return hidden_states


class UpDecoderBlockCausal3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
resolution_idx: Optional[int] = None,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor: float = 1.0,
add_upsample: bool = True,
upsample_scale_factor=(2, 2, 2),
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
resnets = []

for i in range(num_layers):
if add_upsample:
target_channel = in_channels
else:
target_channel = in_channels if i < num_layers - 1 else out_channels
resnets.append(
ResnetBlockCausal3D(
in_channels=in_channels,
out_channels=target_channel,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)

self.resnets = nn.ModuleList(resnets)

if add_upsample:
self.upsamplers = nn.ModuleList(
[
UpsampleCausal3D(
in_channels,
out_channels=out_channels,
upsample_factor=upsample_scale_factor,
add_residual=add_residual,
slice_t=slice_t,
)
]
)
else:
self.upsamplers = None

self.resolution_idx = resolution_idx

def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
for resnet in self.resnets:
hidden_states = auto_grad_checkpoint(resnet, hidden_states)

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = auto_grad_checkpoint(upsampler, hidden_states)

return hidden_states


class ChannelDecoderCausal3D(nn.Module):
r"""
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
"""

def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
mid_block_add_attention=True,
time_compression_ratio: int = 4,
spatial_compression_ratio: int = 8,
dropout: float = 0.0,
add_residual: bool = False,
slice_t: bool = False,
):
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
self.mid_block = None
self.up_blocks = nn.ModuleList([])

# mid
self.mid_block = UNetMidBlockCausal3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
add_attention=mid_block_add_attention,
)

# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, _ in enumerate(block_out_channels):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
num_time_upsample_layers = int(np.log2(time_compression_ratio))

if time_compression_ratio == 1:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = False
elif time_compression_ratio == 4:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(
i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
)
elif time_compression_ratio == 8:
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
add_time_upsample = bool(i < num_spatial_upsample_layers)
else:
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")

upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
up_block = UpDecoderBlockCausal3D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
out_channels=output_channel,
resolution_idx=None,
dropout=dropout,
add_upsample=bool(add_spatial_upsample or add_time_upsample),
upsample_scale_factor=upsample_scale_factor,
resnet_eps=1e-6,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
add_residual=add_residual,
slice_t=slice_t,
)

self.up_blocks.append(up_block)
prev_output_channel = output_channel

self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)

def post_process(self, sample: torch.Tensor) -> torch.Tensor:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
return sample

def forward(
self,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
r"""The forward method of the `DecoderCausal3D` class."""
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."

sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype

# middle
sample = auto_grad_checkpoint(self.mid_block, sample)
sample = sample.to(upscale_dtype)

# up
for up_block in self.up_blocks:
sample = up_block(sample)

# post-process
if getattr(self, "grad_checkpointing", False):
sample = checkpoint(self.post_process, sample, use_reentrant=True)
else:
sample = self.post_process(sample)

sample = self.conv_out(sample)

return sample

+ 13
- 6
opensora/models/mmdit/distributed.py View File

@@ -4,19 +4,24 @@ from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.shardformer.layer import FusedLinear1D_Col, FusedLinear1D_Row, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import (FusedLinear1D_Col, FusedLinear1D_Row,
Linear1D_Col, Linear1D_Row)
from colossalai.shardformer.layer._operation import all_to_all_comm
from colossalai.shardformer.layer.attn import RingComm, _rescale_out_lse
from colossalai.shardformer.layer.utils import is_share_sp_tp
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.policies.base_policy import (
ModulePolicyDescription, Policy, SubModuleReplacementDescription)
from colossalai.shardformer.shard import ShardConfig
from einops import rearrange
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
from flash_attn.flash_attn_interface import (_flash_attn_backward,
_flash_attn_forward)
from liger_kernel.ops.rope import LigerRopeFunction

try:
from flash_attn_interface import _flash_attn_backward as _flash_attn_backward_v3
from flash_attn_interface import _flash_attn_forward as _flash_attn_forward_v3
from flash_attn_interface import \
_flash_attn_backward as _flash_attn_backward_v3
from flash_attn_interface import \
_flash_attn_forward as _flash_attn_forward_v3

SUPPORT_FA3 = True
except:
@@ -179,6 +184,7 @@ def _fa_backward(
v,
out,
softmax_lse,
None, None, None, None, None, None,
dq,
dk,
dv,
@@ -201,7 +207,8 @@ def _fa_backward(
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=False,
window_size=(-1, -1),
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
deterministic=deterministic,


+ 0
- 12
opensora/models/vae/losses.py View File

@@ -221,15 +221,3 @@ class DiscriminatorLoss(nn.Module):
weighted_discriminator_loss = 0

return weighted_discriminator_loss


def cal_opl_loss(z: Tensor, weight: float = 1e5):
z = rearrange(z, "b c t h w -> (b t) c h w")
opl_loss = (
((z - z.mean(dim=(2, 3), keepdim=True)).norm(dim=1) - 3 * z.std(dim=(2, 3), keepdim=True).norm(dim=1))
.clamp(min=0)
.mean()
)
# opl_loss = opl_loss.mean()
opl_loss = weight * opl_loss
return opl_loss

+ 4
- 6
scripts/diffusion/train.py View File

@@ -222,7 +222,7 @@ def main():
del model_ae.decoder
log_cuda_memory("autoencoder")
log_model_params(model_ae)
# model_ae = torch.compile(model_ae, mode="max-autotune", fullgraph=True, dynamic=True)
model_ae.encode = torch.compile(model_ae.encoder, dynamic=True)

if not cfg.get("cached_text", False):
# == build text encoder (t5) ==
@@ -312,8 +312,9 @@ def main():
logger.info("Loaded checkpoint %s at epoch %s step %s", cfg.load, ret[0], ret[1])

# load optimizer and scheduler will overwrite some of the hyperparameters, so we need to reset them
if cfg.get("lr", None) is not None:
set_lr(optimizer, lr_scheduler, cfg.lr, cfg.get("initial_lr", None))
set_lr(optimizer, lr_scheduler, cfg.optim.lr, cfg.get("initial_lr", None))
set_eps(optimizer, cfg.optim.eps)

if cfg.get("update_warmup_steps", False):
assert (
cfg.get("warmup_steps", None) is not None
@@ -321,9 +322,6 @@ def main():
# set_warmup_steps(lr_scheduler, cfg.warmup_steps)
lr_scheduler.step(start_epoch * num_steps_per_epoch + start_step)
logger.info("The learning rate starts from %s", optimizer.param_groups[0]["lr"])
if cfg.get("eps", False):
set_eps(optimizer, cfg.eps)

if start_step is not None:
# if start step exceeds data length, go to next epoch
if start_step > num_steps_per_epoch:


+ 1
- 10
scripts/vae/train.py View File

@@ -24,7 +24,7 @@ from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
from opensora.datasets.pin_memory_cache import PinMemoryCache
from opensora.models.vae.losses import DiscriminatorLoss, GeneratorLoss, VAELoss, cal_opl_loss
from opensora.models.vae.losses import DiscriminatorLoss, GeneratorLoss, VAELoss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt import CheckpointIO, model_sharding, record_model_param_shape, rm_checkpoints
from opensora.utils.config import config_to_name, create_experiment_workspace, parse_configs
@@ -226,7 +226,6 @@ def main():
nll_rec=0.0,
nll_per=0.0,
kl=0.0,
opl=0.0,
gen=0.0,
gen_w=0.0,
disc=0.0,
@@ -402,12 +401,6 @@ def main():
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
loss_dict = {} # loss at every step

# == opl loss ==
opl_loss_weight = cfg.get("opl_loss_weight", 0)
if opl_loss_weight > 0:
opl_loss = cal_opl_loss(z, opl_loss_weight)
vae_loss += opl_loss

# == reconstruction loss ==
ret = vae_loss_fn(x, x_rec, posterior)
nll_loss = ret["nll_loss"]
@@ -469,8 +462,6 @@ def main():
log_loss("nll_rec", recon_loss, loss_dict, use_video)
log_loss("nll_per", perceptual_loss, loss_dict, use_video)
log_loss("kl", kl_loss, loss_dict, use_video)
if opl_loss_weight > 0:
log_loss("opl", opl_loss, loss_dict, use_video)
if use_discriminator:
log_loss("gen_w", generator_loss, loss_dict, use_video)
log_loss("gen", g_loss, loss_dict, use_video)


Loading…
Cancel
Save
Baidu
map