14 Commits

71 changed files with 2165 additions and 2698 deletions
Split View
  1. +1
    -0
      .gitignore
  2. +25
    -3
      README.md
  3. +1
    -1
      configs/diffusion/inference/256px.py
  4. +20
    -16
      configs/diffusion/inference/high_compression.py
  5. +0
    -89
      configs/diffusion/train/cache_stage2.py
  6. +0
    -30
      configs/diffusion/train/dcae.py
  7. +0
    -8
      configs/diffusion/train/debug.py
  8. +12
    -0
      configs/diffusion/train/demo.py
  9. +70
    -0
      configs/diffusion/train/high_compression.py
  10. +23
    -25
      configs/diffusion/train/image.py
  11. +9
    -20
      configs/diffusion/train/stage1.py
  12. +5
    -22
      configs/diffusion/train/stage1_i2v.py
  13. +0
    -12
      configs/diffusion/train/stage1_v2v.py
  14. +8
    -29
      configs/diffusion/train/stage2.py
  15. +0
    -55
      configs/diffusion/train/stage2_cache.py
  16. +3
    -30
      configs/diffusion/train/stage2_i2v.py
  17. +0
    -53
      configs/diffusion/train/stage2_v1.py
  18. +0
    -43
      configs/diffusion/train/stage2_v2.py
  19. +0
    -35
      configs/vae/train/causal_dcae.py
  20. +83
    -0
      configs/vae/train/video_dc_ae.py
  21. +33
    -0
      configs/vae/train/video_dc_ae_disc.py
  22. +18
    -0
      docs/ae.md
  23. +11
    -0
      docs/hcae.md
  24. +201
    -0
      docs/train.md
  25. +3
    -1
      opensora/acceleration/parallel_states.py
  26. +24
    -8
      opensora/datasets/aspect.py
  27. +1
    -1
      opensora/models/dc_ae/__init__.py
  28. +43
    -6
      opensora/models/dc_ae/ae_model_zoo.py
  29. +0
    -0
      opensora/models/dc_ae/efficientvit/__init__.py
  30. +0
    -0
      opensora/models/dc_ae/efficientvit/apps/__init__.py
  31. +0
    -102
      opensora/models/dc_ae/efficientvit/apps/setup.py
  32. +0
    -1
      opensora/models/dc_ae/efficientvit/apps/trainer/__init__.py
  33. +0
    -128
      opensora/models/dc_ae/efficientvit/apps/trainer/run_config.py
  34. +0
    -10
      opensora/models/dc_ae/efficientvit/apps/utils/__init__.py
  35. +0
    -91
      opensora/models/dc_ae/efficientvit/apps/utils/dist.py
  36. +0
    -54
      opensora/models/dc_ae/efficientvit/apps/utils/ema.py
  37. +0
    -58
      opensora/models/dc_ae/efficientvit/apps/utils/export.py
  38. +0
    -190
      opensora/models/dc_ae/efficientvit/apps/utils/image.py
  39. +0
    -79
      opensora/models/dc_ae/efficientvit/apps/utils/lr.py
  40. +0
    -47
      opensora/models/dc_ae/efficientvit/apps/utils/metric.py
  41. +0
    -114
      opensora/models/dc_ae/efficientvit/apps/utils/misc.py
  42. +0
    -42
      opensora/models/dc_ae/efficientvit/apps/utils/opt.py
  43. +0
    -0
      opensora/models/dc_ae/efficientvit/models/__init__.py
  44. +0
    -102
      opensora/models/dc_ae/efficientvit/models/nn/drop.py
  45. +0
    -183
      opensora/models/dc_ae/efficientvit/models/nn/norm.py
  46. +0
    -207
      opensora/models/dc_ae/efficientvit/models/nn/triton_rms_norm.py
  47. +0
    -3
      opensora/models/dc_ae/efficientvit/models/utils/__init__.py
  48. +0
    -111
      opensora/models/dc_ae/efficientvit/models/utils/network.py
  49. +0
    -79
      opensora/models/dc_ae/efficientvit/models/utils/random.py
  50. +0
    -0
      opensora/models/dc_ae/models/__init__.py
  51. +423
    -54
      opensora/models/dc_ae/models/dc_ae.py
  52. +0
    -2
      opensora/models/dc_ae/models/nn/__init__.py
  53. +2
    -1
      opensora/models/dc_ae/models/nn/act.py
  54. +98
    -0
      opensora/models/dc_ae/models/nn/norm.py
  55. +189
    -46
      opensora/models/dc_ae/models/nn/ops.py
  56. +244
    -0
      opensora/models/dc_ae/models/nn/vo_ops.py
  57. +3
    -0
      opensora/models/dc_ae/utils/__init__.py
  58. +2
    -2
      opensora/models/dc_ae/utils/init.py
  59. +1
    -0
      opensora/models/dc_ae/utils/list.py
  60. +40
    -11
      opensora/models/mmdit/model.py
  61. +29
    -8
      opensora/utils/ckpt.py
  62. +10
    -2
      opensora/utils/config.py
  63. +107
    -33
      opensora/utils/sampling.py
  64. +45
    -36
      opensora/utils/train.py
  65. +1
    -0
      requirements.txt
  66. +0
    -90
      scripts/cnv/extend_csv.py
  67. +70
    -0
      scripts/cnv/meta.py
  68. +23
    -4
      scripts/cnv/shard.py
  69. +0
    -59
      scripts/cnv/txt2csv.py
  70. +25
    -11
      scripts/diffusion/train.py
  71. +259
    -251
      scripts/vae/train.py

+ 1
- 0
.gitignore View File

@@ -195,3 +195,4 @@ package.json
exps
ckpts
flash-attention
datasets

+ 25
- 3
README.md View File

@@ -37,7 +37,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi

## 📰 News

- **[2025.03.17]** 🔥 We released **Open-Sora 2.0** (11B). with MMDiT structure and optimized for image-to-video generation, the model generates high quality of videos (t2v, i2v, t2i2v) with 256x256 and 768x768 resolution. An attempt to adapt for a high-compression autoencoder is also presented. 😚 All training codes are released! [[report]]()
- **[2025.03.17]** 🔥 We released **Open-Sora 2.0** (11B). With MMDiT structure and optimized for image-to-video generation, the model generates high quality of videos (t2v, i2v, t2i2v) with 256x256 and 768x768 resolution. An attempt to adapt for a high-compression autoencoder is also presented. 😚 All training codes are released! [[report]]()
- **[2025.02.20]** 🔥 We released **Open-Sora 1.3** (1B). With the upgraded VAE and Transformer architecture, the quality of our generated videos has been greatly improved 🚀. [[checkpoints]](#open-sora-13-model-weights) [[report]](/docs/report_04.md) [[demo]](https://huggingface.co/spaces/hpcai-tech/open-sora)
- **[2024.12.23]** The development cost of video generation models has saved by 50%! Open-source solutions are now available with H200 GPU vouchers. [[blog]](https://company.hpc-ai.com/blog/the-development-cost-of-video-generation-models-has-saved-by-50-open-source-solutions-are-now-available-with-h200-gpu-vouchers) [[code]](https://github.com/hpcaitech/Open-Sora/blob/main/scripts/train.py) [[vouchers]](https://colossalai.org/zh-Hans/docs/get_started/bonus/)
- **[2024.06.17]** We released **Open-Sora 1.2**, which includes **3D-VAE**, **rectified flow**, and **score condition**. The video quality is greatly improved. [[checkpoints]](#open-sora-12-model-weights) [[report]](/docs/report_03.md) [[arxiv]](https://arxiv.org/abs/2412.20404)
@@ -58,8 +58,12 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi

More samples and corresponding prompts are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/).

| | | |
| --------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [<img src="https://github.com/hpcaitech/Open-Sora-Demo/blob/main/demo/v2.0/demo.gif" width="">](https://streamable.com/e/r0imrp?quality=highest&amp;autoplay=1) | [<img src="https://github.com/hpcaitech/Open-Sora-Demo/blob/main/demo/v2.0/demo.gif" width="">](https://streamable.com/e/hfvjkh?quality=highest&amp;autoplay=1) | [<img src="https://github.com/hpcaitech/Open-Sora-Demo/blob/main/demo/v2.0/demo.gif" width="">](https://streamable.com/e/kutmma?quality=highest&amp;autoplay=1) |

<details>
<summary>OpenSora 1.2 Demo</summary>
<summary>OpenSora 1.3 Demo</summary>

| **5s 720×1280** | **5s 720×1280** | **5s 720×1280** |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
@@ -225,6 +229,22 @@ torchrun --nproc_per_node 8 --standalone scripts/diffusion/inference.py configs/

### Motion Score

During training, we provide motion score into the text prompt. During inference, you can use the following command to generate videos with motion score (the default score is 4):

```bash
torchrun --nproc_per_node 1 --standalone scripts/diffusion/inference.py configs/diffusion/inference/t2i2v_256px.py --save-dir samples --prompt "raining, sea" --motion-score 4
```

We also provide a dynamic motion score evaluator. After setting your OpenAI API key, you can use the following command to evaluate the motion score of a video:

```bash
torchrun --nproc_per_node 1 --standalone scripts/diffusion/inference.py configs/diffusion/inference/t2i2v_256px.py --save-dir samples --prompt "raining, sea" --motion-score dynamic
```

| Score | 1 | 4 | 7 |
| ----- | ------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------- |
| | <img src="https://github.com/hpcaitech/Open-Sora-Demo/blob/main/demo/v2.0/motion_score_1.gif" width=""> | <img src="https://github.com/hpcaitech/Open-Sora-Demo/blob/main/demo/v2.0/motion_score_4.gif" width=""> | <img src="https://github.com/hpcaitech/Open-Sora-Demo/blob/main/demo/v2.0/motion_score_7.gif" width=""> |

### Prompt Refine

We take advantage of ChatGPT to refine the prompt. You can use the following command to refine the prompt. The function is available for both text-to-video and image-to-video generation.
@@ -242,9 +262,11 @@ To make the results reproducible, you can set the random seed by:
torchrun --nproc_per_node 1 --standalone scripts/diffusion/inference.py configs/diffusion/inference/t2i2v_256px.py --save-dir samples --prompt "raining, sea" --sampling_option.seed 42 --seed 42
```

Use `--num-sample k` to generate `k` samples for each prompt.

## Computational Efficiency

We test the computational efficiency on H100/H800 GPU. For 256x256, we use colossalai's tensor parallelism. For 768x768, we use colossalai's sequence parallelism. All use number of steps 50. The results are presented in the format: $\color{blue}{\text{Total time (s)}}/\color{red}{\text{peak GPU memory (GB)}}$
We test the computational efficiency of text-to-video on H100/H800 GPU. For 256x256, we use colossalai's tensor parallelism. For 768x768, we use colossalai's sequence parallelism. All use number of steps 50. The results are presented in the format: $\color{blue}{\text{Total time (s)}}/\color{red}{\text{peak GPU memory (GB)}}$

| Resolution | 1x GPU | 2x GPUs | 4x GPUs | 8x GPUs |
| ---------- | -------------------------------------- | ------------------------------------- | ------------------------------------- | ------------------------------------- |


+ 1
- 1
configs/diffusion/inference/256px.py View File

@@ -29,7 +29,7 @@ sampling_option = dict(
method="i2v", # hard-coded for now
seed=None, # random seed for z
)
motion_score = 4 # motion score for video generation
motion_score = "4" # motion score for video generation
fps_save = 24 # fps for video generation and saving

# Define model components


+ 20
- 16
configs/diffusion/inference/high_compression.py View File

@@ -1,26 +1,30 @@
_base_ = ["t2i2v_768px.py"]

# no need for parallelism
plugin = None
plugin_config = None
plugin_ae = None
plugin_config_ae = None

# model settings
patch_size = 1
model = dict(
from_pretrained=None,
grad_ckpt_settings=None,
in_channels=512,
in_channels=128,
cond_embed=True,
patch_size=1,
)

# AE settings
ae = dict(
_delete_=True,
type="dc_ae",
model_name="dc-ae-f128c512-sana-1.0",
from_scratch=True,
from_pretrained="/home/chenli/luchen/Open-Sora-Dev/outputs/250211_114721-vae_train_sana_2d_32channel/epoch13-global_step2000/model/model-00001.safetensors",
)

sampling_option = dict(
resolution="1024px",
aspect_ratio="1:1",
num_frames=16,
num_steps=50,
temporal_reduction=1,
is_causal_vae=False,
seed=42,
model_name="dc-ae-f32t4c128",
from_pretrained="./ckpts/F32T4C128_AE.safetensors",
use_spatial_tiling=True,
use_temporal_tiling=True,
spatial_tile_size=256,
temporal_tile_size=32,
tile_overlap_factor=0.25,
)
fps_save = 24
ae_spatial_compression = 32

+ 0
- 89
configs/diffusion/train/cache_stage2.py View File

@@ -1,89 +0,0 @@
to_cache_text = False
to_cache_video = True
cached_text = False
cached_video = False
# exist_handling = "ignore"

dataset = dict(
type="cached_video_text",
transform_name="resize_crop",
fps_max=24,
return_latents_path=to_cache_text or to_cache_video,
cached_text=False,
cached_video=False,
)

bucket_config = {
"768px": {
1: (1.0, 50), # 20s/it
5: (1.0, 15),
9: (1.0, 15),
13: (1.0, 15),
17: (1.0, 15),
21: (1.0, 15),
25: (1.0, 15),
29: (1.0, 15),
33: (1.0, 15),
37: (1.0, 10),
41: (1.0, 10),
45: (1.0, 10),
49: (1.0, 10),
53: (1.0, 10),
57: (1.0, 10),
61: (1.0, 10),
65: (1.0, 10),
69: (1.0, 7),
73: (1.0, 7),
77: (1.0, 7),
81: (1.0, 7),
85: (1.0, 7),
89: (1.0, 7),
93: (1.0, 7),
97: (1.0, 7),
101: (1.0, 6),
105: (1.0, 6),
109: (1.0, 6),
113: (1.0, 6),
117: (1.0, 6),
121: (1.0, 6),
125: (1.0, 6),
129: (1.0, 6), # 46s
},
}
pin_memory_cache_pre_alloc_numels = [(260 + 20) * 1024 * 1024] * 24 + [(34 + 20) * 1024 * 1024] * 4
# record_time = True
# record_barrier = True

ae = dict(
type="hunyuan_vae",
from_pretrained="/mnt/jfs-hdd/sora/checkpoints/pretrained_models/hunyuan-video-t2v-720p/vae/pytorch_model.pt",
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=16,
use_spatial_tiling=True,
use_temporal_tiling=False,
)
t5 = dict(
type="text_embedder",
from_pretrained="google/t5-v1_1-xxl",
cache_dir="/mnt/ddn/sora/tmp_load/huggingface/hub/",
max_length=512,
shardformer=True,
)
clip = dict(
type="text_embedder",
from_pretrained="openai/clip-vit-large-patch14",
cache_dir="/mnt/ddn/sora/tmp_load/huggingface/hub/",
max_length=77,
)

# Acceleration settings
prefetch_factor = 2
num_workers = 6
num_bucket_build_workers = 64
dtype = "bf16"

# Other settings
seed = 42
wandb_project = "mmdit"

+ 0
- 30
configs/diffusion/train/dcae.py View File

@@ -1,30 +0,0 @@
_base_ = ["image.py"]

bucket_config = {
"_delete_": True,
"1024px_ar1:1": {16: (1.0, 16)},
}

patch_size = 1
model = dict(
from_pretrained="/mnt/ddn/sora/tmp_load/vo2_1_768px_t2v_adapt.pt",
grad_ckpt_settings=None,
in_channels=512,
)
ae = dict(
_delete_=True,
type="dc_ae",
model_name="dc-ae-f128c512-sana-1.0",
from_scratch=True,
from_pretrained="/home/chenli/luchen/Open-Sora-Dev/outputs/250210_125346-vae_train_sana_2d_256channel/epoch0-global_step15500",
)

pin_memory_cache_pre_alloc_numels = [(260 + 20) * 1024 * 1024] * 24 + [(34 + 20) * 1024 * 1024] * 4
lr = 5e-5
optim = dict(
lr=lr,
)
ema_decay = None
ckpt_every = 500 # save every 4 hours
keep_n_latest = 20
wandb_project = "dcae-adapt"

+ 0
- 8
configs/diffusion/train/debug.py View File

@@ -1,8 +0,0 @@
_base_ = ["stage1_i2v.py"]

bucket_config = {
"_delete_": True,
"256px": {
129: (1.0, 1),
},
}

+ 12
- 0
configs/diffusion/train/demo.py View File

@@ -0,0 +1,12 @@
_base_ = ["stage1.py"]


bucket_config = {
"_delete_": True,
"256px": {
1: (1.0, 1),
33: (1.0, 1),
97: (1.0, 1),
129: (1.0, 1),
},
}

+ 70
- 0
configs/diffusion/train/high_compression.py View File

@@ -0,0 +1,70 @@
_base_ = ["image.py"]

bucket_config = {
"_delete_": True,
"768px": {
1: (1.0, 20),
16: (1.0, 8),
20: (1.0, 8),
24: (1.0, 8),
28: (1.0, 8),
32: (1.0, 8),
36: (1.0, 4),
40: (1.0, 4),
44: (1.0, 4),
48: (1.0, 4),
52: (1.0, 4),
56: (1.0, 4),
60: (1.0, 4),
64: (1.0, 4),
68: (1.0, 3),
72: (1.0, 3),
76: (1.0, 3),
80: (1.0, 3),
84: (1.0, 3),
88: (1.0, 3),
92: (1.0, 3),
96: (1.0, 3),
100: (1.0, 2),
104: (1.0, 2),
108: (1.0, 2),
112: (1.0, 2),
116: (1.0, 2),
120: (1.0, 2),
124: (1.0, 2),
128: (1.0, 2), # 30s
},
}

condition_config = dict(
t2v=1,
i2v_head=7,
)

patch_size = 1
model = dict(
from_pretrained=None,
grad_ckpt_settings=None,
in_channels=128,
cond_embed=True,
patch_size=patch_size,
)
ae = dict(
_delete_=True,
type="dc_ae",
model_name="dc-ae-f32t4c128",
from_pretrained="./ckpts/F32T4C128_AE.safetensors",
from_scratch=True,
scaling_factor=0.493,
use_spatial_tiling=True,
use_temporal_tiling=True,
spatial_tile_size=256,
temporal_tile_size=32,
tile_overlap_factor=0.25,
)
is_causal_vae = False
ae_spatial_compression = 32

ckpt_every = 250
lr = 3e-5
optim = dict(lr=lr)

+ 23
- 25
configs/diffusion/train/image.py View File

@@ -2,34 +2,20 @@
dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=24,
vmaf=True,
fps_max=24, # the desired fps for training
vmaf=True, # load vmaf scores into text
)

# new config
grad_ckpt_settings = (8, 100)
grad_ckpt_settings = (8, 100) # set the grad checkpoint settings
bucket_config = {
"256px": {
1: (1.0, 50),
},
"768px": {
1: (0.5, 11),
},
"1024px": {
1: (0.5, 7),
},
"256px": {1: (1.0, 50)},
"768px": {1: (0.5, 11)},
"1024px": {1: (0.5, 7)},
}
# 6s/it (4x8 GPUs)

# record_time = True
# record_barrier = True
warmup_ae = False
pin_memory_cache_pre_alloc_numels = None

# Define model components
model = dict(
type="flux",
# from_pretrained="/mnt/ddn/sora/tmp_load/flux1-dev-fused-rope.safetensors",‘
from_pretrained=None,
strict_load=False,
guidance_embed=False,
@@ -49,13 +35,13 @@ model = dict(
theta=10_000,
qkv_bias=True,
)
dropout_ratio = {
dropout_ratio = { # probability for dropout text embedding
"t5": 0.31622777,
"clip": 0.31622777,
}
ae = dict(
type="hunyuan_vae",
from_pretrained="/mnt/jfs-hdd/sora/checkpoints/pretrained_models/hunyuan-video-t2v-720p/vae/pytorch_model.pt",
from_pretrained="./ckpts/hunyuan_vae.safetensors",
in_channels=3,
out_channels=3,
layers_per_block=2,
@@ -63,6 +49,7 @@ ae = dict(
use_spatial_tiling=True,
use_temporal_tiling=False,
)
is_causal_vae = True
t5 = dict(
type="text_embedder",
from_pretrained="google/t5-v1_1-xxl",
@@ -77,9 +64,9 @@ clip = dict(
max_length=77,
)

lr = 1e-5 # this will updated optim again after it finishes loading, important
eps = 1e-15 # this will updated optim again after it finishes loading, important
# Optimization settings
lr = 1e-5
eps = 1e-15
optim = dict(
cls="HybridAdam",
lr=lr,
@@ -92,7 +79,7 @@ update_warmup_steps = True

grad_clip = 1.0
accumulation_steps = 1
ema_decay = 0.99
ema_decay = None

# Acceleration settings
prefetch_factor = 2
@@ -105,6 +92,10 @@ plugin_config = dict(
reduce_bucket_size_in_m=128,
overlap_allgather=False,
)
pin_memory_cache_pre_alloc_numels = [(260 + 20) * 1024 * 1024] * 24 + [
(34 + 20) * 1024 * 1024
] * 4
async_io = False

# Other settings
seed = 42
@@ -114,3 +105,10 @@ log_every = 10
ckpt_every = 100
keep_n_latest = 20
wandb_project = "mmdit"

save_master_weights = True
load_master_weights = True

# For debugging
# record_time = True
# record_barrier = True

+ 9
- 20
configs/diffusion/train/stage1.py View File

@@ -1,15 +1,13 @@
_base_ = ["image.py"]

dataset = dict(
memory_efficient=True,
)
dataset = dict(memory_efficient=False)

# new config
grad_ckpt_settings = (8, 100)
bucket_config = {
"_delete_": True,
"256px": {
1: (1.0, 45), # 6.22 s
1: (1.0, 45),
5: (1.0, 12),
9: (1.0, 12),
13: (1.0, 12),
@@ -17,7 +15,7 @@ bucket_config = {
21: (1.0, 12),
25: (1.0, 12),
29: (1.0, 12),
33: (1.0, 12), # 7.02 s
33: (1.0, 12),
37: (1.0, 6),
41: (1.0, 6),
45: (1.0, 6),
@@ -25,7 +23,7 @@ bucket_config = {
53: (1.0, 6),
57: (1.0, 6),
61: (1.0, 6),
65: (1.0, 6), # 6.79 s
65: (1.0, 6),
69: (1.0, 4),
73: (1.0, 4),
77: (1.0, 4),
@@ -33,7 +31,7 @@ bucket_config = {
85: (1.0, 4),
89: (1.0, 4),
93: (1.0, 4),
97: (1.0, 4), # 6.84 s
97: (1.0, 4),
101: (1.0, 3),
105: (1.0, 3),
109: (1.0, 3),
@@ -41,7 +39,7 @@ bucket_config = {
117: (1.0, 3),
121: (1.0, 3),
125: (1.0, 3),
129: (1.0, 3), # 7.48 s
129: (1.0, 3),
},
"768px": {
1: (0.5, 13),
@@ -50,18 +48,9 @@ bucket_config = {
1: (0.5, 7),
},
}
pin_memory_cache_pre_alloc_numels = [(260 + 20) * 1024 * 1024] * 24 + [(34 + 20) * 1024 * 1024] * 4

model = dict(
from_pretrained=None,
grad_ckpt_settings=grad_ckpt_settings,
)
model = dict(grad_ckpt_settings=grad_ckpt_settings)
lr = 5e-5
optim = dict(
lr=lr,
)
ema_decay = 0.999
ckpt_every = 2000 # save every 4 hours
optim = dict(lr=lr)
ckpt_every = 2000
keep_n_latest = 20
wandb_project = "mmdit-vo3"
async_io = False

+ 5
- 22
configs/diffusion/train/stage1_i2v.py View File

@@ -1,31 +1,14 @@
_base_ = ["stage1.py"]

dataset = dict(memory_efficient=False)

# Define model components
model = dict(
cond_embed=True,
)
model = dict(cond_embed=True)

condition_config = dict(
t2v=1,
i2v_head=5,
i2v_loop=1,
i2v_tail=1,
i2v_head=5, # train i2v (image as first frame) with weight 5
i2v_loop=1, # train image connection with weight 1
i2v_tail=1, # train i2v (image as last frame) with weight 1
)
is_causal_vae = True

lr = 1e-5
optim = dict(
lr=lr,
)
ema_decay = None
async_io = False

plugin = "hybrid"
plugin_config = dict(
tp_size=1,
pp_size=1,
sp_size=1,
zero_stage=2,
)
optim = dict(lr=lr)

+ 0
- 12
configs/diffusion/train/stage1_v2v.py View File

@@ -1,12 +0,0 @@
_base_ = ["stage1_i2v.py"]

condition_config = dict(
t2v=1,
i2v_head=5,
i2v_loop=1,
i2v_tail=1,
v2v_head=1,
v2v_head_easy=1,
v2v_tail=0.5,
v2v_tail_easy=0.5,
)

+ 8
- 29
configs/diffusion/train/stage2.py View File

@@ -1,8 +1,8 @@
_base_ = ["image.py"]

# new config
grad_ckpt_settings = (100, 100) # one GPU
# grad_ckpt_buffer_size = 20 * 1024**3
grad_ckpt_settings = (100, 100)
plugin = "hybrid"
plugin_config = dict(
tp_size=1,
@@ -25,10 +25,7 @@ bucket_config = {
21: (1.0, 14),
25: (1.0, 14),
29: (1.0, 14),
33: (
1.0,
14,
), # 7.02 s iter: 4.17 s | encode_video: 1.42 s | encode_text: 0.29 s | forward: 0.53 s | backward: 1.67 s | 135GB
33: (1.0, 14),
37: (1.0, 10),
41: (1.0, 10),
45: (1.0, 10),
@@ -36,21 +33,14 @@ bucket_config = {
53: (1.0, 10),
57: (1.0, 10),
61: (1.0, 10),
65: (
1.0,
10,
), # 6.79 s iter: 10.42 s | encode_video: 4.02 s | encode_text: 0.43 s | forward: 1.31 s | backward: 4.21 s | 125GB
69: (1.0, 7),
65: (1.0, 10),
73: (1.0, 7),
77: (1.0, 7),
81: (1.0, 7),
85: (1.0, 7),
89: (1.0, 7),
93: (1.0, 7),
97: (
1.0,
7,
), # 6.84 s iter: 5.26 s | encode_video: 2.16 s | encode_text: 0.16 s | forward: 0.64 s | backward: 2.08 s | 127GB
97: (1.0, 7),
101: (1.0, 6),
105: (1.0, 6),
109: (1.0, 6),
@@ -58,10 +48,7 @@ bucket_config = {
117: (1.0, 6),
121: (1.0, 6),
125: (1.0, 6),
129: (
1.0,
6,
), # 7.48 s iter: 9.67 s | encode_video: 3.78 s | encode_text: 0.21 s | forward: 1.36 s | backward: 2.78 s | 130.3GB
129: (1.0, 6),
},
"768px": {
1: (1.0, 38),
@@ -99,17 +86,9 @@ bucket_config = {
129: (1.0, 2),
},
}
pin_memory_cache_pre_alloc_numels = [(260 + 20) * 1024 * 1024] * 24 + [(34 + 20) * 1024 * 1024] * 4

model = dict(
from_pretrained=None,
grad_ckpt_settings=grad_ckpt_settings,
)
model = dict(grad_ckpt_settings=grad_ckpt_settings)
lr = 5e-5
optim = dict(
lr=lr,
)
ema_decay = 0.99
optim = dict(lr=lr)
ckpt_every = 200
keep_n_latest = 20
wandb_project = "mmdit-vo3"

+ 0
- 55
configs/diffusion/train/stage2_cache.py View File

@@ -1,55 +0,0 @@
_base_ = ["stage2.py"]

# Dataset settings
cached_text = False
cached_video = True
dataset = dict(
type="cached_video_text",
transform_name="resize_crop",
fps_max=24,
cached_text=cached_text,
cached_video=cached_video,
)


# == stage2: 256px ==
bucket_config = {
"768px": {
1: (1.0, 38),
5: (1.0, 6),
9: (1.0, 6),
13: (1.0, 6),
17: (1.0, 6),
21: (1.0, 6),
25: (1.0, 6),
29: (1.0, 6),
33: (1.0, 6),
37: (1.0, 4),
41: (1.0, 4),
45: (1.0, 4),
49: (1.0, 4),
53: (1.0, 4),
57: (1.0, 4),
61: (1.0, 4),
65: (1.0, 4),
69: (1.0, 3),
73: (1.0, 3),
77: (1.0, 3),
81: (1.0, 3),
85: (1.0, 3),
89: (1.0, 3),
93: (1.0, 3),
97: (1.0, 3),
101: (1.0, 2),
105: (1.0, 2),
109: (1.0, 2),
113: (1.0, 2),
117: (1.0, 2),
121: (1.0, 2),
125: (1.0, 2),
129: (1.0, 2),
},
}
pin_memory_cache_pre_alloc_numels = [(260 + 20) * 1024 * 1024] * 24 + [(34 + 20) * 1024 * 1024] * 4
# record_time = True
# record_barrier = True

+ 3
- 30
configs/diffusion/train/stage2_i2v.py View File

@@ -1,19 +1,14 @@
_base_ = ["stage2_v1.py"]
_base_ = ["stage2.py"]

# Define model components
model = dict(
cond_embed=True,
)
model = dict(cond_embed=True)
grad_ckpt_buffer_size = 25 * 1024**3

condition_config = dict(
t2v=1,
i2v_head=5,
i2v_loop=1,
i2v_tail=1,
# v2v_head=1, # 32
# v2v_tail=0.5, # 32
# v2v_head_easy=1, # 64
# v2v_tail_easy=0.5, # 64
)
is_causal_vae = True

@@ -21,8 +16,6 @@ bucket_config = {
"_delete_": True,
"256px": {
1: (1.0, 195),
# old: bs = 130, speed = 19s/it
# new: bs = 195, speed = 28s/it
5: (1.0, 80),
9: (1.0, 80),
13: (1.0, 80),
@@ -31,8 +24,6 @@ bucket_config = {
25: (1.0, 80),
29: (1.0, 80),
33: (1.0, 80),
# old: bs = 14, speed = 6.8s/it
# new: bs = 80, speed = 36.62s/it
37: (1.0, 40),
41: (1.0, 40),
45: (1.0, 40),
@@ -41,8 +32,6 @@ bucket_config = {
57: (1.0, 40),
61: (1.0, 40),
65: (1.0, 40),
# old: bs = 10 , speed = 10s/it
# new: bs = 40, speed = 39.49s/it
69: (1.0, 28),
73: (1.0, 28),
77: (1.0, 28),
@@ -51,8 +40,6 @@ bucket_config = {
89: (1.0, 28),
93: (1.0, 28),
97: (1.0, 28),
# old: bs = 7, speed = 10s/it
# new: bs = 28, speed = 35.39s/it
101: (1.0, 23),
105: (1.0, 23),
109: (1.0, 23),
@@ -61,8 +48,6 @@ bucket_config = {
121: (1.0, 23),
125: (1.0, 23),
129: (1.0, 23),
# old: bs = 6, speed = 11s/it
# new: bs = 23, speed = 40.6s/it
},
"768px": {
1: (0.5, 38),
@@ -74,7 +59,6 @@ bucket_config = {
25: (0.5, 10),
29: (0.5, 10),
33: (0.5, 10),
# speed = 60.20s/it
37: (0.5, 5),
41: (0.5, 5),
45: (0.5, 5),
@@ -83,7 +67,6 @@ bucket_config = {
57: (0.5, 5),
61: (0.5, 5),
65: (0.5, 5),
# speed = 49.25s/it
69: (0.5, 3),
73: (0.5, 3),
77: (0.5, 3),
@@ -92,7 +75,6 @@ bucket_config = {
89: (0.5, 3),
93: (0.5, 3),
97: (0.5, 3),
# speed = 50.69s/it
101: (0.5, 2),
105: (0.5, 2),
109: (0.5, 2),
@@ -101,14 +83,5 @@ bucket_config = {
121: (0.5, 2),
125: (0.5, 2),
129: (0.5, 2),
# speed = 47.24s/it
},
}


lr = 1e-5
optim = dict(
lr=lr,
)
ckpt_every = 200
async_io = False

+ 0
- 53
configs/diffusion/train/stage2_v1.py View File

@@ -1,53 +0,0 @@
_base_ = ["stage2.py"]

plugin_config = dict(
sp_size=4,
)
dataset = dict(
memory_efficient=True,
)
grad_ckpt_buffer_size = 25 * 1024**3
async_io = False
bucket_config = {
"_delete_": True,
"768px": {
1: (1.0, 38),
5: (1.0, 12),
9: (1.0, 12),
13: (1.0, 12),
17: (1.0, 12),
21: (1.0, 12),
25: (1.0, 12),
29: (1.0, 12),
33: (1.0, 12),
37: ((1.0, 0.9), 5),
41: ((1.0, 0.9), 5),
45: ((1.0, 0.9), 5),
49: ((1.0, 0.9), 5),
53: ((1.0, 0.9), 5),
57: ((1.0, 0.9), 5),
61: ((1.0, 0.9), 5),
65: ((1.0, 0.9), 5),
69: ((1.0, 0.975), 3),
73: ((1.0, 0.975), 3),
77: ((1.0, 0.975), 3),
81: ((1.0, 0.975), 3),
85: ((1.0, 0.975), 3),
89: ((1.0, 0.975), 3),
93: ((1.0, 0.975), 3),
97: ((1.0, 0.975), 3),
101: ((1.0, 0.992), 2),
105: ((1.0, 0.992), 2),
109: ((1.0, 0.992), 2),
113: ((1.0, 0.992), 2),
117: ((1.0, 0.992), 2),
121: ((1.0, 0.975), 2),
125: ((1.0, 0.975), 2),
129: ((1.0, 0.975), 2),
},
}
ema_decay = None

# cp=4, ckpt act mem: 20.79 GB, 33: (1.0, 12), 36.38s/it
# record_time = True
# record_barrier = True

+ 0
- 43
configs/diffusion/train/stage2_v2.py View File

@@ -1,43 +0,0 @@
_base_ = ["stage2_v1.py"]

bucket_config = {
"_delete_": True,
"768px": {
1: (1.0, 38),
5: (1.0, 12),
9: (1.0, 12),
13: (1.0, 12),
17: (1.0, 12),
21: (1.0, 12),
25: (1.0, 12),
29: (1.0, 12),
33: (1.0, 12),
37: (1.0, 5),
41: (1.0, 5),
45: (1.0, 5),
49: (1.0, 5),
53: (1.0, 5),
57: (1.0, 5),
61: (1.0, 5),
65: (1.0, 5),
69: (1.0, 3),
73: (1.0, 3),
77: (1.0, 3),
81: (1.0, 3),
85: (1.0, 3),
89: (1.0, 3),
93: (1.0, 3),
97: (1.0, 3),
101: (1.0, 2),
105: (1.0, 2),
109: (1.0, 2),
113: (1.0, 2),
117: (1.0, 2),
121: (1.0, 2),
125: (1.0, 2),
129: (1.0, 2),
},
}

record_time = True
# record_barrier = True

+ 0
- 35
configs/vae/train/causal_dcae.py View File

@@ -1,35 +0,0 @@
_base_ = ["video.py"]

dataset = dict(
rand_sample_interval=8,
)

bucket_config = {
"_delete_": True,
"256px_ar1:1": {33: (1.0, 2)},
}

vae_loss_config = dict(
perceptual_loss_weight=0.1,
kl_loss_weight=1e-6,
)

opl_loss_weight = 0

model = dict(
_delete_=True,
type="hunyuan_vae",
from_pretrained=None,
in_channels=3,
out_channels=3,
layers_per_block=2,
latent_channels=32,
channel=True,
time_compression_ratio=4,
spatial_compression_ratio=32,
block_out_channels=(128, 128, 256, 256, 512, 512),
encoder_add_residual=True,
decoder_add_residual=True,
encoder_slice_t=False,
decoder_slice_t=True,
)

+ 83
- 0
configs/vae/train/video_dc_ae.py View File

@@ -0,0 +1,83 @@
# ============
# model config
# ============
model = dict(
type="dc_ae",
model_name="dc-ae-f32t4c128",
from_scratch=True,
from_pretrained=None,
)

# ============
# data config
# ============
dataset = dict(
type="video_text",
transform_name="resize_crop",
fps_max=24,
)

bucket_config = {
"256px_ar1:1": {32: (1.0, 1)},
}

num_bucket_build_workers = 64
num_workers = 12
prefetch_factor = 2

# ============
# train config
# ============
optim = dict(
cls="HybridAdam",
lr=5e-5,
eps=1e-8,
weight_decay=0.0,
adamw_mode=True,
betas=(0.9, 0.98),
)
lr_scheduler = dict(warmup_steps=0)

mixed_strategy = "mixed_video_image"
mixed_image_ratio = 0.2 # 1:4

dtype = "bf16"
plugin = "zero2"
plugin_config = dict(
reduce_bucket_size_in_m=128,
overlap_allgather=False,
)

grad_clip = 1.0
grad_checkpoint = False
pin_memory_cache_pre_alloc_numels = [50 * 1024 * 1024] * num_workers * prefetch_factor

seed = 42
outputs = "outputs"
epochs = 100
log_every = 10
ckpt_every = 3000
keep_n_latest = 50
ema_decay = 0.99
wandb_project = "dcae"

discriminator = None
disc_lr_scheduler = None
optim_discriminator = None

update_warmup_steps = True

# ============
# loss config
# ============
opl_loss_weight = 0

vae_loss_config = dict(
perceptual_loss_weight=0.5,
kl_loss_weight=0,
logvar_init=0,
)

gen_loss_config = None
disc_loss_config = None


+ 33
- 0
configs/vae/train/video_dc_ae_disc.py View File

@@ -0,0 +1,33 @@
_base_ = ["video_dc_ae.py"]

discriminator = dict(
_delete_=True,
type="N_Layer_discriminator_3D",
from_pretrained=None,
input_nc=3,
n_layers=5,
conv_cls="conv3d"
)
disc_lr_scheduler = dict(warmup_steps=0)

gen_loss_config = dict(
gen_start=0,
disc_factor=1,
disc_weight=0.05,
)

disc_loss_config = dict(
disc_start=0,
disc_factor=1.0,
disc_loss_type="hinge",
)

optim_discriminator = dict(
cls="HybridAdam",
lr=1e-4,
eps=1e-8,
weight_decay=0.0,
adamw_mode=True,
betas=(0.9, 0.98),
)


+ 18
- 0
docs/ae.md View File

@@ -0,0 +1,18 @@
# Step by step to train and evaluate an video autoencoder

## Installation

```
pip install diffusers==0.31.0
```

## Training
The command to launch training is as follows:
```
torchrun --nproc_per_node 8 scripts/vae/train_video_sana.py configs/vae/train/video_dc_ae_train_channel_proj.py --dataset.data-path /home/zhengzangwei/Open-Sora/datasets/pexels_45k_necessary.csv --model.model_name dc-ae-f32t4c128 --wandb True --optim.lr 5e-5 --wandb-project dcae --vae_loss_config.perceptual_loss_weight 0.5 --wandb-expr-name dc-ae-f32t4c64_full_model_ft

```

## Inference

## Config Interpretation

+ 11
- 0
docs/hcae.md View File

@@ -0,0 +1,11 @@
# Visit the high compression video autoencoder

## Introduction

## Traini

torchrun --nproc_per_node 1 --standalone scripts/diffusion/inference.py configs/diffusion/inference/high_compression.py --dataset.data-path assets/texts/sora.csv --ckpt-path /mnt/jfs-hdd/sora/checkpoints/zhengzangwei/outputs/adapt_video_sana/250304_141109-diffusion_train_dc_ae_video_temporal_compression_i2v/epoch0-global_step2000 --save-dir samples/debug_03_06/adapt_2k_op --sampling_option.num_frames 128



torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/high_compression.py --dataset.data-path /mnt/ddn/sora/meta/vo3/stage2/video+image_stage2_nopart3.parquet --wandb True --wandb-project adapt_video_sana --load /mnt/jfs-hdd/sora/checkpoints/zhengzangwei/outputs/adapt_video_sana/250304_141109-diffusion_train_dc_ae_video_temporal_compression_i2v/epoch0-global_step2000 --outputs /mnt/jfs-hdd/sora/checkpoints/zhengzangwei/outputs/adapt_video_sana/ --start-step 0 --start-epoch 0 --seed 2026

+ 201
- 0
docs/train.md View File

@@ -0,0 +1,201 @@
# Step by step to train or finetune your own model

## Installation

Besides from the installation in the main page, you need to install the following packages:

```bash
pip install git+https://github.com/hpcaitech/TensorNVMe.git # requires cmake, for checkpoint saving
pip install pandarallel # for parallel processing
```

## Prepare dataset

The dataset should be presented in a `csv` or `parquet` file. To better illustrate the process, we will use a 45k [pexels dataset](https://huggingface.co/datasets/hpcai-tech/open-sora-pexels-45k) as an example. This dataset contains clipped, score filtered high-quality videos from [Pexels](https://www.pexels.com/).

First, download the dataset to your local machine:

```bash
mkdir datasets
cd datasets
# For Chinese users, export HF_ENDPOINT=https://hf-mirror.com to speed up the download
huggingface-cli download --repo-type dataset hpcai-tech/open-sora-pexels-45k --local-dir open-sora-pexels-45k # 250GB

cd open-sora-pexels-45k
cat tar/pexels_45k.tar.* > pexels_45k.tar
tar -xvf pexels_45k.tar
mv pexels_45k .. # make sure the path is Open-Sora/datasets/pexels_45k
```

There are three `csv` files provided:

- `pexels_45k.csv`: contains only path and text, which needs to be processed for training.
- `pexels_45k_necessary.csv`: contains necessary information for training.
- `pexels_45k_score.csv`: contains score information for each video. The 45k videos are filtered out based on the score. See tech report for more details.

If you want to use custom dataset, at least the following columns are required:

```csv
path,text,num_frames,height,width,aspect_ratio,resolution,fps
```

We provide a script to process the `pexels_45k.csv` to `pexels_45k_necessary.csv`:

```bash
# single process
python scripts/cnv/meta.py --input datasets/pexels_45k.csv --output datasets/pexels_45k_nec.csv --num_workers 0
# parallel process
python scripts/cnv/meta.py --input datasets/pexels_45k.csv --output datasets/pexels_45k_nec.csv --num_workers 64
```

> The process may take a while, depending on the number of videos in the dataset. The process is neccessary for training on arbitrary aspect ratio, resolution, and number of frames.

## Training

The command format to launch training is as follows:

```bash
torchrun --nproc_per_node 8 scripts/diffusion/train.py [path/to/config] --dataset.data-path [path/to/dataset] [override options]
```

For example, to train a model with stage 1 config from scratch using pexels dataset:

```bash
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage1.py --dataset.data-path datasets/pexels_45k_necessary.csv
```

### Config

All configs are located in `configs/diffusion/train/`. The following rules are applied:

- `_base_ = ["config_to_inherit"]`: inherit from another config by mmengine's support. Variables are overwritten by the new config. Dictionary is merged if `_delete_` key is not present.
- command line arguments override the config file. For example, `--lr 1e-5` will override the `lr` in the config file. `--dataset.data-path datasets/pexels_45k_necessary.csv` will override the `data-path` value in the dictionary `dataset`.

The `bucket_config` is used to control different training stages. It is a dictionary of dictionaries. The tuple means (sampling probability, batch size). For example:

```python
bucket_config = {
"256px": {
1: (1.0, 45), # for 256px images, use 100% of the data with batch size 45
33: (1.0, 12), # for 256px videos with no less than 33 frames, use 100% of the data with batch size 12
65: (1.0, 6), # for 256px videos with no less than 65 frames, use 100% of the data with batch size 6
97: (1.0, 4), # for 256px videos with no less than 97 frames, use 100% of the data with batch size 4
129: (1.0, 3), # for 256px videos with no less than 129 frames, use 100% of the data with batch size 3
},
"768px": {
1: (0.5, 13), # for 768px images, use 50% of the data with batch size 13
},
"1024px": {
1: (0.5, 7), # for 1024px images, use 50% of the data with batch size 7
},
}
```

We provide the following configs, the batch size is searched on H200 GPUs with 140GB memory:

- `image.py`: train on images only.
- `stage1.py`: train on videos with 256px resolution.
- `stage2.py`: train on videos with 768px resolution with sequence parallelism (default 4).
- `stage1_i2v.py`: train t2v and i2v with 256px resolution.
- `stage2_i2v.py`: train t2v and i2v with 768px resolution.

We also provide a demo config `demo.py` with small batch size for debugging.

### Fine-tuning

To finetune from Open-Sora v2, run:

```bash
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage1.py --dataset.data-path datasets/pexels_45k_necessary.csv --model.from_pretrained ckpts/Open_Sora_v2.safetensors
```

To finetune from flux-dev, we provided a transformed flux-dev [ckpts](https://huggingface.co/hpcai-tech/flux1-dev-fused-rope). Download it to `ckpts` and run:

```bash
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage1.py --dataset.data-path datasets/pexels_45k_necessary.csv --model.from_pretrained ckpts/flux1-dev-fused-rope.safetensors
```

### Multi-GPU

To train on multiple GPUs, use `colossalai run`:

```bash
colossalai run --hostfile hostfiles --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage1.py --dataset.data-path datasets/pexels_45k_necessary.csv --model.from_pretrained ckpts/Open_Sora_v2.safetensors
```

`hostfiles` is a file that contains the IP addresses of the nodes. For example:

```bash
xxx.xxx.xxx.xxx
yyy.yyy.yyy.yyy
zzz.zzz.zzz.zzz
```

use `--wandb True` to log the training process to [wandb](https://wandb.ai/).

### Resume training

To resume training, use `--load`. It will load the optimizer state and dataloader state.

```bash
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage1.py --dataset.data-path datasets/pexels_45k_necessary.csv --load outputs/your_experiment/epoch*-global_step*
```

If you want to load optimzer state but not dataloader state, use:

```bash
torchrun --nproc_per_node 8 scripts/diffusion/train.py configs/diffusion/train/stage1.py --dataset.data-path datasets/pexels_45k_necessary.csv --load outputs/your_experiment/epoch*-global_step* --start-step 0 --start-epoch 0
```

> Note if dataset, batch size, and number of GPUs are changed, the dataloader state will not be meaningful.

## Inference

The inference is the same as described in the main page. The command format is as follows:

```bash
torchrun --nproc_per_node 1 --standalone scripts/diffusion/inference.py configs/diffusion/inference/t2i2v_256px.py --save-dir samples --prompt "raining, sea" --model.from_pretrained outputs/your_experiment/epoch*-global_step*
```

## Advanced Usage

More details are provided in the tech report. If explanation for some techiques is needed, feel free to open an issue.

- Tensor parallelism and sequence parallelism
- Zero 2
- Pin memory organization
- Garbage collection organization
- Data prefetching
- Communication bucket optimization
- Shardformer for T5

### Gradient Checkpointing

We support selective gradient checkpointing to save memory. The `grad_ckpt_setting` is a tuple, the first element is the number of dual layers to apply gradient checkpointing, the second element is the number of single layers to apply full gradient. A very large number will apply full gradient to all layers.

```python
grad_ckpt_setting = (100, 100)
model = dict(
grad_ckpt_setting=grad_ckpt_setting,
)
```

To further save memory, you can offload gradient checkpointing to CPU by:

```python
grad_ckpt_buffer_size = 25 * 1024**3 # 25GB
```

### Asynchronous Checkpoint Saving

With `--async-io True`, the checkpoint will be saved asynchronously with the support of ColossalAI. This will save time for checkpoint saving.

### Dataset

With a very large dataset, the `csv` file or even `parquet` file may be too large to fit in memory. We provide a script to split the dataset into smaller chunks:

```bash
python scripts/cnv/shard.py /path/to/dataset.parquet
```

Then a folder with shards will be created. You can use the `--dataset.memory_efficient True` to load the dataset shard by shard.

+ 3
- 1
opensora/acceleration/parallel_states.py View File

@@ -7,7 +7,9 @@ def set_data_parallel_group(group: dist.ProcessGroup):
_GLOBAL_PARALLEL_GROUPS["data"] = group


def get_data_parallel_group():
def get_data_parallel_group(get_mixed_dp_pg : bool = False):
if get_mixed_dp_pg and "mixed_dp_group" in _GLOBAL_PARALLEL_GROUPS:
return _GLOBAL_PARALLEL_GROUPS["mixed_dp_group"]
return _GLOBAL_PARALLEL_GROUPS.get("data", dist.group.WORLD)




+ 24
- 8
opensora/datasets/aspect.py View File

@@ -1,7 +1,6 @@
import math
import os

D = int(os.environ.get("VO_ASPECT_DIV", 16))
ASPECT_RATIO_LD_LIST = [ # width:height
"2.39:1", # cinemascope, 2.39
"2:1", # rare, 2
@@ -20,7 +19,10 @@ def get_ratio(name: str) -> float:
return height / width


def get_aspect_ratios_dict(total_pixels: int = 256 * 256, training: bool = True) -> dict[str, tuple[int, int]]:
def get_aspect_ratios_dict(
total_pixels: int = 256 * 256, training: bool = True
) -> dict[str, tuple[int, int]]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
aspect_ratios_dict = {}
aspect_ratios_vertical_dict = {}
for ratio in ASPECT_RATIO_LD_LIST:
@@ -58,6 +60,7 @@ def get_num_pexels(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str,


def get_num_tokens(aspect_ratios_dict: dict[str, tuple[int, int]]) -> dict[str, int]:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return {ratio: h * w // D // D for ratio, (h, w) in aspect_ratios_dict.items()}


@@ -74,7 +77,9 @@ def get_num_pexels_from_name(resolution: str) -> int:
return num_pexels


def get_resolution_with_aspect_ratio(resolution: str) -> tuple[int, dict[str, tuple[int, int]]]:
def get_resolution_with_aspect_ratio(
resolution: str,
) -> tuple[int, dict[str, tuple[int, int]]]:
"""Get resolution with aspect ratio

Args:
@@ -90,7 +95,9 @@ def get_resolution_with_aspect_ratio(resolution: str) -> tuple[int, dict[str, tu
setting = ""
else:
resolution, setting = keys
assert setting == "max" or setting.startswith("ar"), f"Invalid setting {setting}"
assert setting == "max" or setting.startswith(
"ar"
), f"Invalid setting {setting}"

# get resolution
num_pexels = get_num_pexels_from_name(resolution)
@@ -100,11 +107,16 @@ def get_resolution_with_aspect_ratio(resolution: str) -> tuple[int, dict[str, tu

# handle setting
if setting == "max":
aspect_ratio = max(aspect_ratio_dict, key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1])
aspect_ratio = max(
aspect_ratio_dict,
key=lambda x: aspect_ratio_dict[x][0] * aspect_ratio_dict[x][1],
)
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}
elif setting.startswith("ar"):
aspect_ratio = setting[2:]
assert aspect_ratio in aspect_ratio_dict, f"Aspect ratio {aspect_ratio} not found"
assert (
aspect_ratio in aspect_ratio_dict
), f"Aspect ratio {aspect_ratio} not found"
aspect_ratio_dict = {aspect_ratio: aspect_ratio_dict[aspect_ratio]}

return num_pexels, aspect_ratio_dict
@@ -112,11 +124,15 @@ def get_resolution_with_aspect_ratio(resolution: str) -> tuple[int, dict[str, tu

def get_closest_ratio(height: float, width: float, ratios: dict) -> str:
aspect_ratio = height / width
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio)))
closest_ratio = min(
ratios.keys(), key=lambda ratio: abs(aspect_ratio - get_ratio(ratio))
)
return closest_ratio


def get_image_size(resolution: str, ar_ratio: str, training: bool = True) -> tuple[int, int]:
def get_image_size(
resolution: str, ar_ratio: str, training: bool = True
) -> tuple[int, int]:
num_pexels = get_num_pexels_from_name(resolution)
ar_dict = get_aspect_ratios_dict(num_pexels, training)
assert ar_ratio in ar_dict, f"Aspect ratio {ar_ratio} not found"


+ 1
- 1
opensora/models/dc_ae/__init__.py View File

@@ -1 +1 @@
from .efficientvit.ae_model_zoo import DC_AE
from .ae_model_zoo import DC_AE

opensora/models/dc_ae/efficientvit/ae_model_zoo.py → opensora/models/dc_ae/ae_model_zoo.py View File

@@ -24,22 +24,28 @@ from torch import nn
from opensora.registry import MODELS
from opensora.utils.ckpt import load_checkpoint

from .models.efficientvit.dc_ae import DCAE, DCAEConfig, dc_ae_f32c32, dc_ae_f64c128, dc_ae_f128c512
from .models.dc_ae import DCAE, DCAEConfig, dc_ae_f32, dc_ae_f64c128, dc_ae_f64t4c256, dc_ae_f128c512

__all__ = ["create_dc_ae_model_cfg", "DCAE_HF", "AutoencoderKL", "DC_AE"]


REGISTERED_DCAE_MODEL: dict[str, tuple[Callable, Optional[str]]] = {
"dc-ae-f32c32-in-1.0": (dc_ae_f32c32, None),
"dc-ae-f32c32-in-1.0": (dc_ae_f32, None),
"dc-ae-f64c128-in-1.0": (dc_ae_f64c128, None),
"dc-ae-f128c512-in-1.0": (dc_ae_f128c512, None),
#################################################################################################
"dc-ae-f32c32-mix-1.0": (dc_ae_f32c32, None),
"dc-ae-f32c32-mix-1.0": (dc_ae_f32, None),
"dc-ae-f64c128-mix-1.0": (dc_ae_f64c128, None),
"dc-ae-f128c512-mix-1.0": (dc_ae_f128c512, None),
#################################################################################################
"dc-ae-f32c32-sana-1.0": (dc_ae_f32c32, None),
"dc-ae-f128c512-sana-1.0": (dc_ae_f128c512, None),
"dc-ae-f32c32-sana-1.0": (dc_ae_f32, None),
"dc-ae-f32c32-sana-1.0-video": (dc_ae_f32, None),
"dc-ae-f32c32-sana-1.0-video-temporal-compression": (dc_ae_f32, None),
"dc-ae-f32t4c256": (dc_ae_f32, None),
"dc-ae-f32t4c128": (dc_ae_f32, None),
"dc-ae-f32t4c64": (dc_ae_f32, None),
"dc-ae-f64t4c128": (dc_ae_f64c128, None),
"dc-ae-f64t4c256": (dc_ae_f64t4c256, None),
}


@@ -64,6 +70,16 @@ def DC_AE(
torch_dtype: torch.dtype = torch.bfloat16,
from_scratch: bool = False,
from_pretrained: str | None = None,
is_training: bool = False,
rename_keys: dict = None,
use_spatial_tiling: bool = False,
use_temporal_tiling: bool = False,
spatial_tile_size: int = 256,
temporal_tile_size: int = 32,
tile_overlap_factor: float = 0.25,
tune_channel_proj: bool = False,
train_decoder_only: bool = False,
scaling_factor: float = None,
) -> DCAE_HF:
if not from_scratch:
model = DCAE_HF.from_pretrained(model_name).to(device_map, torch_dtype)
@@ -71,7 +87,28 @@ def DC_AE(
model = DCAE_HF(model_name).to(device_map, torch_dtype)

if from_pretrained is not None:
model = load_checkpoint(model, from_pretrained, device_map=device_map)
model = load_checkpoint(model, from_pretrained, device_map=device_map, rename_keys=rename_keys)
print(f"loaded dc_ae from ckpt path: {from_pretrained}")

if train_decoder_only:
print("freezing all encoder weights!")
# tune the channel projection layer only
if tune_channel_proj:
model.cfg.tune_channel_proj = True
print("freezing all weights except the channel projection!")
for _, param in model.named_parameters():
param.requires_grad = False
for _, param in model.decoder.named_parameters():
param.requires_grad = True

model.cfg.is_training = is_training
model.use_spatial_tiling = use_spatial_tiling
model.use_temporal_tiling = use_temporal_tiling
model.spatial_tile_size = spatial_tile_size
model.temporal_tile_size = temporal_tile_size
model.tile_overlap_factor = tile_overlap_factor
if scaling_factor is not None:
model.scaling_factor = scaling_factor
return model



+ 0
- 0
opensora/models/dc_ae/efficientvit/__init__.py View File


+ 0
- 0
opensora/models/dc_ae/efficientvit/apps/__init__.py View File


+ 0
- 102
opensora/models/dc_ae/efficientvit/apps/setup.py View File

@@ -1,102 +0,0 @@
import os
import time
from copy import deepcopy
from typing import Optional

import torch.backends.cudnn
import torch.distributed
import torch.nn as nn

from ..apps.utils import (
dist_init,
dump_config,
get_dist_local_rank,
get_dist_rank,
init_modules,
is_master,
load_config,
partial_update_config,
zero_last_gamma,
)
from ..models.utils import load_state_dict_from_file

__all__ = [
"save_exp_config",
"setup_dist_env",
"setup_seed",
"setup_exp_config",
"init_model",
]


def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
if not is_master():
return
dump_config(exp_config, os.path.join(path, name))


def setup_dist_env(gpu: Optional[str] = None) -> None:
if gpu is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
if not torch.distributed.is_initialized():
dist_init()
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(get_dist_local_rank())


def setup_seed(manual_seed: int, resume: bool) -> None:
if resume:
manual_seed = int(time.time())
manual_seed = get_dist_rank() + manual_seed
torch.manual_seed(manual_seed)
torch.cuda.manual_seed_all(manual_seed)


def setup_exp_config(config_path: str, recursive=True, opt_args: Optional[dict] = None) -> dict:
# load config
if not os.path.isfile(config_path):
raise ValueError(config_path)

fpaths = [config_path]
if recursive:
extension = os.path.splitext(config_path)[1]
while os.path.dirname(config_path) != config_path:
config_path = os.path.dirname(config_path)
fpath = os.path.join(config_path, "default" + extension)
if os.path.isfile(fpath):
fpaths.append(fpath)
fpaths = fpaths[::-1]

default_config = load_config(fpaths[0])
exp_config = deepcopy(default_config)
for fpath in fpaths[1:]:
partial_update_config(exp_config, load_config(fpath))
# update config via args
if opt_args is not None:
partial_update_config(exp_config, opt_args)

return exp_config


def init_model(
network: nn.Module,
init_from: Optional[str] = None,
backbone_init_from: Optional[str] = None,
rand_init="trunc_normal",
last_gamma=None,
) -> None:
# initialization
init_modules(network, init_type=rand_init)
# zero gamma of last bn in each block
if last_gamma is not None:
zero_last_gamma(network, last_gamma)

# load weight
if init_from is not None and os.path.isfile(init_from):
network.load_state_dict(load_state_dict_from_file(init_from))
print(f"Loaded init from {init_from}")
elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
print(f"Loaded backbone init from {backbone_init_from}")
else:
print(f"Random init ({rand_init}) with last gamma {last_gamma}")

+ 0
- 1
opensora/models/dc_ae/efficientvit/apps/trainer/__init__.py View File

@@ -1 +0,0 @@
from .run_config import *

+ 0
- 128
opensora/models/dc_ae/efficientvit/apps/trainer/run_config.py View File

@@ -1,128 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Any

import numpy as np
import torch.nn as nn

from ...apps.utils import CosineLRwithWarmup, build_optimizer

__all__ = ["Scheduler", "RunConfig"]


class Scheduler:
PROGRESS = 0


class RunConfig:
n_epochs: int
init_lr: float
warmup_epochs: int
warmup_lr: float
lr_schedule_name: str
lr_schedule_param: dict
optimizer_name: str
optimizer_params: dict
weight_decay: float
no_wd_keys: list
grad_clip: float # allow none to turn off grad clipping
reset_bn: bool
reset_bn_size: int
reset_bn_batch_size: int
eval_image_size: list # allow none to use image_size in data_provider

@property
def none_allowed(self):
return ["grad_clip", "eval_image_size"]

def __init__(self, **kwargs): # arguments must be passed as kwargs
for k, val in kwargs.items():
setattr(self, k, val)

# check that all relevant configs are there
annotations = {}
for clas in type(self).mro():
if hasattr(clas, "__annotations__"):
annotations.update(clas.__annotations__)
for k, k_type in annotations.items():
assert hasattr(self, k), f"Key {k} with type {k_type} required for initialization."
attr = getattr(self, k)
if k in self.none_allowed:
k_type = (k_type, type(None))
assert isinstance(attr, k_type), f"Key {k} must be type {k_type}, provided={attr}."

self.global_step = 0
self.batch_per_epoch = 1

def build_optimizer(self, network: nn.Module) -> tuple[Any, Any]:
r"""require setting 'batch_per_epoch' before building optimizer & lr_scheduler"""
param_dict = {}
for name, param in network.named_parameters():
if param.requires_grad:
opt_config = [self.weight_decay, self.init_lr]
if self.no_wd_keys is not None and len(self.no_wd_keys) > 0:
if np.any([key in name for key in self.no_wd_keys]):
opt_config[0] = 0
opt_key = json.dumps(opt_config)
param_dict[opt_key] = param_dict.get(opt_key, []) + [param]

net_params = []
for opt_key, param_list in param_dict.items():
wd, lr = json.loads(opt_key)
net_params.append({"params": param_list, "weight_decay": wd, "lr": lr})

optimizer = build_optimizer(net_params, self.optimizer_name, self.optimizer_params, self.init_lr)
# build lr scheduler
if self.lr_schedule_name == "cosine":
decay_steps = []
for epoch in self.lr_schedule_param.get("step", []):
decay_steps.append(epoch * self.batch_per_epoch)
decay_steps.append(self.n_epochs * self.batch_per_epoch)
decay_steps.sort()
lr_scheduler = CosineLRwithWarmup(
optimizer,
self.warmup_epochs * self.batch_per_epoch,
self.warmup_lr,
decay_steps,
)
else:
raise NotImplementedError
return optimizer, lr_scheduler

def update_global_step(self, epoch, batch_id=0) -> None:
self.global_step = epoch * self.batch_per_epoch + batch_id
Scheduler.PROGRESS = self.progress

@property
def progress(self) -> float:
warmup_steps = self.warmup_epochs * self.batch_per_epoch
steps = max(0, self.global_step - warmup_steps)
return steps / (self.n_epochs * self.batch_per_epoch)

def step(self) -> None:
self.global_step += 1
Scheduler.PROGRESS = self.progress

def get_remaining_epoch(self, epoch, post=True) -> int:
return self.n_epochs + self.warmup_epochs - epoch - int(post)

def epoch_format(self, epoch: int) -> str:
epoch_format = f"%.{len(str(self.n_epochs))}d"
epoch_format = f"[{epoch_format}/{epoch_format}]"
epoch_format = epoch_format % (epoch + 1 - self.warmup_epochs, self.n_epochs)
return epoch_format

+ 0
- 10
opensora/models/dc_ae/efficientvit/apps/utils/__init__.py View File

@@ -1,10 +0,0 @@
from .dist import *
from .ema import *

# from .export import *
from .image import *
from .init import *
from .lr import *
from .metric import *
from .misc import *
from .opt import *

+ 0
- 91
opensora/models/dc_ae/efficientvit/apps/utils/dist.py View File

@@ -1,91 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Union

import torch
import torch.distributed

from ...models.utils.list import list_mean, list_sum

__all__ = [
"dist_init",
"is_dist_initialized",
"get_dist_rank",
"get_dist_size",
"is_master",
"dist_barrier",
"get_dist_local_rank",
"sync_tensor",
]


def dist_init() -> None:
if is_dist_initialized():
return
try:
torch.distributed.init_process_group(backend="nccl")
assert torch.distributed.is_initialized()
except Exception:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
print("warning: dist not init")


def is_dist_initialized() -> bool:
return torch.distributed.is_initialized()


def get_dist_rank() -> int:
return int(os.environ["RANK"])


def get_dist_size() -> int:
return int(os.environ["WORLD_SIZE"])


def is_master() -> bool:
return get_dist_rank() == 0


def dist_barrier() -> None:
if is_dist_initialized():
torch.distributed.barrier()


def get_dist_local_rank() -> int:
return int(os.environ["LOCAL_RANK"])


def sync_tensor(tensor: Union[torch.Tensor, float], reduce="mean") -> Union[torch.Tensor, list[torch.Tensor]]:
if not is_dist_initialized():
return tensor
if not isinstance(tensor, torch.Tensor):
tensor = torch.Tensor(1).fill_(tensor).cuda()
tensor_list = [torch.empty_like(tensor) for _ in range(get_dist_size())]
torch.distributed.all_gather(tensor_list, tensor.contiguous(), async_op=False)
if reduce == "mean":
return list_mean(tensor_list)
elif reduce == "sum":
return list_sum(tensor_list)
elif reduce == "cat":
return torch.cat(tensor_list, dim=0)
elif reduce == "root":
return tensor_list[0]
else:
return tensor_list

+ 0
- 54
opensora/models/dc_ae/efficientvit/apps/utils/ema.py View File

@@ -1,54 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import copy
import math

import torch
import torch.nn as nn

from ...models.utils import is_parallel

__all__ = ["EMA"]


def update_ema(ema: nn.Module, new_state_dict: dict[str, torch.Tensor], decay: float) -> None:
for k, v in ema.state_dict().items():
if v.dtype.is_floating_point:
v -= (1.0 - decay) * (v - new_state_dict[k].detach())


class EMA:
def __init__(self, model: nn.Module, decay: float, warmup_steps=2000):
self.shadows = copy.deepcopy(model.module if is_parallel(model) else model).eval()
self.decay = decay
self.warmup_steps = warmup_steps

for p in self.shadows.parameters():
p.requires_grad = False

def step(self, model: nn.Module, global_step: int) -> None:
with torch.no_grad():
msd = (model.module if is_parallel(model) else model).state_dict()
update_ema(self.shadows, msd, self.decay * (1 - math.exp(-global_step / self.warmup_steps)))

def state_dict(self) -> dict[float, dict[str, torch.Tensor]]:
return {self.decay: self.shadows.state_dict()}

def load_state_dict(self, state_dict: dict[float, dict[str, torch.Tensor]]) -> None:
for decay in state_dict:
if decay == self.decay:
self.shadows.load_state_dict(state_dict[decay])

+ 0
- 58
opensora/models/dc_ae/efficientvit/apps/utils/export.py View File

@@ -1,58 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import io
import os
from typing import Any

import onnx
import torch
import torch.nn as nn
from onnxsim import simplify as simplify_func

__all__ = ["export_onnx"]


def export_onnx(model: nn.Module, export_path: str, sample_inputs: Any, simplify=True, opset=11) -> None:
"""Export a model to a platform-specific onnx format.

Args:
model: a torch.nn.Module object.
export_path: export location.
sample_inputs: Any.
simplify: a flag to turn on onnx-simplifier
opset: int
"""
model.eval()

buffer = io.BytesIO()
with torch.no_grad():
torch.onnx.export(model, sample_inputs, buffer, opset_version=opset)
buffer.seek(0, 0)
if simplify:
onnx_model = onnx.load_model(buffer)
onnx_model, success = simplify_func(onnx_model)
assert success
new_buffer = io.BytesIO()
onnx.save(onnx_model, new_buffer)
buffer = new_buffer
buffer.seek(0, 0)

if buffer.getbuffer().nbytes > 0:
save_dir = os.path.dirname(export_path)
os.makedirs(save_dir, exist_ok=True)
with open(export_path, "wb") as f:
f.write(buffer.read())

+ 0
- 190
opensora/models/dc_ae/efficientvit/apps/utils/image.py View File

@@ -1,190 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import os
import pathlib
from typing import Any, Callable, Optional, Union

import numpy as np
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.datasets import ImageFolder

__all__ = ["load_image", "load_image_from_dir", "DMCrop", "CustomImageFolder", "ImageDataset"]


def load_image(data_path: str, mode="rgb") -> Image.Image:
img = Image.open(data_path)
if mode == "rgb":
img = img.convert("RGB")
return img


def load_image_from_dir(
dir_path: str,
suffix: Union[str, tuple[str, ...], list[str]] = (".jpg", ".JPEG", ".png"),
return_mode="path",
k: Optional[int] = None,
shuffle_func: Optional[Callable] = None,
) -> Union[list, tuple[list, list]]:
suffix = [suffix] if isinstance(suffix, str) else suffix

file_list = []
for dirpath, _, fnames in os.walk(dir_path):
for fname in fnames:
if pathlib.Path(fname).suffix not in suffix:
continue
image_path = os.path.join(dirpath, fname)
file_list.append(image_path)

if shuffle_func is not None and k is not None:
shuffle_file_list = shuffle_func(file_list)
file_list = shuffle_file_list or file_list
file_list = file_list[:k]

file_list = sorted(file_list)

if return_mode == "path":
return file_list
else:
files = []
path_list = []
for file_path in file_list:
try:
files.append(load_image(file_path))
path_list.append(file_path)
except Exception:
print(f"Fail to load {file_path}")
if return_mode == "image":
return files
else:
return path_list, files


class DMCrop:
"""center/random crop used in diffusion models"""

def __init__(self, size: int) -> None:
self.size = size

def __call__(self, pil_image: Image.Image) -> Image.Image:
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
image_size = self.size
if pil_image.size == (image_size, image_size):
return pil_image

while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)

scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)

arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])


class CustomImageFolder(ImageFolder):
def __init__(self, root: str, transform: Optional[Callable] = None, return_dict: bool = False):
root = os.path.expanduser(root)
self.return_dict = return_dict
super().__init__(root, transform)

def __getitem__(self, index: int) -> Union[dict[str, Any], tuple[Any, Any]]:
path, target = self.samples[index]
image = load_image(path)
if self.transform is not None:
image = self.transform(image)
if self.return_dict:
return {
"index": index,
"image_path": path,
"image": image,
"label": target,
}
else:
return image, target


class ImageDataset(Dataset):
def __init__(
self,
data_dirs: Union[str, list[str]],
splits: Optional[Union[str, list[Optional[str]]]] = None,
transform: Optional[Callable] = None,
suffix=(".jpg", ".JPEG", ".png"),
pil=True,
return_dict=True,
) -> None:
super().__init__()

self.data_dirs = [data_dirs] if isinstance(data_dirs, str) else data_dirs
if isinstance(splits, list):
assert len(splits) == len(self.data_dirs)
self.splits = splits
elif isinstance(splits, str):
assert len(self.data_dirs) == 1
self.splits = [splits]
else:
self.splits = [None for _ in range(len(self.data_dirs))]

self.transform = transform
self.pil = pil
self.return_dict = return_dict

# load all images [image_path]
self.samples = []
for data_dir, split in zip(self.data_dirs, self.splits):
if split is None:
samples = load_image_from_dir(data_dir, suffix, return_mode="path")
else:
samples = []
with open(split) as fin:
for line in fin.readlines():
relative_path = line[:-1]
full_path = os.path.join(data_dir, relative_path)
samples.append(full_path)
self.samples += samples

def __len__(self) -> int:
return len(self.samples)

def __getitem__(self, index: int, skip_image=False) -> dict[str, Any]:
image_path = self.samples[index]

if skip_image:
image = None
else:
try:
image = load_image(image_path, return_pil=self.pil)
except Exception:
print(f"Fail to load {image_path}")
raise OSError
if self.transform is not None:
image = self.transform(image)
if self.return_dict:
return {
"index": index,
"image_path": image_path,
"image_name": os.path.basename(image_path),
"data": image,
}
else:
return image

+ 0
- 79
opensora/models/dc_ae/efficientvit/apps/utils/lr.py View File

@@ -1,79 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Union

import torch

from ...models.utils.list import val2list

__all__ = ["CosineLRwithWarmup", "ConstantLRwithWarmup"]


class CosineLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
warmup_lr: float,
decay_steps: Union[int, list[int]],
last_epoch: int = -1,
) -> None:
self.warmup_steps = warmup_steps
self.warmup_lr = warmup_lr
self.decay_steps = val2list(decay_steps)
super().__init__(optimizer, last_epoch)

def get_lr(self) -> list[float]:
if self.last_epoch < self.warmup_steps:
return [
(base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
for base_lr in self.base_lrs
]
else:
current_steps = self.last_epoch - self.warmup_steps
decay_steps = [0] + self.decay_steps
idx = len(decay_steps) - 2
for i, decay_step in enumerate(decay_steps[:-1]):
if decay_step <= current_steps < decay_steps[i + 1]:
idx = i
break
current_steps -= decay_steps[idx]
decay_step = decay_steps[idx + 1] - decay_steps[idx]
return [0.5 * base_lr * (1 + math.cos(math.pi * current_steps / decay_step)) for base_lr in self.base_lrs]


class ConstantLRwithWarmup(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
warmup_lr: float,
last_epoch: int = -1,
) -> None:
self.warmup_steps = warmup_steps
self.warmup_lr = warmup_lr
super().__init__(optimizer, last_epoch)

def get_lr(self) -> list[float]:
if self.last_epoch < self.warmup_steps:
return [
(base_lr - self.warmup_lr) * (self.last_epoch + 1) / self.warmup_steps + self.warmup_lr
for base_lr in self.base_lrs
]
else:
return self.base_lrs

+ 0
- 47
opensora/models/dc_ae/efficientvit/apps/utils/metric.py View File

@@ -1,47 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import torch

from ...apps.utils.dist import sync_tensor

__all__ = ["AverageMeter"]


class AverageMeter:
"""Computes and stores the average and current value."""

def __init__(self, is_distributed=True):
self.is_distributed = is_distributed
self.sum = 0
self.count = 0

def _sync(self, val: Union[torch.Tensor, int, float]) -> Union[torch.Tensor, int, float]:
return sync_tensor(val, reduce="sum") if self.is_distributed else val

def update(self, val: Union[torch.Tensor, int, float], delta_n=1):
self.count += self._sync(delta_n)
self.sum += self._sync(val * delta_n)

def get_count(self) -> Union[torch.Tensor, int, float]:
return self.count.item() if isinstance(self.count, torch.Tensor) and self.count.numel() == 1 else self.count

@property
def avg(self):
avg = -1 if self.count == 0 else self.sum / self.count
return avg.item() if isinstance(avg, torch.Tensor) and avg.numel() == 1 else avg

+ 0
- 114
opensora/models/dc_ae/efficientvit/apps/utils/misc.py View File

@@ -1,114 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Union

import yaml

__all__ = [
"parse_with_yaml",
"parse_unknown_args",
"partial_update_config",
"resolve_and_load_config",
"load_config",
"dump_config",
]


def parse_with_yaml(config_str: str) -> Union[str, dict]:
try:
# add space manually for dict
if "{" in config_str and "}" in config_str and ":" in config_str:
out_str = config_str.replace(":", ": ")
else:
out_str = config_str
return yaml.safe_load(out_str)
except ValueError:
# return raw string if parsing fails
return config_str


def parse_unknown_args(unknown: list) -> dict:
"""Parse unknown args."""
index = 0
parsed_dict = {}
while index < len(unknown):
key, val = unknown[index], unknown[index + 1]
index += 2
if not key.startswith("--"):
continue
key = key[2:]

# try parsing with either dot notation or full yaml notation
# Note that the vanilla case "--key value" will be parsed the same
if "." in key:
# key == a.b.c, val == val --> parsed_dict[a][b][c] = val
keys = key.split(".")
dict_to_update = parsed_dict
for key in keys[:-1]:
if not (key in dict_to_update and isinstance(dict_to_update[key], dict)):
dict_to_update[key] = {}
dict_to_update = dict_to_update[key]
dict_to_update[keys[-1]] = parse_with_yaml(val) # so we can parse lists, bools, etc...
else:
parsed_dict[key] = parse_with_yaml(val)
return parsed_dict


def partial_update_config(config: dict, partial_config: dict) -> dict:
for key in partial_config:
if key in config and isinstance(partial_config[key], dict) and isinstance(config[key], dict):
partial_update_config(config[key], partial_config[key])
else:
config[key] = partial_config[key]
return config


def resolve_and_load_config(path: str, config_name="config.yaml") -> dict:
path = os.path.realpath(os.path.expanduser(path))
if os.path.isdir(path):
config_path = os.path.join(path, config_name)
else:
config_path = path
if os.path.isfile(config_path):
pass
else:
raise Exception(f"Cannot find a valid config at {path}")
config = load_config(config_path)
return config


class SafeLoaderWithTuple(yaml.SafeLoader):
"""A yaml safe loader with python tuple loading capabilities."""

def construct_python_tuple(self, node):
return tuple(self.construct_sequence(node))


SafeLoaderWithTuple.add_constructor("tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple)


def load_config(filename: str) -> dict:
"""Load a yaml file."""
filename = os.path.realpath(os.path.expanduser(filename))
return yaml.load(open(filename), Loader=SafeLoaderWithTuple)


def dump_config(config: dict, filename: str) -> None:
"""Dump a config file"""
filename = os.path.realpath(os.path.expanduser(filename))
yaml.dump(config, open(filename, "w"), sort_keys=False)

+ 0
- 42
opensora/models/dc_ae/efficientvit/apps/utils/opt.py View File

@@ -1,42 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional

import torch

__all__ = ["REGISTERED_OPTIMIZER_DICT", "build_optimizer"]

# register optimizer here
# name: optimizer, kwargs with default values
REGISTERED_OPTIMIZER_DICT: dict[str, tuple[type, dict[str, Any]]] = {
"sgd": (torch.optim.SGD, {"momentum": 0.9, "nesterov": True}),
"adam": (torch.optim.Adam, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
"adamw": (torch.optim.AdamW, {"betas": (0.9, 0.999), "eps": 1e-8, "amsgrad": False}),
}


def build_optimizer(
net_params, optimizer_name: str, optimizer_params: Optional[dict], init_lr: float
) -> torch.optim.Optimizer:
optimizer_class, default_params = REGISTERED_OPTIMIZER_DICT[optimizer_name]
optimizer_params = {} if optimizer_params is None else optimizer_params

for key in default_params:
if key in optimizer_params:
default_params[key] = optimizer_params[key]
optimizer = optimizer_class(net_params, init_lr, **default_params)
return optimizer

+ 0
- 0
opensora/models/dc_ae/efficientvit/models/__init__.py View File


+ 0
- 102
opensora/models/dc_ae/efficientvit/models/nn/drop.py View File

@@ -1,102 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional

import numpy as np
import torch
import torch.nn as nn

from ...apps.trainer.run_config import Scheduler
from ...models.nn.ops import IdentityLayer, ResidualBlock
from ...models.utils import build_kwargs_from_config

__all__ = ["apply_drop_func"]


def apply_drop_func(network: nn.Module, drop_config: Optional[dict[str, Any]]) -> None:
if drop_config is None:
return

drop_lookup_table = {
"droppath": apply_droppath,
}

drop_func = drop_lookup_table[drop_config["name"]]
drop_kwargs = build_kwargs_from_config(drop_config, drop_func)

drop_func(network, **drop_kwargs)


def apply_droppath(
network: nn.Module,
drop_prob: float,
linear_decay=True,
scheduled=True,
skip=0,
) -> None:
all_valid_blocks = []
for m in network.modules():
for name, sub_module in m.named_children():
if isinstance(sub_module, ResidualBlock) and isinstance(sub_module.shortcut, IdentityLayer):
all_valid_blocks.append((m, name, sub_module))
all_valid_blocks = all_valid_blocks[skip:]
for i, (m, name, sub_module) in enumerate(all_valid_blocks):
prob = drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob
new_module = DropPathResidualBlock(
sub_module.main,
sub_module.shortcut,
sub_module.post_act,
sub_module.pre_norm,
prob,
scheduled,
)
m._modules[name] = new_module


class DropPathResidualBlock(ResidualBlock):
def __init__(
self,
main: nn.Module,
shortcut: Optional[nn.Module],
post_act=None,
pre_norm: Optional[nn.Module] = None,
######################################
drop_prob: float = 0,
scheduled=True,
):
super().__init__(main, shortcut, post_act, pre_norm)

self.drop_prob = drop_prob
self.scheduled = scheduled

def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.drop_prob == 0 or not isinstance(self.shortcut, IdentityLayer):
return ResidualBlock.forward(self, x)
else:
drop_prob = self.drop_prob
if self.scheduled:
drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1)
keep_prob = 1 - drop_prob

shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize

res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x)
if self.post_act:
res = self.post_act(res)
return res

+ 0
- 183
opensora/models/dc_ae/efficientvit/models/nn/norm.py View File

@@ -1,183 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm

from ...models.nn.triton_rms_norm import TritonRMSNorm2dFunc
from ...models.utils import build_kwargs_from_config

__all__ = ["LayerNorm2d", "TritonRMSNorm2d", "build_norm", "reset_bn", "set_norm_eps"]


class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x - torch.mean(x, dim=1, keepdim=True)
out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
if self.elementwise_affine:
out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return out


class TritonRMSNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return TritonRMSNorm2dFunc.apply(x, self.weight, self.bias, self.eps)


class RMSNorm2d(nn.Module):
def __init__(
self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True
) -> None:
super().__init__()
self.num_features = num_features
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.parameter.Parameter(torch.empty(self.num_features))
if bias:
self.bias = torch.nn.parameter.Parameter(torch.empty(self.num_features))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype)
if self.elementwise_affine:
x = x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return x


# register normalization function here
REGISTERED_NORM_DICT: dict[str, type] = {
"bn2d": nn.BatchNorm2d,
"ln": nn.LayerNorm,
"ln2d": LayerNorm2d,
"trms2d": TritonRMSNorm2d,
"rms2d": RMSNorm2d,
}


def build_norm(name="bn2d", num_features=None, **kwargs) -> Optional[nn.Module]:
if name in ["ln", "ln2d", "trms2d"]:
kwargs["normalized_shape"] = num_features
else:
kwargs["num_features"] = num_features
if name in REGISTERED_NORM_DICT:
norm_cls = REGISTERED_NORM_DICT[name]
args = build_kwargs_from_config(kwargs, norm_cls)
return norm_cls(**args)
else:
return None


def reset_bn(
model: nn.Module,
data_loader: list,
sync=True,
progress_bar=False,
) -> None:
import copy

import torch.nn.functional as F
from efficientvit.apps.utils import AverageMeter, is_master, sync_tensor
from efficientvit.models.utils import get_device, list_join
from tqdm import tqdm

bn_mean = {}
bn_var = {}

tmp_model = copy.deepcopy(model)
for name, m in tmp_model.named_modules():
if isinstance(m, _BatchNorm):
bn_mean[name] = AverageMeter(is_distributed=False)
bn_var[name] = AverageMeter(is_distributed=False)

def new_forward(bn, mean_est, var_est):
def lambda_forward(x):
x = x.contiguous()
if sync:
batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
batch_mean = sync_tensor(batch_mean, reduce="cat")
batch_mean = torch.mean(batch_mean, dim=0, keepdim=True)

batch_var = (x - batch_mean) * (x - batch_mean)
batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
batch_var = sync_tensor(batch_var, reduce="cat")
batch_var = torch.mean(batch_var, dim=0, keepdim=True)
else:
batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
batch_var = (x - batch_mean) * (x - batch_mean)
batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)

batch_mean = torch.squeeze(batch_mean)
batch_var = torch.squeeze(batch_var)

mean_est.update(batch_mean.data, x.size(0))
var_est.update(batch_var.data, x.size(0))

# bn forward using calculated mean & var
_feature_dim = batch_mean.shape[0]
return F.batch_norm(
x,
batch_mean,
batch_var,
bn.weight[:_feature_dim],
bn.bias[:_feature_dim],
False,
0.0,
bn.eps,
)

return lambda_forward

m.forward = new_forward(m, bn_mean[name], bn_var[name])

# skip if there is no batch normalization layers in the network
if len(bn_mean) == 0:
return

tmp_model.eval()
with torch.no_grad():
with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t:
for images in data_loader:
images = images.to(get_device(tmp_model))
tmp_model(images)
t.set_postfix(
{
"bs": images.size(0),
"res": list_join(images.shape[-2:], "x"),
}
)
t.update()

for name, m in model.named_modules():
if name in bn_mean and bn_mean[name].count > 0:
feature_dim = bn_mean[name].avg.size(0)
assert isinstance(m, _BatchNorm)
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)


def set_norm_eps(model: nn.Module, eps: Optional[float] = None) -> None:
for m in model.modules():
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
if eps is not None:
m.eps = eps

+ 0
- 207
opensora/models/dc_ae/efficientvit/models/nn/triton_rms_norm.py View File

@@ -1,207 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import torch
import triton
import triton.language as tl

__all__ = ["TritonRMSNorm2dFunc"]


@triton.jit
def _rms_norm_2d_fwd_fused(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Rrms, # pointer to the 1/rms
M,
C,
N,
num_blocks, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
m_n = tl.program_id(0)
m, n = m_n // num_blocks, m_n % num_blocks

Y += m * C * N
X += m * C * N
# Compute mean

cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cols < N

x_sum_square = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, C):
x = tl.load(X + off * N + cols, mask=mask, other=0.0).to(tl.float32)
x_sum_square += x * x
mean_square = x_sum_square / C
rrms = 1 / tl.sqrt(mean_square + eps)
# Write rstd
tl.store(Rrms + m * N + cols, rrms, mask=mask)
# Normalize and apply linear transformation
for off in range(0, C):
pos = off * N + cols
w = tl.load(W + off)
b = tl.load(B + off)
x = tl.load(X + pos, mask=mask, other=0.0).to(tl.float32)
x_hat = x * rrms
y = x_hat * w + b
# Write output
tl.store(Y + pos, y, mask=mask)


@triton.jit
def _rms_norm_2d_bwd_dx_fused(
DX, # pointer to the input gradient
DY, # pointer to the output gradient
DW, # pointer to the partial sum of weights gradient
DB, # pointer to the partial sum of biases gradient
X, # pointer to the input
W, # pointer to the weights
B, # pointer to the biases
Rrms, # pointer to the 1/rms
M,
C,
N, # number of columns in X
num_blocks,
eps, # epsilon to avoid division by zero
GROUP_SIZE_M: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
BLOCK_SIZE_C: tl.constexpr,
):
# Map the program id to the elements of X, DX, and DY it should compute.
m_n = tl.program_id(0)
m, n = m_n // num_blocks, m_n % num_blocks
X += m * C * N
DY += m * C * N
DX += m * C * N
Rrms += m * N

cols = n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cols < N
# Offset locks and weights/biases gradient pointer for parallel reduction
DW = DW + m_n * C
DB = DB + m_n * C
rrms = tl.load(Rrms + cols, mask=mask, other=1)
# Load data to SRAM
c1 = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, C):
pos = off * N + cols
x = tl.load(X + pos, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32)
w = tl.load(W + off).to(tl.float32)
# Compute dx
xhat = x * rrms
wdy = w * dy
xhat = tl.where(mask, xhat, 0.0)
wdy = tl.where(mask, wdy, 0.0)
c1 += xhat * wdy
# Accumulate partial sums for dw/db
tl.store(DW + off, tl.sum((dy * xhat).to(w.dtype), axis=0))
tl.store(DB + off, tl.sum(dy.to(w.dtype), axis=0))

c1 /= C
for off in range(0, C):
pos = off * N + cols
x = tl.load(X + pos, mask=mask, other=0).to(tl.float32)
dy = tl.load(DY + pos, mask=mask, other=0).to(tl.float32)
w = tl.load(W + off).to(tl.float32)
xhat = x * rrms
wdy = w * dy
dx = (wdy - (xhat * c1)) * rrms
# Write dx
tl.store(DX + pos, dx, mask=mask)


class TritonRMSNorm2dFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
# allocate output
y = torch.empty_like(x)
# reshape input data into 2D tensor
x_arg = x.reshape(x.shape[0], x.shape[1], -1)
M, C, N = x_arg.shape
rrms = torch.empty((M, N), dtype=torch.float32, device="cuda")
# Less than 64KB per feature: enqueue fused kernel
BLOCK_SIZE = 256
num_blocks = triton.cdiv(N, BLOCK_SIZE)
num_warps = 8
# enqueue kernel
_rms_norm_2d_fwd_fused[(M * num_blocks,)]( #
x_arg,
y,
weight,
bias,
rrms, #
M,
C,
N,
num_blocks,
eps, #
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
num_ctas=1,
)
ctx.save_for_backward(x, weight, bias, rrms)
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.num_blocks = num_blocks
ctx.num_warps = num_warps
ctx.eps = eps
return y

@staticmethod
def backward(ctx, dy):
x, w, b, rrms = ctx.saved_tensors
num_blocks = ctx.num_blocks

x_arg = x.reshape(x.shape[0], x.shape[1], -1)
M, C, N = x_arg.shape
# GROUP_SIZE_M = 64
GROUP_SIZE_M = M * num_blocks
# allocate output
_dw = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device)
_db = torch.empty((GROUP_SIZE_M, C), dtype=x.dtype, device=w.device)
dw = torch.empty((C,), dtype=w.dtype, device=w.device)
db = torch.empty((C,), dtype=w.dtype, device=w.device)
dx = torch.empty_like(dy)
# enqueue kernel using forward pass heuristics
# also compute partial sums for DW and DB
# print(f"M={M}, num_blocks={num_blocks}, dx={dx.shape}, dy={dy.shape}, _dw={_dw.shape}, _db={_db.shape}, x={x.shape}, w={w.shape}, b={b.shape}, m={m.shape}, v={v.shape}, M={M}, C={C}, N={N}")
_rms_norm_2d_bwd_dx_fused[(M * num_blocks,)]( #
dx,
dy,
_dw,
_db,
x,
w,
b,
rrms, #
M,
C,
N,
num_blocks,
ctx.eps, #
BLOCK_SIZE=ctx.BLOCK_SIZE,
GROUP_SIZE_M=GROUP_SIZE_M, #
BLOCK_SIZE_C=triton.next_power_of_2(C),
num_warps=ctx.num_warps,
)
dw = _dw.sum(dim=0)
db = _db.sum(dim=0)
return dx, dw, db, None

+ 0
- 3
opensora/models/dc_ae/efficientvit/models/utils/__init__.py View File

@@ -1,3 +0,0 @@
from .list import *
from .network import *
from .random import *

+ 0
- 111
opensora/models/dc_ae/efficientvit/models/utils/network.py View File

@@ -1,111 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

import collections
import os
from inspect import signature
from typing import Any, Callable, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = [
"is_parallel",
"get_device",
"get_same_padding",
"resize",
"build_kwargs_from_config",
"load_state_dict_from_file",
"get_submodule_weights",
]


def is_parallel(model: nn.Module) -> bool:
return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))


def get_device(model: nn.Module) -> torch.device:
return model.parameters().__next__().device


def get_dtype(model: nn.Module) -> torch.dtype:
return model.parameters().__next__().dtype


def get_same_padding(kernel_size: Union[int, tuple[int, ...]]) -> Union[int, tuple[int, ...]]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2


def resize(
x: torch.Tensor,
size: Optional[Any] = None,
scale_factor: Optional[list[float]] = None,
mode: str = "bicubic",
align_corners: Optional[bool] = False,
) -> torch.Tensor:
if mode in {"bilinear", "bicubic"}:
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
)
elif mode in {"nearest", "area"}:
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
else:
raise NotImplementedError(f"resize(mode={mode}) not implemented.")


def build_kwargs_from_config(config: dict, target_func: Callable) -> dict[str, Any]:
valid_keys = list(signature(target_func).parameters)
kwargs = {}
for key in config:
if key in valid_keys:
kwargs[key] = config[key]
return kwargs


def load_state_dict_from_file(file: str, only_state_dict=True) -> dict[str, torch.Tensor]:
file = os.path.realpath(os.path.expanduser(file))
checkpoint = torch.load(file, map_location="cpu", weights_only=True)
if only_state_dict and "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
return checkpoint


def get_submodule_weights(weights: collections.OrderedDict, prefix: str):
submodule_weights = collections.OrderedDict()
len_prefix = len(prefix)
for key, weight in weights.items():
if key.startswith(prefix):
submodule_weights[key[len_prefix:]] = weight
return submodule_weights


def get_dtype_from_str(dtype: str) -> torch.dtype:
if dtype == "fp32":
return torch.float32
if dtype == "fp16":
return torch.float16
if dtype == "bf16":
return torch.bfloat16
raise NotImplementedError(f"dtype {dtype} is not supported")

+ 0
- 79
opensora/models/dc_ae/efficientvit/models/utils/random.py View File

@@ -1,79 +0,0 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Optional, Union

import numpy as np
import torch

__all__ = [
"torch_randint",
"torch_random",
"torch_shuffle",
"torch_uniform",
"torch_random_choices",
]


def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int:
"""uniform: [low, high)"""
if low == high:
return low
else:
assert low < high
return int(torch.randint(low=low, high=high, generator=generator, size=(1,)))


def torch_random(generator: Optional[torch.Generator] = None) -> float:
"""uniform distribution on the interval [0, 1)"""
return float(torch.rand(1, generator=generator))


def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]:
rand_indexes = torch.randperm(len(src_list), generator=generator).tolist()
return [src_list[i] for i in rand_indexes]


def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float:
"""uniform distribution on the interval [low, high)"""
rand_val = torch_random(generator)
return (high - low) * rand_val + low


def torch_random_choices(
src_list: list[Any],
generator: Optional[torch.Generator] = None,
k=1,
weight_list: Optional[list[float]] = None,
) -> Union[Any, list]:
if weight_list is None:
rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,))
out_list = [src_list[i] for i in rand_idx]
else:
assert len(weight_list) == len(src_list)
accumulate_weight_list = np.cumsum(weight_list)

out_list = []
for _ in range(k):
val = torch_uniform(0, accumulate_weight_list[-1], generator)
active_id = 0
for i, weight_val in enumerate(accumulate_weight_list):
active_id = i
if weight_val > val:
break
out_list.append(src_list[active_id])

return out_list[0] if k == 1 else out_list

opensora/models/dc_ae/efficientvit/models/efficientvit/__init__.py → opensora/models/dc_ae/models/__init__.py View File


opensora/models/dc_ae/efficientvit/models/efficientvit/dc_ae.py → opensora/models/dc_ae/models/dc_ae.py View File

@@ -22,11 +22,13 @@ import torch.nn as nn
from omegaconf import MISSING, OmegaConf
from torch import Tensor

from ...apps.setup import init_model
from ...models.nn.act import build_act
from ...models.nn.norm import build_norm
from ...models.nn.ops import (
ChannelDuplicatingPixelUnshuffleUpSampleLayer,
from opensora.acceleration.checkpoint import auto_grad_checkpoint

from ..utils import init_modules
from .nn.act import build_act
from .nn.norm import build_norm
from .nn.ops import (
ChannelDuplicatingPixelShuffleUpSampleLayer,
ConvLayer,
ConvPixelShuffleUpSampleLayer,
ConvPixelUnshuffleDownSampleLayer,
@@ -39,7 +41,7 @@ from ...models.nn.ops import (
ResidualBlock,
)

__all__ = ["DCAE", "dc_ae_f32c32", "dc_ae_f64c128", "dc_ae_f128c512"]
__all__ = ["DCAE", "dc_ae_f32", "dc_ae_f64c128", "dc_ae_f128c512", "dc_ae_f64t4c256"]


@dataclass
@@ -58,6 +60,9 @@ class EncoderConfig:
out_act: Optional[str] = None
out_shortcut: Optional[str] = "averaging"
double_latent: bool = False
is_video: bool = False
temporal_downsample: tuple[bool, ...] = ()
tune_channel_proj: bool = False


@dataclass
@@ -75,12 +80,17 @@ class DecoderConfig:
upsample_shortcut: str = "duplicating"
out_norm: str = "rms2d"
out_act: str = "relu"
is_video: bool = False
temporal_upsample: tuple[bool, ...] = ()
tune_channel_proj: bool = False


@dataclass
class DCAEConfig:
in_channels: int = 3
latent_channels: int = 32
time_compression_ratio: int = 1
spatial_compression_ratio: int = 32
encoder: EncoderConfig = field(
default_factory=lambda: EncoderConfig(in_channels="${..in_channels}", latent_channels="${..latent_channels}")
)
@@ -93,10 +103,21 @@ class DCAEConfig:
pretrained_source: str = "dc-ae"

scaling_factor: Optional[float] = None
is_image_model: bool = False

tune_channel_proj: bool = False
train_decoder_only: bool = False
is_training: bool = False # NOTE: set to True in vae train config

use_spatial_tiling: bool = False
use_temporal_tiling: bool = False
spatial_tile_size: int = 256
temporal_tile_size: int = 32
tile_overlap_factor: float = 0.25


def build_block(
block_type: str, in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str]
block_type: str, in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], is_video: bool
) -> nn.Module:
if block_type == "ResBlock":
assert in_channels == out_channels
@@ -108,21 +129,26 @@ def build_block(
use_bias=(True, False),
norm=(None, norm),
act_func=(act, None),
is_video=is_video,
)
block = ResidualBlock(main_block, IdentityLayer())
elif block_type == "EViT_GLU":
assert in_channels == out_channels
block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=())
block = EfficientViTBlock(
in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=(), is_video=is_video
)
elif block_type == "EViTS5_GLU":
assert in_channels == out_channels
block = EfficientViTBlock(in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=(5,))
block = EfficientViTBlock(
in_channels, norm=norm, act_func=act, local_module="GLUMBConv", scales=(5,), is_video=is_video
)
else:
raise ValueError(f"block_type {block_type} is not supported")
return block


def build_stage_main(
width: int, depth: int, block_type: str | list[str], norm: str, act: str, input_width: int
width: int, depth: int, block_type: str | list[str], norm: str, act: str, input_width: int, is_video: bool
) -> list[nn.Module]:
assert isinstance(block_type, str) or (isinstance(block_type, list) and depth == len(block_type))
stage = []
@@ -134,23 +160,45 @@ def build_stage_main(
out_channels=width,
norm=norm,
act=act,
is_video=is_video,
)
stage.append(block)
return stage


def build_downsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module:
def build_downsample_block(
block_type: str,
in_channels: int,
out_channels: int,
shortcut: Optional[str],
is_video: bool,
temporal_downsample: bool = False,
) -> nn.Module:
"""
Spatial downsample is always performed. Temporal downsample is optional.
"""

if block_type == "Conv":
if is_video:
if temporal_downsample:
stride = (2, 2, 2)
else:
stride = (1, 2, 2)
else:
stride = 2
block = ConvLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
stride=stride,
use_bias=True,
norm=None,
act_func=None,
is_video=is_video,
)
elif block_type == "ConvPixelUnshuffle":
if is_video:
raise NotImplementedError("ConvPixelUnshuffle downsample is not supported for video")
block = ConvPixelUnshuffleDownSampleLayer(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
)
@@ -160,7 +208,7 @@ def build_downsample_block(block_type: str, in_channels: int, out_channels: int,
pass
elif shortcut == "averaging":
shortcut_block = PixelUnshuffleChannelAveragingDownSampleLayer(
in_channels=in_channels, out_channels=out_channels, factor=2
in_channels=in_channels, out_channels=out_channels, factor=2, temporal_downsample=temporal_downsample
)
block = ResidualBlock(block, shortcut_block)
else:
@@ -168,22 +216,36 @@ def build_downsample_block(block_type: str, in_channels: int, out_channels: int,
return block


def build_upsample_block(block_type: str, in_channels: int, out_channels: int, shortcut: Optional[str]) -> nn.Module:
def build_upsample_block(
block_type: str,
in_channels: int,
out_channels: int,
shortcut: Optional[str],
is_video: bool,
temporal_upsample: bool = False,
) -> nn.Module:
if block_type == "ConvPixelShuffle":
if is_video:
raise NotImplementedError("ConvPixelShuffle upsample is not supported for video")
block = ConvPixelShuffleUpSampleLayer(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
)
elif block_type == "InterpolateConv":
block = InterpolateConvUpSampleLayer(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, factor=2
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
factor=2,
is_video=is_video,
temporal_upsample=temporal_upsample,
)
else:
raise ValueError(f"block_type {block_type} is not supported for upsampling")
if shortcut is None:
pass
elif shortcut == "duplicating":
shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer(
in_channels=in_channels, out_channels=out_channels, factor=2
shortcut_block = ChannelDuplicatingPixelShuffleUpSampleLayer(
in_channels=in_channels, out_channels=out_channels, factor=2, temporal_upsample=temporal_upsample
)
block = ResidualBlock(block, shortcut_block)
else:
@@ -191,7 +253,9 @@ def build_upsample_block(block_type: str, in_channels: int, out_channels: int, s
return block


def build_encoder_project_in_block(in_channels: int, out_channels: int, factor: int, downsample_block_type: str):
def build_encoder_project_in_block(
in_channels: int, out_channels: int, factor: int, downsample_block_type: str, is_video: bool
):
if factor == 1:
block = ConvLayer(
in_channels=in_channels,
@@ -201,8 +265,11 @@ def build_encoder_project_in_block(in_channels: int, out_channels: int, factor:
use_bias=True,
norm=None,
act_func=None,
is_video=is_video,
)
elif factor == 2:
if is_video:
raise NotImplementedError("Downsample during project_in is not supported for video")
block = build_downsample_block(
block_type=downsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None
)
@@ -212,7 +279,12 @@ def build_encoder_project_in_block(in_channels: int, out_channels: int, factor:


def build_encoder_project_out_block(
in_channels: int, out_channels: int, norm: Optional[str], act: Optional[str], shortcut: Optional[str]
in_channels: int,
out_channels: int,
norm: Optional[str],
act: Optional[str],
shortcut: Optional[str],
is_video: bool,
):
block = OpSequential(
[
@@ -226,6 +298,7 @@ def build_encoder_project_out_block(
use_bias=True,
norm=None,
act_func=None,
is_video=is_video,
),
]
)
@@ -241,7 +314,7 @@ def build_encoder_project_out_block(
return block


def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut: Optional[str]):
def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut: Optional[str], is_video: bool):
block = ConvLayer(
in_channels=in_channels,
out_channels=out_channels,
@@ -250,11 +323,12 @@ def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut
use_bias=True,
norm=None,
act_func=None,
is_video=is_video,
)
if shortcut is None:
pass
elif shortcut == "duplicating":
shortcut_block = ChannelDuplicatingPixelUnshuffleUpSampleLayer(
shortcut_block = ChannelDuplicatingPixelShuffleUpSampleLayer(
in_channels=in_channels, out_channels=out_channels, factor=1
)
block = ResidualBlock(block, shortcut_block)
@@ -264,7 +338,13 @@ def build_decoder_project_in_block(in_channels: int, out_channels: int, shortcut


def build_decoder_project_out_block(
in_channels: int, out_channels: int, factor: int, upsample_block_type: str, norm: Optional[str], act: Optional[str]
in_channels: int,
out_channels: int,
factor: int,
upsample_block_type: str,
norm: Optional[str],
act: Optional[str],
is_video: bool,
):
layers: list[nn.Module] = [
build_norm(norm, in_channels),
@@ -280,9 +360,12 @@ def build_decoder_project_out_block(
use_bias=True,
norm=None,
act_func=None,
is_video=is_video,
)
)
elif factor == 2:
if is_video:
raise NotImplementedError("Upsample during project_out is not supported for video")
layers.append(
build_upsample_block(
block_type=upsample_block_type, in_channels=in_channels, out_channels=out_channels, shortcut=None
@@ -310,13 +393,20 @@ class Encoder(nn.Module):
out_channels=cfg.width_list[0] if cfg.depth_list[0] > 0 else cfg.width_list[1],
factor=1 if cfg.depth_list[0] > 0 else 2,
downsample_block_type=cfg.downsample_block_type,
is_video=cfg.is_video,
)

self.stages: list[OpSequential] = []
for stage_id, (width, depth) in enumerate(zip(cfg.width_list, cfg.depth_list)):
block_type = cfg.block_type[stage_id] if isinstance(cfg.block_type, list) else cfg.block_type
stage = build_stage_main(
width=width, depth=depth, block_type=block_type, norm=cfg.norm, act=cfg.act, input_width=width
width=width,
depth=depth,
block_type=block_type,
norm=cfg.norm,
act=cfg.act,
input_width=width,
is_video=cfg.is_video,
)

if stage_id < num_stages - 1 and depth > 0:
@@ -325,6 +415,8 @@ class Encoder(nn.Module):
in_channels=width,
out_channels=cfg.width_list[stage_id + 1] if cfg.downsample_match_channel else width,
shortcut=cfg.downsample_shortcut,
is_video=cfg.is_video,
temporal_downsample=cfg.temporal_downsample[stage_id] if cfg.temporal_downsample != [] else False,
)
stage.append(downsample_block)
self.stages.append(OpSequential(stage))
@@ -336,15 +428,22 @@ class Encoder(nn.Module):
norm=cfg.out_norm,
act=cfg.out_act,
shortcut=cfg.out_shortcut,
is_video=cfg.is_video,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.project_in(x)
# x = auto_grad_checkpoint(self.project_in, x)
for stage in self.stages:
if len(stage.op_list) == 0:
continue
x = stage(x)
x = self.project_out(x)
# x = stage(x)
if self.cfg.tune_channel_proj:
x = stage(x)
else:
x = auto_grad_checkpoint(stage, x)
# x = self.project_out(x)
x = auto_grad_checkpoint(self.project_out, x)
return x


@@ -366,6 +465,7 @@ class Decoder(nn.Module):
in_channels=cfg.latent_channels,
out_channels=cfg.width_list[-1],
shortcut=cfg.in_shortcut,
is_video=cfg.is_video,
)

self.stages: list[OpSequential] = []
@@ -377,6 +477,8 @@ class Decoder(nn.Module):
in_channels=cfg.width_list[stage_id + 1],
out_channels=width if cfg.upsample_match_channel else cfg.width_list[stage_id + 1],
shortcut=cfg.upsample_shortcut,
is_video=cfg.is_video,
temporal_upsample=cfg.temporal_upsample[stage_id] if cfg.temporal_upsample != [] else False,
)
stage.append(upsample_block)

@@ -393,6 +495,7 @@ class Decoder(nn.Module):
input_width=(
width if cfg.upsample_match_channel else cfg.width_list[min(stage_id + 1, num_stages - 1)]
),
is_video=cfg.is_video,
)
)
self.stages.insert(0, OpSequential(stage))
@@ -405,15 +508,21 @@ class Decoder(nn.Module):
upsample_block_type=cfg.upsample_block_type,
norm=cfg.out_norm,
act=cfg.out_act,
is_video=cfg.is_video,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.project_in(x)
if self.cfg.tune_channel_proj:
x = self.project_in(x)
else:
x = auto_grad_checkpoint(self.project_in, x)
for stage in reversed(self.stages):
if len(stage.op_list) == 0:
continue
x = stage(x)
x = self.project_out(x)
# x = stage(x)
x = auto_grad_checkpoint(stage, x)
# x = self.project_out(x)
x = auto_grad_checkpoint(self.project_out, x)
return x


@@ -423,13 +532,31 @@ class DCAE(nn.Module):
self.cfg = cfg
self.encoder = Encoder(cfg.encoder)
self.decoder = Decoder(cfg.decoder)
self.scaling_factor = cfg.scaling_factor
if cfg.tune_channel_proj: # propergate the tune_channel_proj flag for gradient checkpointing handling
self.encoder.cfg.tune_channel_proj = True
self.decoder.cfg.tune_channel_proj = True

self.scaling_factor = cfg.scaling_factor
self.time_compression_ratio = cfg.time_compression_ratio
self.spatial_compression_ratio = cfg.spatial_compression_ratio
self.use_spatial_tiling = cfg.use_spatial_tiling
self.use_temporal_tiling = cfg.use_temporal_tiling
self.spatial_tile_size = cfg.spatial_tile_size
self.temporal_tile_size = cfg.temporal_tile_size
assert (
cfg.spatial_tile_size // cfg.spatial_compression_ratio
), f"spatial tile size {cfg.spatial_tile_size} must be divisible by spatial compression of {cfg.spatial_compression_ratio}"
self.spatial_tile_latent_size = cfg.spatial_tile_size // cfg.spatial_compression_ratio
assert (
cfg.temporal_tile_size // cfg.time_compression_ratio
), f"temporal tile size {cfg.temporal_tile_size} must be divisible by temporal compression of {cfg.time_compression_ratio}"
self.temporal_tile_latent_size = cfg.temporal_tile_size // cfg.time_compression_ratio
self.tile_overlap_factor = cfg.tile_overlap_factor
if self.cfg.pretrained_path is not None:
self.load_model()

self.to(torch.float32)
init_model(self)
init_modules(self, init_type="trunc_normal")

def load_model(self):
if self.cfg.pretrained_source == "dc-ae":
@@ -438,20 +565,22 @@ class DCAE(nn.Module):
else:
raise NotImplementedError

@property
def spatial_compression_ratio(self) -> int:
return 2 ** (self.decoder.num_stages - 1)
def get_last_layer(self):
return self.decoder.project_out.op_list[2].conv.weight

def encode_single(self, x: torch.Tensor) -> torch.Tensor:
# @property
# def spatial_compression_ratio(self) -> int:
# return 2 ** (self.decoder.num_stages - 1)

def encode_single(self, x: torch.Tensor, is_video_encoder: bool = False) -> torch.Tensor:
assert x.shape[0] == 1
is_video = x.dim() == 5
if is_video:
if is_video and not is_video_encoder:
b, c, f, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)

z = self.encoder(x)

if is_video:
if is_video and not is_video_encoder:
z = z.unsqueeze(dim=0).permute(0, 2, 1, 3, 4)

if self.scaling_factor is not None:
@@ -459,58 +588,222 @@ class DCAE(nn.Module):

return z

def encode(self, x: torch.Tensor) -> torch.Tensor:
def _encode(self, x: torch.Tensor) -> torch.Tensor:
if self.cfg.is_training:
return self.encoder(x)
is_video_encoder = self.encoder.cfg.is_video if self.encoder.cfg.is_video is not None else False
x_ret = []
for i in range(x.shape[0]):
x_ret.append(self.encode_single(x[i : i + 1]))
x_ret.append(self.encode_single(x[i : i + 1], is_video_encoder))
return torch.cat(x_ret, dim=0)

def decode_single(self, z: torch.Tensor) -> torch.Tensor:
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b

def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b

def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
x / blend_extent
)
return b

def spatial_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
net_size = int(self.spatial_tile_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.spatial_tile_latent_size * self.tile_overlap_factor)
row_limit = self.spatial_tile_latent_size - blend_extent

# Split video into tiles and encode them separately.
rows = []
for i in range(0, x.shape[-2], net_size):
row = []
for j in range(0, x.shape[-1], net_size):
tile = x[:, :, :, i : i + self.spatial_tile_size, j : j + self.spatial_tile_size]
tile = self._encode(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))

return torch.cat(result_rows, dim=-2)

def temporal_tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
overlap_size = int(self.temporal_tile_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.temporal_tile_latent_size * self.tile_overlap_factor)
t_limit = self.temporal_tile_latent_size - blend_extent

# Split the video into tiles and encode them separately.
row = []
for i in range(0, x.shape[2], overlap_size):
tile = x[:, :, i : i + self.temporal_tile_size, :, :]
if self.use_spatial_tiling and (
tile.shape[-1] > self.spatial_tile_size or tile.shape[-2] > self.spatial_tile_size
):
tile = self.spatial_tiled_encode(tile)
else:
tile = self._encode(tile)
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])

return torch.cat(result_row, dim=2)

def encode(self, x: torch.Tensor) -> torch.Tensor:
if self.use_temporal_tiling and x.shape[2] > self.temporal_tile_size:
return self.temporal_tiled_encode(x)
elif self.use_spatial_tiling and (x.shape[-1] > self.spatial_tile_size or x.shape[-2] > self.spatial_tile_size):
return self.spatial_tiled_encode(x)
else:
return self._encode(x)

def spatial_tiled_decode(self, z: torch.FloatTensor) -> torch.Tensor:
net_size = int(self.spatial_tile_latent_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.spatial_tile_size * self.tile_overlap_factor)
row_limit = self.spatial_tile_size - blend_extent

# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[-2], net_size):
row = []
for j in range(0, z.shape[-1], net_size):
tile = z[:, :, :, i : i + self.spatial_tile_latent_size, j : j + self.spatial_tile_latent_size]
decoded = self._decode(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))

return torch.cat(result_rows, dim=-2)

def temporal_tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
overlap_size = int(self.temporal_tile_latent_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.temporal_tile_size * self.tile_overlap_factor)
t_limit = self.temporal_tile_size - blend_extent

row = []
for i in range(0, z.shape[2], overlap_size):
tile = z[:, :, i : i + self.temporal_tile_latent_size, :, :]
if self.use_spatial_tiling and (
tile.shape[-1] > self.spatial_tile_latent_size or tile.shape[-2] > self.spatial_tile_latent_size
):
decoded = self.spatial_tiled_decode(tile)
else:
decoded = self._decode(tile)
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])

return torch.cat(result_row, dim=2)

def decode_single(self, z: torch.Tensor, is_video_decoder: bool = False) -> torch.Tensor:
assert z.shape[0] == 1
is_video = z.dim() == 5
if is_video:
if is_video and not is_video_decoder:
b, c, f, h, w = z.shape
z = z.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)

if self.scaling_factor is not None:
z = z * self.scaling_factor

x = self.decoder(z)

if is_video:
if is_video and not is_video_decoder:
x = x.unsqueeze(dim=0).permute(0, 2, 1, 3, 4)
return x

def decode(self, z: torch.Tensor) -> torch.Tensor:
def _decode(self, z: torch.Tensor) -> torch.Tensor:
if self.cfg.is_training:
return self.decoder(z)
is_video_decoder = self.decoder.cfg.is_video if self.decoder.cfg.is_video is not None else False
x_ret = []
for i in range(z.shape[0]):
x_ret.append(self.decode_single(z[i : i + 1]))
x_ret.append(self.decode_single(z[i : i + 1], is_video_decoder))
return torch.cat(x_ret, dim=0)

def decode(self, z: torch.Tensor) -> torch.Tensor:
if self.use_temporal_tiling and z.shape[2] > self.temporal_tile_latent_size:
return self.temporal_tiled_decode(z)
elif self.use_spatial_tiling and (
z.shape[-1] > self.spatial_tile_latent_size or z.shape[-2] > self.spatial_tile_latent_size
):
return self.spatial_tiled_decode(z)
else:
return self._decode(z)

def forward(self, x: torch.Tensor) -> tuple[Any, Tensor, dict[Any, Any]]:
x_type = x.dtype
is_video = x.dim() == 5
is_image_model = self.cfg.__dict__.get("is_image_model", False)
x = x.to(self.encoder.project_in.conv.weight.dtype)

if is_video:
b, c, f, h, w = x.shape
if is_image_model:
b, c, _, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(-1, c, h, w)

z = self.encoder(x)
dec = self.decoder(z)
z = self.encode(x)
dec = self.decode(z)

if is_video:
dec = dec.reshape(b, f, c, h, w).permute(0, 2, 1, 3, 4)
if is_image_model:
dec = dec.reshape(b, 1, c, h, w).permute(0, 2, 1, 3, 4)
z = z.unsqueeze(dim=0).permute(0, 2, 1, 3, 4)

dec = dec.to(x_type)
return dec, None, z

def get_latent_size(self, input_size: list[int]) -> list[int]:
latent_size = []
# T
latent_size.append((input_size[0] - 1) // self.time_compression_ratio + 1)
# H, w
for i in range(1, 3):
latent_size.append((input_size[i] - 1) // self.spatial_compression_ratio + 1)
return latent_size

def dc_ae_f32c32(name: str, pretrained_path: str) -> DCAEConfig:

def dc_ae_f32(name: str, pretrained_path: str) -> DCAEConfig:
if name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
cfg_str = (
"latent_channels=32 "
"time_compression_ratio=1 "
"spatial_compression_ratio=32 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[0,4,8,2,2,2] "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViT_GLU,EViT_GLU,EViT_GLU] "
@@ -520,6 +813,8 @@ def dc_ae_f32c32(name: str, pretrained_path: str) -> DCAEConfig:
elif name in ["dc-ae-f32c32-sana-1.0"]:
cfg_str = (
"latent_channels=32 "
"time_compression_ratio=1 "
"spatial_compression_ratio=32 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[2,2,2,3,3,3] "
"encoder.downsample_block_type=Conv "
@@ -527,8 +822,37 @@ def dc_ae_f32c32(name: str, pretrained_path: str) -> DCAEConfig:
"decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms2d decoder.act=silu "
"scaling_factor=0.41407"
"scaling_factor=0.41407 "
"is_image_model=True"
)
elif name in ["dc-ae-f32c32-sana-1.0-video", "dc-ae-f32t4c256", "dc-ae-f32t4c128", "dc-ae-f32t4c64"]:
cfg_str = (
"time_compression_ratio=4 "
"spatial_compression_ratio=32 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024] encoder.depth_list=[2,2,2,3,3,3] "
"encoder.downsample_block_type=Conv "
"encoder.norm=rms3d "
"encoder.is_video=True "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms3d decoder.act=silu decoder.out_norm=rms3d "
"decoder.is_video=True "
) # make sure there is no trailing blankspace in the last line
if name in ["dc-ae-f32t4c256", "dc-ae-f32t4c128", "dc-ae-f32t4c64"]:
cfg_str += (
"encoder.temporal_downsample=[False,False,False,True,True,False] "
"decoder.temporal_upsample=[False,False,False,True,True,False]"
) # make sure there is preceding blankspace in the first line
if name in ["dc-ae-f32t4c256"]:
cfg_str += " latent_channels=256 " "scaling_factor=0.46505"
elif name in ["dc-ae-f32t4c128"]:
cfg_str += " latent_channels=128"
elif name in ["dc-ae-f32t4c64"]:
cfg_str += " latent_channels=64"
elif name in ["dc-ae-f32c32-sana-1.0-video"]:
cfg_str += " latent_channels=32"
else:
raise NotImplementedError
cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
@@ -547,6 +871,51 @@ def dc_ae_f64c128(name: str, pretrained_path: Optional[str] = None) -> DCAEConfi
"decoder.width_list=[128,256,512,512,1024,1024,2048] decoder.depth_list=[0,5,10,2,2,2,2] "
"decoder.norm=[bn2d,bn2d,bn2d,trms2d,trms2d,trms2d,trms2d] decoder.act=[relu,relu,relu,silu,silu,silu,silu]"
)
elif name in ["dc-ae-f64t4c128"]:
cfg_str = (
"time_compression_ratio=4 "
"spatial_compression_ratio=64 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024,1024] encoder.depth_list=[2,2,2,3,3,3,3] "
"encoder.downsample_block_type=Conv "
"encoder.norm=rms3d "
"encoder.is_video=True "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024,1024] decoder.depth_list=[3,3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms3d decoder.act=silu decoder.out_norm=rms3d "
"decoder.is_video=True "
"encoder.temporal_downsample=[False,False,False,True,True,False,False] "
"decoder.temporal_upsample=[False,False,False,True,True,False,False] "
"latent_channels=128"
)
else:
raise NotImplementedError
cfg = OmegaConf.from_dotlist(cfg_str.split(" "))
cfg: DCAEConfig = OmegaConf.to_object(OmegaConf.merge(OmegaConf.structured(DCAEConfig), cfg))
cfg.pretrained_path = pretrained_path
return cfg


def dc_ae_f64t4c256(name: str, pretrained_path: Optional[str] = None) -> DCAEConfig:
if name in ["dc-ae-f64t4c256"]:
cfg_str = (
"time_compression_ratio=4 "
"spatial_compression_ratio=64 "
"encoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"encoder.width_list=[128,256,512,512,1024,1024,1024] encoder.depth_list=[2,2,2,3,3,3,3] "
"encoder.downsample_block_type=Conv "
"encoder.norm=rms3d "
"encoder.is_video=True "
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024,1024] decoder.depth_list=[3,3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=rms3d decoder.act=silu decoder.out_norm=rms3d "
"decoder.is_video=True "
"encoder.temporal_downsample=[False,False,False,True,True,False,False] "
"decoder.temporal_upsample=[False,False,False,True,True,False,False] "
"latent_channels=256"
)
else:
raise NotImplementedError
cfg = OmegaConf.from_dotlist(cfg_str.split(" "))

opensora/models/dc_ae/efficientvit/models/nn/__init__.py → opensora/models/dc_ae/models/nn/__init__.py View File

@@ -1,5 +1,3 @@
from .act import *
from .drop import *
from .norm import *
from .ops import *
from .triton_rms_norm import *

opensora/models/dc_ae/efficientvit/models/nn/act.py → opensora/models/dc_ae/models/nn/act.py View File

@@ -19,7 +19,8 @@ from typing import Optional

import torch.nn as nn

from ...models.utils import build_kwargs_from_config
from ..nn.vo_ops import build_kwargs_from_config


__all__ = ["build_act"]


+ 98
- 0
opensora/models/dc_ae/models/nn/norm.py View File

@@ -0,0 +1,98 @@
# Copyright 2024 MIT Han Lab
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import torch
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm

from ..nn.vo_ops import build_kwargs_from_config

__all__ = ["LayerNorm2d", "build_norm", "set_norm_eps"]


class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = x - torch.mean(x, dim=1, keepdim=True)
out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps)
if self.elementwise_affine:
out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return out



class RMSNorm2d(nn.Module):
def __init__(
self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True
) -> None:
super().__init__()
self.num_features = num_features
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.parameter.Parameter(torch.empty(self.num_features))
if bias:
self.bias = torch.nn.parameter.Parameter(torch.empty(self.num_features))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype)
if self.elementwise_affine:
x = x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return x


class RMSNorm3d(RMSNorm2d):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype)
if self.elementwise_affine:
x = x * self.weight.view(1, -1, 1, 1, 1) + self.bias.view(1, -1, 1, 1, 1)
return x


# register normalization function here
REGISTERED_NORM_DICT: dict[str, type] = {
"bn2d": nn.BatchNorm2d,
"ln": nn.LayerNorm,
"ln2d": LayerNorm2d,
"rms2d": RMSNorm2d,
"rms3d": RMSNorm3d,
}


def build_norm(name="bn2d", num_features=None, **kwargs) -> Optional[nn.Module]:
if name in ["ln", "ln2d"]:
kwargs["normalized_shape"] = num_features
else:
kwargs["num_features"] = num_features
if name in REGISTERED_NORM_DICT:
norm_cls = REGISTERED_NORM_DICT[name]
args = build_kwargs_from_config(kwargs, norm_cls)
return norm_cls(**args)
else:
return None


def set_norm_eps(model: nn.Module, eps: Optional[float] = None) -> None:
for m in model.modules():
if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)):
if eps is not None:
m.eps = eps

opensora/models/dc_ae/efficientvit/models/nn/ops.py → opensora/models/dc_ae/models/nn/ops.py View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0 # upsample on the temporal dimension as well

from typing import Optional

@@ -20,9 +20,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F

from opensora.models.vae.utils import ChannelChunkConv3d

from ...models.nn.act import build_act
from ...models.nn.norm import build_norm
from ...models.utils import get_same_padding, list_sum, resize, val2list, val2tuple
from ...models.nn.vo_ops import chunked_interpolate, get_same_padding, pixel_shuffle_3d, pixel_unshuffle_3d, resize
from ...utils import list_sum, val2list, val2tuple

__all__ = [
"ConvLayer",
@@ -30,7 +33,7 @@ __all__ = [
"ConvPixelUnshuffleDownSampleLayer",
"PixelUnshuffleChannelAveragingDownSampleLayer",
"ConvPixelShuffleUpSampleLayer",
"ChannelDuplicatingPixelUnshuffleUpSampleLayer",
"ChannelDuplicatingPixelShuffleUpSampleLayer",
"LinearLayer",
"IdentityLayer",
"DSConv",
@@ -63,29 +66,68 @@ class ConvLayer(nn.Module):
dropout=0,
norm="bn2d",
act_func="relu",
is_video=False,
pad_mode_3d="constant",
):
super().__init__()
self.is_video = is_video

if self.is_video:
assert dilation == 1, "only support dilation=1 for 3d conv"
assert kernel_size % 2 == 1, "only support odd kernel size for 3d conv"
self.pad_mode_3d = pad_mode_3d # 3d padding follows CausalConv3d by Hunyuan
# padding = (
# kernel_size // 2,
# kernel_size // 2,
# kernel_size // 2,
# kernel_size // 2,
# kernel_size - 1,
# 0,
# ) # W, H, T
# non-causal padding
padding = (
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
kernel_size // 2,
)
self.padding = padding
self.dropout = nn.Dropout3d(dropout, inplace=False) if dropout > 0 else None
assert isinstance(stride, (int, tuple)), "stride must be an integer or 3-tuple for 3d conv"
self.conv = ChannelChunkConv3d( # padding is handled by F.pad() in forward()
in_channels,
out_channels,
kernel_size=(kernel_size, kernel_size, kernel_size),
stride=(stride, stride, stride) if isinstance(stride, int) else stride,
groups=groups,
bias=use_bias,
)
else:
padding = get_same_padding(kernel_size)
padding *= dilation
self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=padding,
dilation=(dilation, dilation),
groups=groups,
bias=use_bias,
)

padding = get_same_padding(kernel_size)
padding *= dilation

self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=(kernel_size, kernel_size),
stride=(stride, stride),
padding=padding,
dilation=(dilation, dilation),
groups=groups,
bias=use_bias,
)
self.norm = build_norm(norm, num_features=out_channels)
self.act = build_act(act_func)
self.pad = F.pad

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.dropout is not None:
x = self.dropout(x)
if self.is_video: # custom padding for 3d conv
x = self.pad(x, self.padding, mode=self.pad_mode_3d) # "constant" padding defaults to 0
x = self.conv(x)
if self.norm:
x = self.norm(x)
@@ -150,21 +192,44 @@ class PixelUnshuffleChannelAveragingDownSampleLayer(nn.Module):
in_channels: int,
out_channels: int,
factor: int,
temporal_downsample: bool = False, # temporal downsample for 5d input tensor
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert in_channels * factor**2 % out_channels == 0
self.group_size = in_channels * factor**2 // out_channels
self.temporal_downsample = temporal_downsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pixel_unshuffle(x, self.factor)
B, C, H, W = x.shape
x = x.view(B, self.out_channels, self.group_size, H, W)
x = x.mean(dim=2)
if x.dim() == 4:
assert self.in_channels * self.factor**2 % self.out_channels == 0
group_size = self.in_channels * self.factor**2 // self.out_channels
x = F.pixel_unshuffle(x, self.factor)
B, C, H, W = x.shape
x = x.view(B, self.out_channels, group_size, H, W)
x = x.mean(dim=2)
elif x.dim() == 5: # [B, C, T, H, W]
_, _, T, _, _ = x.shape
if self.temporal_downsample and T != 1: # 3d pixel unshuffle
x = pixel_unshuffle_3d(x, self.factor)
assert self.in_channels * self.factor**3 % self.out_channels == 0
group_size = self.in_channels * self.factor**3 // self.out_channels
else: # 2d pixel unshuffle
x = x.permute(0, 2, 1, 3, 4) # [B, T, C, H, W]
x = F.pixel_unshuffle(x, self.factor)
x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
assert self.in_channels * self.factor**2 % self.out_channels == 0
group_size = self.in_channels * self.factor**2 // self.out_channels
B, C, T, H, W = x.shape
x = x.view(B, self.out_channels, group_size, T, H, W)
x = x.mean(dim=2)
else:
raise ValueError(f"Unsupported input dimension: {x.dim()}")
return x

def __repr__(self):
return f"PixelUnshuffleChannelAveragingDownSampleLayer(in_channels={self.in_channels}, out_channels={self.out_channels}, factor={self.factor}), temporal_downsample={self.temporal_downsample}"


class ConvPixelShuffleUpSampleLayer(nn.Module):
def __init__(
@@ -200,10 +265,13 @@ class InterpolateConvUpSampleLayer(nn.Module):
kernel_size: int,
factor: int,
mode: str = "nearest",
is_video: bool = False,
temporal_upsample: bool = False,
) -> None:
super().__init__()
self.factor = factor
self.mode = mode
self.temporal_upsample = temporal_upsample
self.conv = ConvLayer(
in_channels=in_channels,
out_channels=out_channels,
@@ -211,33 +279,66 @@ class InterpolateConvUpSampleLayer(nn.Module):
use_bias=True,
norm=None,
act_func=None,
is_video=is_video,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, scale_factor=self.factor, mode=self.mode)
if x.dim() == 4:
x = F.interpolate(x, scale_factor=self.factor, mode=self.mode)
elif x.dim() == 5:
# [B, C, T, H, W] -> [B, C, T*factor, H*factor, W*factor]
if self.temporal_upsample and x.size(2) != 1: # temporal upsample for video input
x = chunked_interpolate(x, scale_factor=[self.factor, self.factor, self.factor], mode=self.mode)
else:
x = chunked_interpolate(x, scale_factor=[1, self.factor, self.factor], mode=self.mode)
x = self.conv(x)
return x

def __repr__(self):
return f"InterpolateConvUpSampleLayer(factor={self.factor}, mode={self.mode}, temporal_upsample={self.temporal_upsample})"


class ChannelDuplicatingPixelUnshuffleUpSampleLayer(nn.Module):
class ChannelDuplicatingPixelShuffleUpSampleLayer(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
factor: int,
temporal_upsample: bool = False, # upsample on the temporal dimension as well
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.factor = factor
assert out_channels * factor**2 % in_channels == 0
self.repeats = out_channels * factor**2 // in_channels
self.temporal_upsample = temporal_upsample

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.repeat_interleave(self.repeats, dim=1)
x = F.pixel_shuffle(x, self.factor)
if x.dim() == 5:
B, C, T, H, W = x.shape
assert C == self.in_channels

if self.temporal_upsample and T != 1: # video input
repeats = self.out_channels * self.factor**3 // self.in_channels
else:
repeats = self.out_channels * self.factor**2 // self.in_channels

x = x.repeat_interleave(repeats, dim=1)

if x.dim() == 4: # original image-only training
x = F.pixel_shuffle(x, self.factor)
elif x.dim() == 5: # [B, C, T, H, W]
if self.temporal_upsample and T != 1: # video input
x = pixel_shuffle_3d(x, self.factor)
else:
x = x.permute(0, 2, 1, 3, 4) # [B, T, C, H, W]
x = F.pixel_shuffle(x, self.factor) # on H and W only
x = x.permute(0, 2, 1, 3, 4) # [B, C, T, H, W]
return x

def __repr__(self):
return f"ChannelDuplicatingPixelShuffleUpSampleLayer(in_channels={self.in_channels}, out_channels={self.out_channels}, factor={self.factor}, temporal_upsample={self.temporal_upsample})"


class LinearLayer(nn.Module):
def __init__(
@@ -438,6 +539,7 @@ class GLUMBConv(nn.Module):
use_bias=False,
norm=(None, None, "ln2d"),
act_func=("silu", "silu", None),
is_video=False,
):
super().__init__()
use_bias = val2tuple(use_bias, 3)
@@ -454,6 +556,7 @@ class GLUMBConv(nn.Module):
use_bias=use_bias[0],
norm=norm[0],
act_func=act_func[0],
is_video=is_video,
)
self.depth_conv = ConvLayer(
mid_channels * 2,
@@ -464,6 +567,7 @@ class GLUMBConv(nn.Module):
use_bias=use_bias[1],
norm=norm[1],
act_func=None,
is_video=is_video,
)
self.point_conv = ConvLayer(
mid_channels,
@@ -472,6 +576,7 @@ class GLUMBConv(nn.Module):
use_bias=use_bias[2],
norm=norm[2],
act_func=act_func[2],
is_video=is_video,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -498,6 +603,7 @@ class ResBlock(nn.Module):
use_bias=False,
norm=("bn2d", "bn2d"),
act_func=("relu6", None),
is_video=False,
):
super().__init__()
use_bias = val2tuple(use_bias, 2)
@@ -514,6 +620,7 @@ class ResBlock(nn.Module):
use_bias=use_bias[0],
norm=norm[0],
act_func=act_func[0],
is_video=is_video,
)
self.conv2 = ConvLayer(
mid_channels,
@@ -523,6 +630,7 @@ class ResBlock(nn.Module):
use_bias=use_bias[1],
norm=norm[1],
act_func=act_func[1],
is_video=is_video,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -547,6 +655,7 @@ class LiteMLA(nn.Module):
kernel_func="relu",
scales: tuple[int, ...] = (5,),
eps=1.0e-15,
is_video=False,
):
super().__init__()
self.eps = eps
@@ -566,11 +675,13 @@ class LiteMLA(nn.Module):
use_bias=use_bias[0],
norm=norm[0],
act_func=act_func[0],
is_video=is_video,
)
conv_class = nn.Conv2d if not is_video else ChannelChunkConv3d
self.aggreg = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
conv_class(
3 * total_dim,
3 * total_dim,
scale,
@@ -578,7 +689,7 @@ class LiteMLA(nn.Module):
groups=3 * total_dim,
bias=use_bias[0],
),
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
conv_class(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
)
for scale in scales
]
@@ -592,24 +703,41 @@ class LiteMLA(nn.Module):
use_bias=use_bias[1],
norm=norm[1],
act_func=act_func[1],
is_video=is_video,
)

@torch.autocast(device_type="cuda", enabled=False)
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
B, _, H, W = list(qkv.size())
if qkv.ndim == 5:
B, _, T, H, W = list(qkv.size())
is_video = True
else:
B, _, H, W = list(qkv.size())
is_video = False

if qkv.dtype == torch.float16:
qkv = qkv.float()

qkv = torch.reshape(
qkv,
(
B,
-1,
3 * self.dim,
H * W,
),
)
if qkv.ndim == 4:
qkv = torch.reshape(
qkv,
(
B,
-1,
3 * self.dim,
H * W,
),
)
elif qkv.ndim == 5:
qkv = torch.reshape(
qkv,
(
B,
-1,
3 * self.dim,
H * W * T,
),
)
q, k, v = (
qkv[:, :, 0 : self.dim],
qkv[:, :, self.dim : 2 * self.dim],
@@ -630,7 +758,10 @@ class LiteMLA(nn.Module):
out = out.float()
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)

out = torch.reshape(out, (B, -1, H, W))
if not is_video:
out = torch.reshape(out, (B, -1, H, W))
else:
out = torch.reshape(out, (B, -1, T, H, W))
return out

@torch.autocast(device_type="cuda", enabled=False)
@@ -674,11 +805,19 @@ class LiteMLA(nn.Module):
multi_scale_qkv.append(op(qkv))
qkv = torch.cat(multi_scale_qkv, dim=1)

H, W = list(qkv.size())[-2:]
if H * W > self.dim:
out = self.relu_linear_att(qkv).to(qkv.dtype)
else:
out = self.relu_quadratic_att(qkv)
if qkv.ndim == 4:
H, W = list(qkv.size())[-2:]
# num_tokens = H * W
elif qkv.ndim == 5:
_, _, T, H, W = list(qkv.size())
# num_tokens = H * W * T

# if num_tokens > self.dim:
out = self.relu_linear_att(qkv).to(qkv.dtype)
# else:
# if self.is_video:
# raise NotImplementedError("Video is not supported for quadratic attention")
# out = self.relu_quadratic_att(qkv)
out = self.proj(out)

return out
@@ -696,6 +835,7 @@ class EfficientViTBlock(nn.Module):
act_func: str = "hswish",
context_module: str = "LiteMLA",
local_module: str = "MBConv",
is_video: bool = False,
):
super().__init__()
if context_module == "LiteMLA":
@@ -707,6 +847,7 @@ class EfficientViTBlock(nn.Module):
dim=dim,
norm=(None, norm),
scales=scales,
is_video=is_video,
),
IdentityLayer(),
)
@@ -721,6 +862,7 @@ class EfficientViTBlock(nn.Module):
use_bias=(True, True, False),
norm=(None, None, norm),
act_func=(act_func, act_func, None),
is_video=is_video,
),
IdentityLayer(),
)
@@ -733,6 +875,7 @@ class EfficientViTBlock(nn.Module):
use_bias=(True, True, False),
norm=(None, None, norm),
act_func=(act_func, act_func, None),
is_video=is_video,
),
IdentityLayer(),
)

+ 244
- 0
opensora/models/dc_ae/models/nn/vo_ops.py View File

@@ -0,0 +1,244 @@
import math
from inspect import signature
from typing import Any, Callable, Optional, Union

import torch
import torch.nn.functional as F

VERBOSE = False


def pixel_shuffle_3d(x, upscale_factor):
"""
3D pixelshuffle 操作。
"""
B, C, T, H, W = x.shape
r = upscale_factor
assert C % (r * r * r) == 0, "通道数必须是上采样因子的立方倍数"

C_new = C // (r * r * r)
x = x.view(B, C_new, r, r, r, T, H, W)
if VERBOSE:
print("x.view:")
print(x)
print("x.view.shape:")
print(x.shape)

x = x.permute(0, 1, 5, 2, 6, 3, 7, 4)
if VERBOSE:
print("x.permute:")
print(x)
print("x.permute.shape:")
print(x.shape)

y = x.reshape(B, C_new, T * r, H * r, W * r)
return y


def pixel_unshuffle_3d(x, downsample_factor):
"""
3D pixel unshuffle 操作。
"""
B, C, T, H, W = x.shape

r = downsample_factor
assert T % r == 0, f"时间维度必须是下采样因子的倍数, got shape {x.shape}"
assert H % r == 0, f"高度维度必须是下采样因子的倍数, got shape {x.shape}"
assert W % r == 0, f"宽度维度必须是下采样因子的倍数, got shape {x.shape}"
T_new = T // r
H_new = H // r
W_new = W // r
C_new = C * (r * r * r)

x = x.view(B, C, T_new, r, H_new, r, W_new, r)
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6)
y = x.reshape(B, C_new, T_new, H_new, W_new)
return y


def test_pixel_shuffle_3d():
# 输入张量 (B, C, T, H, W) = (1, 16, 2, 4, 4)
x = torch.arange(1, 1 + 1 * 16 * 2 * 4 * 4).view(1, 16, 2, 4, 4).float()
print("x:")
print(x)
print("x.shape:")
print(x.shape)

upscale_factor = 2

# 使用自定义 pixelshuffle_3d
y = pixel_shuffle_3d(x, upscale_factor)
print("pixelshuffle_3d 结果:")
print(y)
print("输出形状:", y.shape)
# 预期输出形状: (1, 1, 4, 8, 8)
# 因为:
# - 通道数从8变为1 (8 /(2*2*2))
# - 时间维度从2变为4 (2*2)
# - 高度从4变为8 (4*2)
# - 宽度从4变为8 (4*2)

print(torch.allclose(x, pixel_unshuffle_3d(y, upscale_factor)))


def chunked_interpolate(x, scale_factor, mode="nearest"):
"""
Interpolate large tensors by chunking along the channel dimension. https://discuss.pytorch.org/t/error-using-f-interpolate-for-large-3d-input/207859
Only supports 'nearest' interpolation mode.

Args:
x (torch.Tensor): Input tensor (B, C, D, H, W)
scale_factor: Tuple of scaling factors (d, h, w)

Returns:
torch.Tensor: Interpolated tensor
"""
assert (
mode == "nearest"
), "Only the nearest mode is supported" # actually other modes are theoretically supported but not tested
if len(x.shape) != 5:
raise ValueError("Expected 5D input tensor (B, C, D, H, W)")

# Calculate max chunk size to avoid int32 overflow. num_elements < max_int32
# Max int32 is 2^31 - 1
max_elements_per_chunk = 2**31 - 1

# Calculate output spatial dimensions
out_d = math.ceil(x.shape[2] * scale_factor[0])
out_h = math.ceil(x.shape[3] * scale_factor[1])
out_w = math.ceil(x.shape[4] * scale_factor[2])

# Calculate max channels per chunk to stay under limit
elements_per_channel = out_d * out_h * out_w
max_channels = max_elements_per_chunk // (x.shape[0] * elements_per_channel)

# Use smaller of max channels or input channels
chunk_size = min(max_channels, x.shape[1])

# Ensure at least 1 channel per chunk
chunk_size = max(1, chunk_size)
if VERBOSE:
print(f"Input channels: {x.shape[1]}")
print(f"Chunk size: {chunk_size}")
print(f"max_channels: {max_channels}")
print(f"num_chunks: {math.ceil(x.shape[1] / chunk_size)}")

chunks = []
for i in range(0, x.shape[1], chunk_size):
start_idx = i
end_idx = min(i + chunk_size, x.shape[1])

chunk = x[:, start_idx:end_idx, :, :, :]

interpolated_chunk = F.interpolate(chunk, scale_factor=scale_factor, mode="nearest")

chunks.append(interpolated_chunk)

if not chunks:
raise ValueError(f"No chunks were generated. Input shape: {x.shape}")

# Concatenate chunks along channel dimension
return torch.cat(chunks, dim=1)


def test_chunked_interpolate():
# Test case 1: Basic upscaling with scale_factor
x1 = torch.randn(2, 16, 16, 32, 32).cuda()
scale_factor = (2.0, 2.0, 2.0)
assert torch.allclose(
chunked_interpolate(x1, scale_factor=scale_factor), F.interpolate(x1, scale_factor=scale_factor, mode="nearest")
)

# Test case 3: Downscaling with scale_factor
x3 = torch.randn(2, 16, 32, 64, 64).cuda()
scale_factor = (0.5, 0.5, 0.5)
assert torch.allclose(
chunked_interpolate(x3, scale_factor=scale_factor), F.interpolate(x3, scale_factor=scale_factor, mode="nearest")
)

# Test case 4: Different scales per dimension
x4 = torch.randn(2, 16, 16, 32, 32).cuda()
scale_factor = (2.0, 1.5, 1.5)
assert torch.allclose(
chunked_interpolate(x4, scale_factor=scale_factor), F.interpolate(x4, scale_factor=scale_factor, mode="nearest")
)

# Test case 5: Large input tensor
x5 = torch.randn(2, 16, 64, 128, 128).cuda()
scale_factor = (2.0, 2.0, 2.0)
assert torch.allclose(
chunked_interpolate(x5, scale_factor=scale_factor), F.interpolate(x5, scale_factor=scale_factor, mode="nearest")
)

# Test case 7: Chunk size equal to input depth
x7 = torch.randn(2, 16, 8, 32, 32).cuda()
scale_factor = (2.0, 2.0, 2.0)
assert torch.allclose(
chunked_interpolate(x7, scale_factor=scale_factor), F.interpolate(x7, scale_factor=scale_factor, mode="nearest")
)

# Test case 8: Single channel input
x8 = torch.randn(2, 1, 16, 32, 32).cuda()
scale_factor = (2.0, 2.0, 2.0)
assert torch.allclose(
chunked_interpolate(x8, scale_factor=scale_factor), F.interpolate(x8, scale_factor=scale_factor, mode="nearest")
)

# Test case 9: Minimal batch size
x9 = torch.randn(1, 16, 32, 64, 64).cuda()
scale_factor = (0.5, 0.5, 0.5)
assert torch.allclose(
chunked_interpolate(x9, scale_factor=scale_factor), F.interpolate(x9, scale_factor=scale_factor, mode="nearest")
)

# Test case 10: Non-power-of-2 dimensions
x10 = torch.randn(2, 16, 15, 31, 31).cuda()
scale_factor = (2.0, 2.0, 2.0)
assert torch.allclose(
chunked_interpolate(x10, scale_factor=scale_factor),
F.interpolate(x10, scale_factor=scale_factor, mode="nearest"),
)

# Test case 11: large output tensor


def get_same_padding(kernel_size: Union[int, tuple[int, ...]]) -> Union[int, tuple[int, ...]]:
if isinstance(kernel_size, tuple):
return tuple([get_same_padding(ks) for ks in kernel_size])
else:
assert kernel_size % 2 > 0, "kernel size should be odd number"
return kernel_size // 2


def resize(
x: torch.Tensor,
size: Optional[Any] = None,
scale_factor: Optional[list[float]] = None,
mode: str = "bicubic",
align_corners: Optional[bool] = False,
) -> torch.Tensor:
if mode in {"bilinear", "bicubic"}:
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
)
elif mode in {"nearest", "area"}:
return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode)
else:
raise NotImplementedError(f"resize(mode={mode}) not implemented.")


def build_kwargs_from_config(config: dict, target_func: Callable) -> dict[str, Any]:
valid_keys = list(signature(target_func).parameters)
kwargs = {}
for key in config:
if key in valid_keys:
kwargs[key] = config[key]
return kwargs


if __name__ == "__main__":
test_chunked_interpolate()

+ 3
- 0
opensora/models/dc_ae/utils/__init__.py View File

@@ -0,0 +1,3 @@
from .init import *
from .list import *


opensora/models/dc_ae/efficientvit/apps/utils/init.py → opensora/models/dc_ae/utils/init.py View File

@@ -64,8 +64,7 @@ def init_modules(model: Union[nn.Module, list[nn.Module]], init_type="trunc_norm


def zero_last_gamma(model: nn.Module, init_val=0) -> None:
import efficientvit.models.nn.ops as ops

import opensora.models.dc_ae.models.nn.ops as ops
for m in model.modules():
if isinstance(m, ops.ResidualBlock) and isinstance(m.shortcut, ops.IdentityLayer):
if isinstance(m.main, (ops.DSConv, ops.MBConv, ops.FusedMBConv)):
@@ -82,3 +81,4 @@ def zero_last_gamma(model: nn.Module, init_val=0) -> None:
norm = getattr(parent_module, "norm", None)
if norm is not None:
nn.init.constant_(norm.weight, init_val)


opensora/models/dc_ae/efficientvit/models/utils/list.py → opensora/models/dc_ae/utils/list.py View File

@@ -65,3 +65,4 @@ def squeeze_list(x: Optional[list]) -> Union[list, Any]:
return x[0]
else:
return x


+ 40
- 11
opensora/models/mmdit/model.py View File

@@ -57,6 +57,7 @@ class MMDiTConfig:
fused_qkv: bool = True
grad_ckpt_settings: tuple[int, int] | None = None
use_liger_rope: bool = False
patch_size: int = 2

def get(self, attribute_name, default=None):
return getattr(self, attribute_name, default)
@@ -74,28 +75,41 @@ class MMDiTModel(nn.Module):
self.config = config
self.in_channels = config.in_channels
self.out_channels = self.in_channels
self.patch_size = config.patch_size

if config.hidden_size % config.num_heads != 0:
raise ValueError(f"Hidden size {config.hidden_size} must be divisible by num_heads {config.num_heads}")
raise ValueError(
f"Hidden size {config.hidden_size} must be divisible by num_heads {config.num_heads}"
)

pe_dim = config.hidden_size // config.num_heads
if sum(config.axes_dim) != pe_dim:
raise ValueError(f"Got {config.axes_dim} but expected positional dim {pe_dim}")
raise ValueError(
f"Got {config.axes_dim} but expected positional dim {pe_dim}"
)

self.hidden_size = config.hidden_size
self.num_heads = config.num_heads
pe_embedder_cls = LigerEmbedND if config.use_liger_rope else EmbedND
self.pe_embedder = pe_embedder_cls(dim=pe_dim, theta=config.theta, axes_dim=config.axes_dim)
self.pe_embedder = pe_embedder_cls(
dim=pe_dim, theta=config.theta, axes_dim=config.axes_dim
)

self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(config.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if config.guidance_embed else nn.Identity()
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if config.guidance_embed
else nn.Identity()
)
self.cond_in = (
nn.Linear(self.in_channels + 4, self.hidden_size, bias=True) if config.cond_embed else nn.Identity()
) # +4 is due to the 4 channels of the mask
nn.Linear(
self.in_channels + self.patch_size**2, self.hidden_size, bias=True
)
if config.cond_embed
else nn.Identity()
)
self.txt_in = nn.Linear(config.context_in_dim, self.hidden_size)

self.double_blocks = nn.ModuleList(
@@ -114,7 +128,10 @@ class MMDiTModel(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size, self.num_heads, mlp_ratio=config.mlp_ratio, fused_qkv=config.fused_qkv
self.hidden_size,
self.num_heads,
mlp_ratio=config.mlp_ratio,
fused_qkv=config.fused_qkv,
)
for _ in range(config.depth_single_blocks)
]
@@ -165,7 +182,9 @@ class MMDiTModel(nn.Module):
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.config.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y_vec)

@@ -198,7 +217,9 @@ class MMDiTModel(nn.Module):
guidance: Tensor | None = None,
**kwargs,
) -> Tensor:
img, txt, vec, pe = self.prepare_block_inputs(img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance)
img, txt, vec, pe = self.prepare_block_inputs(
img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance
)

for block in self.double_blocks:
img, txt = auto_grad_checkpoint(block, img, txt, vec, pe)
@@ -223,7 +244,9 @@ class MMDiTModel(nn.Module):
guidance: Tensor | None = None,
**kwargs,
) -> Tensor:
img, txt, vec, pe = self.prepare_block_inputs(img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance)
img, txt, vec, pe = self.prepare_block_inputs(
img, img_ids, txt, txt_ids, timesteps, y_vec, cond, guidance
)

ckpt_depth_double = self.config.grad_ckpt_settings[0]
for block in self.double_blocks[:ckpt_depth_double]:
@@ -270,5 +293,11 @@ def Flux(
else:
model = model.to(torch_dtype)
if from_pretrained:
model = load_checkpoint(model, from_pretrained, cache_dir=cache_dir, device_map=device_map, strict=strict_load)
model = load_checkpoint(
model,
from_pretrained,
cache_dir=cache_dir,
device_map=device_map,
strict=strict_load,
)
return model

+ 29
- 8
opensora/utils/ckpt.py View File

@@ -16,6 +16,7 @@ from colossalai.utils.safetensors import save as async_save
from colossalai.zero.low_level import LowLevelZeroOptimizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from tensornvme.async_file_io import AsyncFileWriter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

@@ -87,6 +88,7 @@ def load_checkpoint(
device_map: torch.device | str = "cpu",
cai_model_name: str = "model",
strict: bool = False,
rename_keys: dict = None, # rename keys in the checkpoint to support fine-tuning with a different model architecture; map old_key_prefix to new_key_prefix
) -> nn.Module:
"""
Loads a checkpoint into model from a path. Support three types of checkpoints:
@@ -111,11 +113,26 @@ def load_checkpoint(

log_message(f"Loading checkpoint from {path}")
if path.endswith(".safetensors"):
ckpt = load_file(path, device="cpu")
# ckpt = load_file(path, device=str(device_map))
ckpt = load_file(path, device=torch.cuda.current_device())

if rename_keys is not None:
# rename keys in the loaded state_dict with old_key_prefix to with new_key_prefix.
renamed_ckpt = {}
for old_key, v in ckpt.items():
new_key = old_key
for old_key_prefix, new_key_prefix in rename_keys.items():
if old_key_prefix in old_key:
new_key = old_key.replace(old_key_prefix, new_key_prefix)
print(f"Renamed {old_key} to {new_key} in the loaded state_dict")
break
renamed_ckpt[new_key] = v
ckpt = renamed_ckpt

missing, unexpected = model.load_state_dict(ckpt, strict=strict)
print_load_warning(missing, unexpected)
elif path.endswith(".pt") or path.endswith(".pth"):
ckpt = torch.load(path, map_location="cpu")
ckpt = torch.load(path, map_location=device_map)
missing, unexpected = model.load_state_dict(ckpt, strict=strict)
print_load_warning(missing, unexpected)
else:
@@ -287,7 +304,7 @@ def master_weights_gathering(model: torch.nn.Module, optimizer: LowLevelZeroOpti
model_shape_dict (dict): The shape of the model parameters.
device (torch.device): The device to gather the model to.
"""
pg = get_data_parallel_group()
pg = get_data_parallel_group(get_mixed_dp_pg=True)
world_size = dist.get_world_size(pg)
w2m = optimizer.get_working_to_master_map()
for name, param in model.named_parameters():
@@ -303,29 +320,33 @@ def master_weights_gathering(model: torch.nn.Module, optimizer: LowLevelZeroOpti


def load_master_weights(model: torch.nn.Module, optimizer: LowLevelZeroOptimizer, state_dict: dict) -> None:
pg = get_data_parallel_group()
pg = get_data_parallel_group(get_mixed_dp_pg=True)
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
w2m = optimizer.get_working_to_master_map()
for name, param in model.named_parameters():
master_p = w2m[id(param)]
assert param.numel() == len(master_p)
target_chunk = state_dict[name].chunk(world_size)[rank]
state = state_dict[name].view(-1)
padding_size = len(master_p) * world_size - len(state)
state = torch.nn.functional.pad(state, [0, padding_size])
target_chunk = state.chunk(world_size)[rank].to(master_p.dtype)
master_p[: len(target_chunk)].copy_(target_chunk)


class CheckpointIO:
def __init__(self, n_write_entries: int = 32):
self.n_write_entries = n_write_entries
self.writer
self.writer: Optional[AsyncFileWriter] = None
self.pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None
self.master_pinned_state_dict: Optional[Dict[str, torch.Tensor]] = None
self.master_writer
self.master_writer: Optional[AsyncFileWriter] = None

def _sync_io(self):
if self.writer is not None:
self.writer.synchronize()
self.writer = None
if self.master_writer is not None:
self.master_writer.synchronize()
self.master_writer = None

def __del__(self):


+ 10
- 2
opensora/utils/config.py View File

@@ -48,6 +48,10 @@ def parse_configs() -> Config:
cfg = read_config(config)
cfg = merge_args(cfg, args)
cfg.config_path = config

# hard-coded for spatial compression
if cfg.get("ae_spatial_compression", None) is not None:
os.environ["AE_SPATIAL_COMPRESSION"] = str(cfg.ae_spatial_compression)
return cfg


@@ -142,7 +146,9 @@ def sync_string(value: str):
bytes_value = value.encode("utf-8")
max_len = 256
bytes_tensor = torch.zeros(max_len, dtype=torch.uint8).cuda()
bytes_tensor[: len(bytes_value)] = torch.tensor(list(bytes_value), dtype=torch.uint8)
bytes_tensor[: len(bytes_value)] = torch.tensor(
list(bytes_value), dtype=torch.uint8
)
torch.distributed.broadcast(bytes_tensor, 0)
synced_value = bytes_tensor.cpu().numpy().tobytes().decode("utf-8").rstrip("\x00")
return synced_value
@@ -167,7 +173,9 @@ def create_experiment_workspace(
experiment_index = datetime.now().strftime("%y%m%d_%H%M%S")
experiment_index = sync_string(experiment_index)
# Create an experiment folder
model_name = "-" + model_name.replace("/", "-") if model_name is not None else ""
model_name = (
"-" + model_name.replace("/", "-") if model_name is not None else ""
)
exp_name = f"{experiment_index}{model_name}"
exp_dir = f"{output_dir}/{exp_name}"
if is_main_process():


+ 107
- 33
opensora/utils/sampling.py View File

@@ -14,7 +14,11 @@ from opensora.datasets.aspect import get_image_size
from opensora.models.mmdit.model import MMDiTModel
from opensora.models.text.conditioner import HFEmbedder
from opensora.registry import MODELS, build_module
from opensora.utils.inference import SamplingMethod, collect_references_batch, prepare_inference_condition
from opensora.utils.inference import (
SamplingMethod,
collect_references_batch,
prepare_inference_condition,
)

# ======================================================
# Sampling Options
@@ -85,9 +89,13 @@ def sanitize_sampling_option(sampling_option: SamplingOption) -> SamplingOption:
Returns:
SamplingOption: The sanitized sampling options.
"""
if sampling_option.resolution is not None or sampling_option.aspect_ratio is not None:
if (
sampling_option.resolution is not None
or sampling_option.aspect_ratio is not None
):
assert (
sampling_option.resolution is not None and sampling_option.aspect_ratio is not None
sampling_option.resolution is not None
and sampling_option.aspect_ratio is not None
), "Both resolution and aspect ratio must be provided"
resolution = sampling_option.resolution
aspect_ratio = sampling_option.aspect_ratio
@@ -137,7 +145,12 @@ class Denoiser(ABC):

@abstractmethod
def prepare_guidance(
self, text: list[str], optional_models: dict[str, nn.Module], device: torch.device, dtype: torch.dtype, **kwargs
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> dict[str, Tensor]:
"""Prepare the guidance for the model. This method will alter text."""

@@ -159,13 +172,20 @@ class I2VDenoiser(Denoiser):
image_osci = kwargs.pop("image_osci", False)
scale_temporal_osci = kwargs.pop("scale_temporal_osci", False)

guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
# patch size
patch_size = kwargs.pop("patch_size", 2)

guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
# timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
b, c, t, w, h = masked_ref.size()
cond = torch.cat((masks, masked_ref), dim=1)
cond = pack(cond)
cond = pack(cond, patch_size=patch_size)
kwargs["cond"] = torch.cat([cond, cond, torch.zeros_like(cond)], dim=0)

# forward preparation
@@ -182,14 +202,18 @@ class I2VDenoiser(Denoiser):

# prepare guidance
text_gs = get_oscillation_gs(guidance, i) if text_osci else guidance
image_gs = get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
image_gs = (
get_oscillation_gs(guidance_img, i) if image_osci else guidance_img
)
cond, uncond, uncond_2 = pred.chunk(3, dim=0)
if image_gs > 1.0 and scale_temporal_osci:
# image_gs decrease with each denoising step
step_upper_image_gs = torch.linspace(image_gs, 1.0, len(timesteps))[i]
# image_gs increase along the temporal axis of the latent video
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[None, None, :, None, None].repeat(b, c, 1, h, w)
image_gs = pack(image_gs).to(cond.device, cond.dtype)
image_gs = torch.linspace(1.0, step_upper_image_gs, t)[
None, None, :, None, None
].repeat(b, c, 1, h, w)
image_gs = pack(image_gs, patch_size=patch_size).to(cond.device, cond.dtype)

# update
pred = uncond_2 + image_gs * (uncond - uncond_2) + text_gs * (cond - uncond)
@@ -202,7 +226,12 @@ class I2VDenoiser(Denoiser):
return img

def prepare_guidance(
self, text: list[str], optional_models: dict[str, nn.Module], device: torch.device, dtype: torch.dtype, **kwargs
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> tuple[list[str], dict[str, Tensor]]:
ret = {}

@@ -222,10 +251,14 @@ class DistilledDenoiser(Denoiser):
timesteps = kwargs.pop("timesteps")
guidance = kwargs.pop("guidance")

guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
guidance_vec = torch.full(
(img.shape[0],), guidance, device=img.device, dtype=img.dtype
)
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
# timesteps
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
t_vec = torch.full(
(img.shape[0],), t_curr, dtype=img.dtype, device=img.device
)
# forward
pred = model(
img=img,
@@ -238,7 +271,12 @@ class DistilledDenoiser(Denoiser):
return img

def prepare_guidance(
self, text: list[str], optional_models: dict[str, nn.Module], device: torch.device, dtype: torch.dtype, **kwargs
self,
text: list[str],
optional_models: dict[str, nn.Module],
device: torch.device,
dtype: torch.dtype,
**kwargs,
) -> tuple[list[str], dict[str, Tensor]]:
return text, {}

@@ -258,7 +296,9 @@ def time_shift(alpha: float, t: Tensor) -> Tensor:
return alpha * t / (1 + (alpha - 1) * t)


def get_res_lin_function(x1: float = 256, y1: float = 1, x2: float = 4096, y2: float = 3) -> callable:
def get_res_lin_function(
x1: float = 256, y1: float = 1, x2: float = 4096, y2: float = 3
) -> callable:
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
@@ -281,7 +321,9 @@ def get_schedule(
if shift_alpha is None:
# estimate mu based on linear estimation between two points
# spatial scale
shift_alpha = get_res_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
shift_alpha = get_res_lin_function(y1=base_shift, y2=max_shift)(
image_seq_len
)
# temporal scale
shift_alpha *= math.sqrt(num_frames)
# calculate shifted timesteps
@@ -290,9 +332,6 @@ def get_schedule(
return timesteps.tolist()


D = int(os.environ.get("VO_ASPECT_DIV", 16))


def get_noise(
num_samples: int,
height: int,
@@ -319,6 +358,7 @@ def get_noise(
Returns:
Tensor: The noise tensor.
"""
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return torch.randn(
num_samples,
channel,
@@ -333,10 +373,15 @@ def get_noise(


def pack(x: Tensor, patch_size: int = 2) -> Tensor:
return rearrange(x, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size)
return rearrange(
x, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
)


def unpack(x: Tensor, height: int, width: int, num_frames: int, patch_size: int = 2) -> Tensor:
def unpack(
x: Tensor, height: int, width: int, num_frames: int, patch_size: int = 2
) -> Tensor:
D = int(os.environ.get("AE_SPATIAL_COMPRESSION", 16))
return rearrange(
x,
"b (t h w) (c ph pw) -> b c t (h ph) (w pw)",
@@ -383,7 +428,9 @@ def prepare(
if bs != len(prompt):
bs = len(prompt)

img = rearrange(img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size)
img = rearrange(
img, "b c t (h ph) (w pw) -> b (t h w) (c ph pw)", ph=patch_size, pw=patch_size
)
if img.shape[0] != bs:
img = repeat(img, "b ... -> (repeat b) ...", repeat=bs // img.shape[0])

@@ -478,20 +525,34 @@ def prepare_models(
Returns:
tuple[nn.Module, nn.Module, nn.Module, nn.Module, dict[str, nn.Module]]: The models. They are the diffusion model, the autoencoder model, the T5 model, the CLIP model, and the optional models.
"""
model_device = "cpu" if offload_model and cfg.get("img_flux", None) is not None else device
model_device = (
"cpu" if offload_model and cfg.get("img_flux", None) is not None else device
)

model = build_module(cfg.model, MODELS, device_map=model_device, torch_dtype=dtype).eval()
model_ae = build_module(cfg.ae, MODELS, device_map=model_device, torch_dtype=dtype).eval()
model = build_module(
cfg.model, MODELS, device_map=model_device, torch_dtype=dtype
).eval()
model_ae = build_module(
cfg.ae, MODELS, device_map=model_device, torch_dtype=dtype
).eval()
model_t5 = build_module(cfg.t5, MODELS, device_map=device, torch_dtype=dtype).eval()
model_clip = build_module(cfg.clip, MODELS, device_map=device, torch_dtype=dtype).eval()
model_clip = build_module(
cfg.clip, MODELS, device_map=device, torch_dtype=dtype
).eval()
if cfg.get("pretrained_lora_path", None) is not None:
model = PeftModel.from_pretrained(model, cfg.pretrained_lora_path, is_trainable=False)
model = PeftModel.from_pretrained(
model, cfg.pretrained_lora_path, is_trainable=False
)

# optional models
optional_models = {}
if cfg.get("img_flux", None) is not None:
model_img_flux = build_module(cfg.img_flux, MODELS, device_map=device, torch_dtype=dtype).eval()
model_ae_img_flux = build_module(cfg.img_flux_ae, MODELS, device_map=device, torch_dtype=dtype).eval()
model_img_flux = build_module(
cfg.img_flux, MODELS, device_map=device, torch_dtype=dtype
).eval()
model_ae_img_flux = build_module(
cfg.img_flux_ae, MODELS, device_map=device, torch_dtype=dtype
).eval()
optional_models["img_flux"] = model_img_flux
optional_models["img_flux_ae"] = model_ae_img_flux

@@ -549,9 +610,15 @@ def prepare_api(
# random seed if not provided
seed = opt.seed if opt.seed is not None else random.randint(0, 2**32 - 1)
if opt.is_causal_vae:
num_frames = 1 if opt.num_frames == 1 else (opt.num_frames - 1) // opt.temporal_reduction + 1
num_frames = (
1
if opt.num_frames == 1
else (opt.num_frames - 1) // opt.temporal_reduction + 1
)
else:
num_frames = 1 if opt.num_frames == 1 else opt.num_frames // opt.temporal_reduction
num_frames = (
1 if opt.num_frames == 1 else opt.num_frames // opt.temporal_reduction
)

z = get_noise(
len(text),
@@ -571,7 +638,11 @@ def prepare_api(
if cond_type != "t2v" and "ref" in kwargs:
reference_path_list = kwargs.pop("ref")
references = collect_references_batch(
reference_path_list, cond_type, model_ae, (opt.height, opt.width), is_causal=opt.is_causal_vae
reference_path_list,
cond_type,
model_ae,
(opt.height, opt.width),
is_causal=opt.is_causal_vae,
)
elif cond_type != "t2v":
print(
@@ -603,7 +674,9 @@ def prepare_api(

if opt.method in [SamplingMethod.I2V]:
# prepare references
masks, masked_ref = prepare_inference_condition(z, cond_type, ref_list=references, causal=opt.is_causal_vae)
masks, masked_ref = prepare_inference_condition(
z, cond_type, ref_list=references, causal=opt.is_causal_vae
)
inp["masks"] = masks
inp["masked_ref"] = masked_ref
inp["sigma_min"] = sigma_min
@@ -619,6 +692,7 @@ def prepare_api(
opt.scale_temporal_osci and "i2v" in cond_type
), # don't use temporal osci for v2v or t2v
flow_shift=opt.flow_shift,
patch_size=patch_size,
)

x = unpack(x, opt.height, opt.width, num_frames, patch_size=patch_size)


+ 45
- 36
opensora/utils/train.py View File

@@ -184,7 +184,7 @@ def dropout_condition(prob: float, txt: torch.Tensor, null_txt: torch.Tensor) ->


def prepare_visual_condition_uncausal(
x: torch.Tensor, condition_config: dict, model_ae: torch.nn.Module
x: torch.Tensor, condition_config: dict, model_ae: torch.nn.Module, pad: bool = False
) -> torch.Tensor:
"""
Prepare the visual condition for the model.
@@ -199,7 +199,7 @@ def prepare_visual_condition_uncausal(
"""
# x has shape [b, c, t, h, w], where b is the batch size
B = x.shape[0]
C = model_ae.z_channels
C = model_ae.cfg.latent_channels
T, H, W = model_ae.get_latent_size(x.shape[-3:])

# Initialize masks tensor to match the shape of x, but only the time dimension will be masked
@@ -211,12 +211,12 @@ def prepare_visual_condition_uncausal(
x_0 = torch.zeros(B, C, T, H, W).to(x.device, x.dtype)
if T > 1: # video
# certain v2v conditions not are applicable for short videos
if T <= 32 // model_ae.compression[0]:
if T <= 32 // model_ae.time_compression_ratio:
condition_config.pop("v2v_head", None) # given first 32 frames
condition_config.pop("v2v_tail", None) # given last 32 frames
condition_config.pop("v2v_head_easy", None) # given first 64 frames
condition_config.pop("v2v_tail_easy", None) # given last 64 frames
if T <= 64 // model_ae.compression[0]:
if T <= 64 // model_ae.time_compression_ratio:
condition_config.pop("v2v_head_easy", None) # given first 64 frames
condition_config.pop("v2v_tail_easy", None) # given last 64 frames

@@ -230,9 +230,12 @@ def prepare_visual_condition_uncausal(
if mask_cond == "i2v_head": # NOTE: modify video, mask first latent frame
# padded video such that the first latent frame correspond to image only
masks[i, :, 0, :, :] = 1
pad_num = model_ae.compression[0] - 1 # 32 --> new video: 7 + (1+31-7)
padded_x = torch.cat([x[i, :, :1]] * pad_num + [x[i, :, :-pad_num]], dim=1).unsqueeze(0)
x_0[i] = model_ae.encode(padded_x)[0]
if pad:
pad_num = model_ae.time_compression_ratio - 1 # 32 --> new video: 7 + (1+31-7)
padded_x = torch.cat([x[i, :, :1]] * pad_num + [x[i, :, :-pad_num]], dim=1).unsqueeze(0)
x_0[i] = model_ae.encode(padded_x)[0]
else:
x_0[i] = model_ae.encode(x[i : i + 1])[0]
# condition: encode the image only
latent[i, :, :1, :, :] = model_ae.encode(
x[i, :, :1, :, :].unsqueeze(0)
@@ -241,47 +244,55 @@ def prepare_visual_condition_uncausal(
# pad video such that first and last latent frame correspond to image only
masks[i, :, 0, :, :] = 1
masks[i, :, -1, :, :] = 1
pad_num = model_ae.compression[0] - 1
padded_x = torch.cat(
[x[i, :, :1]] * pad_num
+ [x[i, :, : -pad_num * 2]]
+ [x[i, :, -pad_num * 2 - 1].unsqueeze(1)] * pad_num,
dim=1,
).unsqueeze(
0
) # remove the last pad_num * 2 frames from the end of the video
x_0[i] = model_ae.encode(padded_x)[0]
# condition: encode the image only
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -pad_num * 2 - 1, :, :].unsqueeze(1).unsqueeze(0))
if pad:
pad_num = model_ae.time_compression_ratio - 1
padded_x = torch.cat(
[x[i, :, :1]] * pad_num
+ [x[i, :, : -pad_num * 2]]
+ [x[i, :, -pad_num * 2 - 1].unsqueeze(1)] * pad_num,
dim=1,
).unsqueeze(
0
) # remove the last pad_num * 2 frames from the end of the video
x_0[i] = model_ae.encode(padded_x)[0]
# condition: encode the image only
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -pad_num * 2 - 1, :, :].unsqueeze(1).unsqueeze(0))
else:
x_0[i] = model_ae.encode(x[i : i + 1])[0]
latent[i, :, :1, :, :] = model_ae.encode(x[i, :, :1, :, :].unsqueeze(0))
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
elif mask_cond == "i2v_tail": # mask the last latent frame
masks[i, :, -1, :, :] = 1
pad_num = model_ae.compression[0] - 1
padded_x = torch.cat([x[i, :, pad_num:]] + [x[i, :, -1:]] * pad_num, dim=1).unsqueeze(0)
x_0[i] = model_ae.encode(padded_x)[0]
# condition: encode the image only
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
if pad:
pad_num = model_ae.time_compression_ratio - 1
padded_x = torch.cat([x[i, :, pad_num:]] + [x[i, :, -1:]] * pad_num, dim=1).unsqueeze(0)
x_0[i] = model_ae.encode(padded_x)[0]
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -pad_num * 2 - 1, :, :].unsqueeze(1).unsqueeze(0))
else:
x_0[i] = model_ae.encode(x[i : i + 1])[0]
latent[i, :, -1:, :, :] = model_ae.encode(x[i, :, -1:, :, :].unsqueeze(0))
elif mask_cond == "v2v_head": # mask the first 32 video frames
assert T > 32 // model_ae.compression[0]
conditioned_t = 32 // model_ae.compression[0]
assert T > 32 // model_ae.time_compression_ratio
conditioned_t = 32 // model_ae.time_compression_ratio
masks[i, :, :conditioned_t, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, :conditioned_t, :, :] = x_0[i, :, :conditioned_t, :, :]
elif mask_cond == "v2v_tail": # mask the last 32 video frames
assert T > 32 // model_ae.compression[0]
conditioned_t = 32 // model_ae.compression[0]
assert T > 32 // model_ae.time_compression_ratio
conditioned_t = 32 // model_ae.time_compression_ratio
masks[i, :, -conditioned_t:, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, -conditioned_t:, :, :] = x_0[i, :, -conditioned_t:, :, :]
elif mask_cond == "v2v_head_easy": # mask the first 64 video frames
assert T > 64 // model_ae.compression[0]
conditioned_t = 64 // model_ae.compression[0]
assert T > 64 // model_ae.time_compression_ratio
conditioned_t = 64 // model_ae.time_compression_ratio
masks[i, :, :conditioned_t, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, :conditioned_t, :, :] = x_0[i, :, :conditioned_t, :, :]
elif mask_cond == "v2v_tail_easy": # mask the last 64 video frames
assert T > 64 // model_ae.compression[0]
conditioned_t = 64 // model_ae.compression[0]
assert T > 64 // model_ae.time_compression_ratio
conditioned_t = 64 // model_ae.time_compression_ratio
masks[i, :, -conditioned_t:, :, :] = 1
x_0[i] = model_ae.encode(x[i].unsqueeze(0))[0]
latent[i, :, -conditioned_t:, :, :] = x_0[i, :, -conditioned_t:, :, :]
@@ -300,7 +311,6 @@ def prepare_visual_condition_uncausal(
# merge the masks and the masked_x into a single tensor
cond = torch.cat((masks, latent), dim=1)
return x_0, cond
# return x_0, cond, masks[:, 0, :, 0, 0] # if we want to disregard padded part in loss calc, need the masks


def prepare_visual_condition_causal(x: torch.Tensor, condition_config: dict, model_ae: torch.nn.Module) -> torch.Tensor:
@@ -317,7 +327,7 @@ def prepare_visual_condition_causal(x: torch.Tensor, condition_config: dict, mod
"""
# x has shape [b, c, t, h, w], where b is the batch size
B = x.shape[0]
C = model_ae.z_channels
C = model_ae.cfg.latent_channels
T, H, W = model_ae.get_latent_size(x.shape[-3:])

# Initialize masks tensor to match the shape of x, but only the time dimension will be masked
@@ -395,7 +405,6 @@ def prepare_visual_condition_causal(x: torch.Tensor, condition_config: dict, mod
# merge the masks and the masked_x into a single tensor
cond = torch.cat((masks, latent), dim=1)
return x_0, cond
# return x_0, cond, masks[:, 0, :, 0, 0] # if we want to disregard padded part in loss calc, need the masks


def get_batch_loss(model_pred, v_t, masks=None):


+ 1
- 0
requirements.txt View File

@@ -13,3 +13,4 @@ wandb>=0.17.0
tensorboard>=2.14.0
pre-commit>=3.5.0
omegaconf>=2.3.0
pyarrow

+ 0
- 90
scripts/cnv/extend_csv.py View File

@@ -1,90 +0,0 @@
import argparse
import os

import dask.dataframe as dd
import pandas as pd
from tqdm import tqdm

tqdm.pandas()

try:
pass

PANDA_USE_PARALLEL = True
except ImportError:
PANDA_USE_PARALLEL = False


PANDA_USE_PARALLEL = False


def apply(df, func, **kwargs):
if PANDA_USE_PARALLEL:
return df.parallel_apply(func, **kwargs)
return df.progress_apply(func, **kwargs)


def read_file(input_path):
if input_path.endswith(".csv"):
return pd.read_csv(input_path)
elif input_path.endswith(".parquet"):
return dd.read_parquet(input_path).compute()
else:
raise NotImplementedError(f"Unsupported file format: {input_path}")


def save_file(df, output_path):
if output_path.endswith(".csv"):
df.to_csv(output_path, index=False)
elif output_path.endswith(".parquet"):
df.to_parquet(output_path, index=False)
else:
raise NotImplementedError(f"Unsupported file format: {output_path}")


def process_path(path, suffix):
# Split path into parts
parts = path.split("/")
assert parts[4] == "data", f"Invalid path: {path}"
parts[4] = "latents"

# Process filename
filename = parts[-1]
ext = filename.split(".")[-1]
filename = filename.replace(f".{ext}", suffix)
parts[-1] = filename

new_path = "/".join(parts)

# Create directory if not exists
directory = os.path.dirname(new_path)
os.makedirs(directory, exist_ok=True)

return new_path


def extend_csv(source_csv, target_csv):
# Read CSV file
df = read_file(source_csv)

# Add new columns
df["latents_path"] = apply(df["path"], lambda x: process_path(x, ".pt"))
df["text_t5_path"] = apply(df["path"], lambda x: process_path(x, "_t5.pt"))
df["text_clip_path"] = apply(df["path"], lambda x: process_path(x, "_clip.pt"))

# Save modified CSV
save_file(df, target_csv)
print(f"CSV file saved to: {target_csv}")


def main():
parser = argparse.ArgumentParser(description="Extend CSV file by adding latents and text encoding paths")
parser.add_argument("--source_csv", required=True, help="Path to input CSV file")
parser.add_argument("--target_csv", required=True, help="Path to output CSV file (optional)", default=None)

args = parser.parse_args()
extend_csv(args.source_csv, args.target_csv)


if __name__ == "__main__":
main()

+ 70
- 0
scripts/cnv/meta.py View File

@@ -0,0 +1,70 @@
import argparse

import numpy as np
import pandas as pd
from pandarallel import pandarallel
from torchvision.io.video import read_video
from tqdm import tqdm


def set_parallel(num_workers: int = None) -> callable:
if num_workers == 0:
return lambda x, *args, **kwargs: x.progress_apply(*args, **kwargs)
else:
if num_workers is not None:
pandarallel.initialize(progress_bar=True, nb_workers=num_workers)
else:
pandarallel.initialize(progress_bar=True)
return lambda x, *args, **kwargs: x.parallel_apply(*args, **kwargs)


def get_video_info(path: str) -> pd.Series:
vframes, _, vinfo = read_video(path, pts_unit="sec", output_format="TCHW")
num_frames, C, height, width = vframes.shape
fps = round(vinfo["video_fps"], 3)
aspect_ratio = height / width if width > 0 else np.nan
resolution = height * width

ret = pd.Series(
[height, width, fps, num_frames, aspect_ratio, resolution],
index=[
"height",
"width",
"fps",
"num_frames",
"aspect_ratio",
"resolution",
],
dtype=object,
)
return ret


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, required=True, help="Input file path")
parser.add_argument("--output", type=str, required=True, help="Output file path")
parser.add_argument(
"--num_workers", type=int, default=None, help="Number of workers"
)
return parser.parse_args()


def main():
args = parse_args()
input_path = args.input
output_path = args.output
num_workers = args.num_workers

df = pd.read_csv(input_path)
tqdm.pandas()
apply = set_parallel(num_workers)

result = apply(df["path"], get_video_info)
for col in result.columns:
df[col] = result[col]
df.to_csv(output_path, index=False)


if __name__ == "__main__":
main()

+ 23
- 4
scripts/cnv/shard.py View File

@@ -1,8 +1,15 @@
import os

import dask.dataframe as dd
import pandas as pd
from tqdm import tqdm

try:
import dask.dataframe as dd

SUPPORT_DASK = True
except:
SUPPORT_DASK = False


def shard_parquet(input_path, k):
# 检查输入路径是否存在
@@ -10,10 +17,20 @@ def shard_parquet(input_path, k):
raise FileNotFoundError(f"Input file {input_path} does not exist.")

# 读取 Parquet 文件为 Pandas DataFrame
df = dd.read_parquet(input_path).compute()
if SUPPORT_DASK:
df = dd.read_parquet(input_path).compute()
else:
df = pd.read_parquet(input_path)

# 去除指定的列
columns_to_remove = ["num_frames", "height", "width", "aspect_ratio", "fps", "resolution"]
columns_to_remove = [
"num_frames",
"height",
"width",
"aspect_ratio",
"fps",
"resolution",
]
df = df.drop(columns=[col for col in columns_to_remove if col in df.columns])

# 计算每个分片的大小
@@ -48,7 +65,9 @@ if __name__ == "__main__":

parser = argparse.ArgumentParser(description="Shard a Parquet file.")
parser.add_argument("input_path", type=str, help="Path to the input Parquet file.")
parser.add_argument("k", type=int, help="Number of shards to create.", default=10000)
parser.add_argument(
"k", type=int, help="Number of shards to create.", default=10000
)

args = parser.parse_args()



+ 0
- 59
scripts/cnv/txt2csv.py View File

@@ -1,59 +0,0 @@
import argparse

import pandas as pd


def txt_to_csv(input_txt: str, output_csv: str) -> None:
"""
Convert a .txt file to a .csv file with a 'text' column using pandas.

Args:
input_txt (str): Path to the input .txt file.
output_csv (str): Path to the output .csv file.

Returns:
None
"""
try:
# Read the .txt file, each line becomes an entry in a list
with open(input_txt, "r", encoding="utf-8") as f:
lines = f.readlines()

# Strip newline characters from each line
lines = [line.strip() for line in lines]

# Create a DataFrame with a single column 'text'
df = pd.DataFrame(lines, columns=["text"])

# Write DataFrame to CSV
df.to_csv(output_csv, index=False, encoding="utf-8")
print(f"CSV file '{output_csv}' created successfully.")

except FileNotFoundError:
print(f"Error: The file {input_txt} was not found.")
except Exception as e:
print(f"An error occurred: {e}")


def parse_arguments() -> argparse.Namespace:
"""
Parse the command-line arguments.

Args:
None

Returns:
argparse.Namespace: The parsed arguments containing input_txt and output_csv.
"""
parser = argparse.ArgumentParser(description="Convert a .txt file to a .csv file.")
parser.add_argument("input_txt", type=str, help="Path to the input .txt file.")
parser.add_argument("output_csv", type=str, help="Path to the output .csv file.")
return parser.parse_args()


if __name__ == "__main__":
# Parse command-line arguments
args: argparse.Namespace = parse_arguments()

# Call the conversion function with parsed arguments
txt_to_csv(args.input_txt, args.output_csv)

+ 25
- 11
scripts/diffusion/train.py View File

@@ -15,21 +15,33 @@ gc.disable()
import torch
import torch.distributed as dist
import torch.nn.functional as F
import wandb
from colossalai.booster import Booster
from colossalai.utils import set_seed
from peft import LoraConfig
from tqdm import tqdm

import wandb
from opensora.acceleration.checkpoint import GLOBAL_ACTIVATION_MANAGER, set_grad_checkpoint
from opensora.acceleration.checkpoint import (
GLOBAL_ACTIVATION_MANAGER,
set_grad_checkpoint,
)
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.aspect import bucket_to_shapes
from opensora.datasets.dataloader import prepare_dataloader
from opensora.datasets.pin_memory_cache import PinMemoryCache
from opensora.models.mmdit.distributed import MMDiTPolicy
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt import CheckpointIO, model_sharding, record_model_param_shape, rm_checkpoints
from opensora.utils.config import config_to_name, create_experiment_workspace, parse_configs
from opensora.utils.ckpt import (
CheckpointIO,
model_sharding,
record_model_param_shape,
rm_checkpoints,
)
from opensora.utils.config import (
config_to_name,
create_experiment_workspace,
parse_configs,
)
from opensora.utils.logger import create_logger
from opensora.utils.misc import (
NsysProfiler,
@@ -45,7 +57,13 @@ from opensora.utils.misc import (
to_torch_dtype,
)
from opensora.utils.optimizer import create_lr_scheduler, create_optimizer
from opensora.utils.sampling import get_res_lin_function, pack, prepare, prepare_ids, time_shift
from opensora.utils.sampling import (
get_res_lin_function,
pack,
prepare,
prepare_ids,
time_shift,
)
from opensora.utils.train import (
create_colossalai_plugin,
dropout_condition,
@@ -357,11 +375,7 @@ def main():
if cfg.get("condition_config", None) is not None:
# condition for i2v & v2v
x_0, cond = prepare_visual_condition(x, cfg.condition_config, model_ae)
if cfg.get("no_i2v_ref_loss", False):
inp["masks"] = cond[
:, 0, :, :, :
] # record the padded frames in I2V, so they can be ignored in loss calculation
cond = pack(cond)
cond = pack(cond, patch_size=cfg.get("patch_size", 2))
inp["cond"] = cond
else:
if cfg.get("cached_video", False):
@@ -384,7 +398,7 @@ def main():
with nsys.range("encode_text"), timers["encode_text"]:
inp_ = prepare_ids(x_0, t5_embedding, clip_embedding)
inp.update(inp_)
x_0 = pack(x_0)
x_0 = pack(x_0, patch_size=cfg.get("patch_size", 2))
else:
# == encode text ==
with nsys.range("encode_text"), timers["encode_text"]:


+ 259
- 251
scripts/vae/train.py View File

@@ -16,6 +16,7 @@ import torch
import torch.distributed as dist
from colossalai.booster import Booster
from colossalai.utils import set_seed
from torch.profiler import ProfilerActivity, profile, schedule
from tqdm import tqdm

import wandb
@@ -23,13 +24,13 @@ from opensora.acceleration.checkpoint import set_grad_checkpoint
from opensora.acceleration.parallel_states import get_data_parallel_group
from opensora.datasets.dataloader import prepare_dataloader
from opensora.datasets.pin_memory_cache import PinMemoryCache
from opensora.models.hunyuan_vae.policy import HunyuanVaePolicy
from opensora.models.vae.losses import DiscriminatorLoss, GeneratorLoss, VAELoss, cal_opl_loss
from opensora.registry import DATASETS, MODELS, build_module
from opensora.utils.ckpt import CheckpointIO, model_sharding, record_model_param_shape, rm_checkpoints
from opensora.utils.config import config_to_name, create_experiment_workspace, parse_configs
from opensora.utils.logger import create_logger
from opensora.utils.misc import (
Timer,
all_reduce_sum,
create_tensorboard_writer,
is_log_process,
@@ -41,7 +42,15 @@ from opensora.utils.train import create_colossalai_plugin, set_lr, set_warmup_st

torch.backends.cudnn.benchmark = True

from opensora.acceleration.checkpoint import GLOBAL_ACTIVATION_MANAGER
WAIT = 1
WARMUP = 10
ACTIVE = 20

my_schedule = schedule(
wait=WAIT, # number of warmup steps
warmup=WARMUP, # number of warmup steps with profiling
active=ACTIVE, # number of active steps with profiling
)


def main():
@@ -54,9 +63,6 @@ def main():
# == get dtype & device ==
dtype = to_torch_dtype(cfg.get("dtype", "bf16"))
device, coordinator = setup_device()
grad_ckpt_buffer_size = cfg.get("grad_checkpoint_buffer_size", 0)
if grad_ckpt_buffer_size > 0:
GLOBAL_ACTIVATION_MANAGER.setup_buffer(grad_ckpt_buffer_size, dtype=dtype)
checkpoint_io = CheckpointIO()
set_seed(cfg.get("seed", 1024))
PinMemoryCache.force_dtype = dtype
@@ -66,15 +72,15 @@ def main():
# == init ColossalAI booster ==
plugin_type = cfg.get("plugin", "zero2")
plugin_config = cfg.get("plugin_config", {})
plugin_kwargs = {}
if plugin_type == "hybrid":
plugin_kwargs["custom_policy"] = HunyuanVaePolicy
plugin = create_colossalai_plugin(
plugin=plugin_type,
dtype=cfg.get("dtype", "bf16"),
grad_clip=cfg.get("grad_clip", 0),
**plugin_config,
**plugin_kwargs,
plugin = (
create_colossalai_plugin(
plugin=plugin_type,
dtype=cfg.get("dtype", "bf16"),
grad_clip=cfg.get("grad_clip", 0),
**plugin_config,
)
if plugin_type != "none"
else None
)
booster = Booster(plugin=plugin)

@@ -161,17 +167,6 @@ def main():
generator_loss_fn = GeneratorLoss(**cfg.gen_loss_config)
discriminator_loss_fn = DiscriminatorLoss(**cfg.disc_loss_config)

disc_plugin_type = cfg.get("disc_plugin", "zero2")
disc_plugin_config = cfg.get("disc_plugin_config", {})
disc_plugin = create_colossalai_plugin(
plugin=disc_plugin_type,
dtype=cfg.get("dtype", "bf16"),
grad_clip=cfg.get("grad_clip", 0),
**disc_plugin_config,
)
disc_booster = Booster(plugin=disc_plugin)
booster = Booster(plugin=plugin)

# == setup optimizer ==
optimizer = create_optimizer(model, cfg.optim)

@@ -204,7 +199,7 @@ def main():
)

if use_discriminator:
discriminator, disc_optimizer, _, _, disc_lr_scheduler = disc_booster.boost(
discriminator, disc_optimizer, _, _, disc_lr_scheduler = booster.boost(
model=discriminator,
optimizer=disc_optimizer,
lr_scheduler=disc_lr_scheduler,
@@ -216,7 +211,6 @@ def main():
cfg_epochs = cfg.get("epochs", 1000)
mixed_strategy = cfg.get("mixed_strategy", None)
mixed_image_ratio = cfg.get("mixed_image_ratio", 0.0)
alter_train_gen_disc = bool(cfg.get("alter_train_gen_disc", False))
# modulate mixed image ratio since we force rank 0 to be video
num_ranks = dist.get_world_size()
modulated_mixed_image_ratio = (
@@ -225,9 +219,8 @@ def main():
if is_log_process(plugin_type, plugin_config):
print("modulated mixed image ratio:", modulated_mixed_image_ratio)

start_epoch = start_step = log_gen_step = log_disc_step = 0 # log_step = acc_step = 0

running_loss = dict(
start_epoch = start_step = log_step = acc_step = 0
running_loss = dict( # loss accumulated over config.log_every steps
all=0.0,
nll=0.0,
nll_rec=0.0,
@@ -263,8 +256,8 @@ def main():
cfg.load,
model=model,
ema=ema,
optimizer=optimizer if not cfg.get("reset_optimizer", False) else None,
lr_scheduler=lr_scheduler if not cfg.get("reset_optimizer", False) else None,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=(
None if start_step is not None else sampler
), # if specify start step, set last_micro_batch_access_index of a new sampler instead
@@ -288,7 +281,6 @@ def main():
and os.path.exists(os.path.join(cfg.load, "discriminator"))
and not cfg.get("restart_disc", False)
):
logger.info("loading discriminator...")
booster.load_model(discriminator, os.path.join(cfg.load, "discriminator"))
if cfg.get("load_optimizer", True):
booster.load_optimizer(disc_optimizer, os.path.join(cfg.load, "disc_optimizer"))
@@ -318,6 +310,12 @@ def main():
model_sharding(ema)
ema = ema.to(device)

if cfg.get("freeze_layers", None) == "all":
for param in model.module.parameters():
param.requires_grad = False
print("all layers frozen")

# model.module.requires_grad_(False)
# =======================================================
# 5. training loop
# =======================================================
@@ -328,7 +326,7 @@ def main():
sampler.set_epoch(epoch)
dataiter = iter(dataloader)
logger.info("Beginning epoch %s...", epoch)
random.seed(1024 + dist.get_rank(get_data_parallel_group())) # load vid/img for each rank
random.seed(1024 + dist.get_rank()) # load vid/img for each rank

# == training loop in an epoch ==
with tqdm(
@@ -348,111 +346,122 @@ def main():

batch_, step_, pinned_video_ = fetch_data()

for _ in range(start_step, num_steps_per_epoch):
# == load data ===
batch, step, pinned_video = batch_, step_, pinned_video_
if step + 1 < num_steps_per_epoch:
batch_, step_, pinned_video_ = fetch_data()

# == log config ==
global_step = epoch * num_steps_per_epoch + step
# actual_update_step = (global_step + 1) // accumulation_steps
# log_step += 1
# acc_step += 1

# === whether to train generator and discriminator in alternative steps
if alter_train_gen_disc:
train_gen = (global_step + 1) % 2 == 1
train_disc = (global_step + 1) % 2 == 0
gen_step = global_step // 2 + 1
disc_step = (global_step + 1) // 2
actual_gen_update_step = gen_step // accumulation_steps
actual_disc_update_step = disc_step // accumulation_steps
else:
gen_step = disc_step = global_step + 1
train_gen = train_disc = True
actual_gen_update_step = actual_disc_update_step = (global_step + 1) // accumulation_steps

# == mixed strategy ==
x = batch["video"]
t_length = x.size(2)
use_video = 1
if mixed_strategy == "mixed_video_image":
if random.random() < modulated_mixed_image_ratio and dist.get_rank() != 0:
# NOTE: enable the first rank to use video
t_length = 1
use_video = 0
elif mixed_strategy == "mixed_video_random":
t_length = random.randint(1, x.size(2))
x = x[:, :, :t_length, :, :]

if train_gen:
# == forward pass ==
x_rec, posterior, z = model(x)
if cache_pin_memory:
dataiter.remove_cache(pinned_video)

log_gen_step += 1

# == loss initialization ==
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
loss_dict = {}

# == opl loss ==
opl_loss_weight = cfg.get("opl_loss_weight", 0)
if opl_loss_weight > 0:
opl_loss = cal_opl_loss(z, opl_loss_weight)
vae_loss += opl_loss

# == reconstruction loss ==
ret = vae_loss_fn(x, x_rec, posterior)
nll_loss = ret["nll_loss"]
kl_loss = ret["kl_loss"]
recon_loss = ret["recon_loss"]
perceptual_loss = ret["perceptual_loss"]
vae_loss += nll_loss + kl_loss

# == generator loss ==
if use_discriminator:
# turn off grad update for disc
discriminator.requires_grad_(False)
fake_logits = discriminator(x_rec.contiguous())

generator_loss, g_loss = generator_loss_fn(
fake_logits,
nll_loss,
model.unwrap().get_last_layer(),
actual_gen_update_step,
is_training=model.training,
)
profiler_ctxt = (
profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=my_schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler("./log/profile"),
record_shapes=True,
profile_memory=True,
with_stack=True,
)
if cfg.get("profile", False)
else nullcontext()
)

vae_loss += generator_loss
# turn on disc training
discriminator.requires_grad_(True)

# == generator backward & update ==
ctx = (
booster.no_sync(model, optimizer)
if cfg.get("plugin", "zero2") in ("zero1", "zero1-seq") and (step + 1) % accumulation_steps != 0
else nullcontext()
)
with ctx:
booster.backward(loss=vae_loss / accumulation_steps, optimizer=optimizer)
if (gen_step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step(
actual_gen_update_step,
)
# == update EMA ==
if ema is not None:
update_ema(
ema,
model.unwrap(),
optimizer=optimizer,
decay=cfg.get("ema_decay", 0.9999),
with profiler_ctxt:
for _ in range(start_step, num_steps_per_epoch):
if cfg.get("profile", False) and _ == WARMUP + ACTIVE + WAIT + 3:
break

# == load data ===
batch, step, pinned_video = batch_, step_, pinned_video_
if step + 1 < num_steps_per_epoch:
batch_, step_, pinned_video_ = fetch_data()

# == log config ==
global_step = epoch * num_steps_per_epoch + step
actual_update_step = (global_step + 1) // accumulation_steps
log_step += 1
acc_step += 1

# == mixed strategy ==
x = batch["video"]
t_length = x.size(2)
use_video = 1
if mixed_strategy == "mixed_video_image":
if random.random() < modulated_mixed_image_ratio and dist.get_rank() != 0:
# NOTE: enable the first rank to use video
t_length = 1
use_video = 0
elif mixed_strategy == "mixed_video_random":
t_length = random.randint(1, x.size(2))
x = x[:, :, :t_length, :, :]

with Timer("model", log=True) if cfg.get("profile", False) else nullcontext():
# == forward pass ==
x_rec, posterior, z = model(x)

if cfg.get("profile", False):
profiler_ctxt.step()

if cache_pin_memory:
dataiter.remove_cache(pinned_video)

# == loss initialization ==
vae_loss = torch.tensor(0.0, device=device, dtype=dtype)
loss_dict = {} # loss at every step

# == opl loss ==
opl_loss_weight = cfg.get("opl_loss_weight", 0)
if opl_loss_weight > 0:
opl_loss = cal_opl_loss(z, opl_loss_weight)
vae_loss += opl_loss

# == reconstruction loss ==
ret = vae_loss_fn(x, x_rec, posterior)
nll_loss = ret["nll_loss"]
kl_loss = ret["kl_loss"]
recon_loss = ret["recon_loss"]
perceptual_loss = ret["perceptual_loss"]
vae_loss += nll_loss + kl_loss

# == generator loss ==
if use_discriminator:
# turn off grad update for disc
discriminator.requires_grad_(False)
fake_logits = discriminator(x_rec.contiguous())

generator_loss, g_loss = generator_loss_fn(
fake_logits,
nll_loss,
model.module.get_last_layer(),
actual_update_step,
is_training=model.training,
)
# print(f"generator_loss: {generator_loss}, recon_loss: {recon_loss}, perceptual_loss: {perceptual_loss}")

vae_loss += generator_loss
# turn on disc training
discriminator.requires_grad_(True)

# == generator backward & update ==
ctx = (
booster.no_sync(model, optimizer)
if cfg.get("plugin", "zero2") in ("zero1", "zero1-seq")
and (step + 1) % accumulation_steps != 0
else nullcontext()
)
with Timer("backward", log=True) if cfg.get("profile", False) else nullcontext():
with ctx:
booster.backward(loss=vae_loss / accumulation_steps, optimizer=optimizer)

with Timer("optimizer", log=True) if cfg.get("profile", False) else nullcontext():
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
if lr_scheduler is not None:
lr_scheduler.step(
actual_update_step,
)
# == update EMA ==
if ema is not None:
update_ema(
ema,
model.unwrap(),
optimizer=optimizer,
decay=cfg.get("ema_decay", 0.9999),
)

# == logging ==
log_loss("all", vae_loss, loss_dict, use_video)
@@ -466,129 +475,128 @@ def main():
log_loss("gen_w", generator_loss, loss_dict, use_video)
log_loss("gen", g_loss, loss_dict, use_video)

# == loss: discriminator adversarial ==
if train_disc and use_discriminator:
# need to calc x_rec without accumulating compuation graph to avoid OOM
if not train_gen:
with torch.no_grad():
x_rec, posterior, z = model(x)
if cache_pin_memory:
dataiter.remove_cache(pinned_video)

log_disc_step += 1
real_logits = discriminator(x.detach().contiguous())
fake_logits = discriminator(x_rec.detach().contiguous())
disc_loss = discriminator_loss_fn(
real_logits,
fake_logits,
actual_disc_update_step,
)

# == discriminator backward & update ==
ctx = (
booster.no_sync(discriminator, disc_optimizer)
if cfg.get("plugin", "zero2") in ("zero1", "zero1-seq") and (step + 1) % accumulation_steps != 0
else nullcontext()
)
with ctx:
booster.backward(loss=disc_loss / accumulation_steps, optimizer=disc_optimizer)
if (disc_step + 1) % accumulation_steps == 0:
disc_optimizer.step()
disc_optimizer.zero_grad()
if disc_lr_scheduler is not None:
disc_lr_scheduler.step(actual_disc_update_step)

# log
log_loss("disc", disc_loss, loss_dict, use_video)

# == logging ==
if (gen_step + 1) % accumulation_steps == 0 and actual_gen_update_step == actual_disc_update_step:
if coordinator.is_master() and actual_gen_update_step % cfg.get("log_every", 1) == 0:
avg_loss = {k: v / log_gen_step for k, v in running_loss.items()}
# progress bar
pbar.set_postfix(
{
# "step": step,
# "global_step": global_step,
# "actual_update_step": actual_gen_update_step,
# "lr": optimizer.param_groups[0]["lr"],
**{k: f"{v:.2f}" for k, v in avg_loss.items()},
}
# == loss: discriminator adversarial ==
if use_discriminator:
real_logits = discriminator(x.detach().contiguous())
fake_logits = discriminator(x_rec.detach().contiguous())
disc_loss = discriminator_loss_fn(
real_logits,
fake_logits,
actual_update_step,
)

# == discriminator backward & update ==
ctx = (
booster.no_sync(discriminator, disc_optimizer)
if cfg.get("plugin", "zero2") in ("zero1", "zero1-seq")
and (step + 1) % accumulation_steps != 0
else nullcontext()
)
# tensorboard
tb_writer.add_scalar("loss", vae_loss.item(), actual_gen_update_step)
# wandb
if cfg.get("wandb", False):
wandb.log(
with ctx:
booster.backward(loss=disc_loss / accumulation_steps, optimizer=disc_optimizer)
if (step + 1) % accumulation_steps == 0:
disc_optimizer.step()
disc_optimizer.zero_grad()
if disc_lr_scheduler is not None:
disc_lr_scheduler.step(actual_update_step)

# log
log_loss("disc", disc_loss, loss_dict, use_video)

# == logging ==
if (global_step + 1) % accumulation_steps == 0:
if coordinator.is_master() and actual_update_step % cfg.get("log_every", 1) == 0:
avg_loss = {k: v / log_step for k, v in running_loss.items()}
# progress bar
pbar.set_postfix(
{
"iter": global_step,
"epoch": epoch,
"lr": optimizer.param_groups[0]["lr"],
"avg_loss_": avg_loss,
"avg_loss": avg_loss["all"],
"loss_": loss_dict,
"loss": vae_loss.item(),
},
step=actual_gen_update_step,
# "step": step,
# "global_step": global_step,
# "actual_update_step": actual_update_step,
# "lr": optimizer.param_groups[0]["lr"],
**{k: f"{v:.2f}" for k, v in avg_loss.items()},
}
)
# tensorboard
tb_writer.add_scalar("loss", vae_loss.item(), actual_update_step)
# wandb
if cfg.get("wandb", False):
wandb.log(
{
"iter": global_step,
"epoch": epoch,
"lr": optimizer.param_groups[0]["lr"],
"avg_loss_": avg_loss,
"avg_loss": avg_loss["all"],
"loss_": loss_dict,
"loss": vae_loss.item(),
"global_grad_norm": optimizer.get_grad_norm(),
},
step=actual_update_step,
)

running_loss = {k: 0.0 for k in running_loss}
log_gen_step = 0

# == checkpoint saving ==
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and actual_gen_update_step % ckpt_every == 0 and coordinator.is_master():
subprocess.run("sudo drop_cache", shell=True)

if ckpt_every > 0 and actual_gen_update_step % ckpt_every == 0:
# mannually garbage collection
gc.collect()

save_dir = checkpoint_io.save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
actual_update_step=actual_gen_update_step,
ema_shape_dict=ema_shape_dict,
async_io=True,
)
running_loss = {k: 0.0 for k in running_loss}
log_step = 0

if is_log_process(plugin_type, plugin_config):
os.system(f"chgrp -R share {save_dir}")
# == checkpoint saving ==
ckpt_every = cfg.get("ckpt_every", 0)
if ckpt_every > 0 and actual_update_step % ckpt_every == 0 and coordinator.is_master():
subprocess.run("sudo drop_cache", shell=True)

if use_discriminator:
disc_booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
disc_booster.save_optimizer(
disc_optimizer,
os.path.join(save_dir, "disc_optimizer"),
shard=True,
size_per_shard=4096,
if ckpt_every > 0 and actual_update_step % ckpt_every == 0:
# mannually garbage collection
gc.collect()

save_dir = checkpoint_io.save(
booster,
exp_dir,
model=model,
ema=ema,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
sampler=sampler,
epoch=epoch,
step=step + 1,
global_step=global_step + 1,
batch_size=cfg.get("batch_size", None),
actual_update_step=actual_update_step,
ema_shape_dict=ema_shape_dict,
async_io=True,
)
if disc_lr_scheduler is not None:
disc_booster.save_lr_scheduler(
disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")

if is_log_process(plugin_type, plugin_config):
os.system(f"chgrp -R share {save_dir}")

if use_discriminator:
booster.save_model(discriminator, os.path.join(save_dir, "discriminator"), shard=True)
booster.save_optimizer(
disc_optimizer,
os.path.join(save_dir, "disc_optimizer"),
shard=True,
size_per_shard=4096,
)
dist.barrier()

logger.info(
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
epoch,
step + 1,
actual_gen_update_step,
save_dir,
)
if disc_lr_scheduler is not None:
booster.save_lr_scheduler(
disc_lr_scheduler, os.path.join(save_dir, "disc_lr_scheduler")
)
dist.barrier()

logger.info(
"Saved checkpoint at epoch %s, step %s, global_step %s to %s",
epoch,
step + 1,
actual_update_step,
save_dir,
)

# remove old checkpoints
rm_checkpoints(exp_dir, keep_n_latest=cfg.get("keep_n_latest", -1))
logger.info(
"Removed old checkpoints and kept %s latest ones.", cfg.get("keep_n_latest", -1)
)

# remove old checkpoints
rm_checkpoints(exp_dir, keep_n_latest=cfg.get("keep_n_latest", -1))
logger.info("Removed old checkpoints and kept %s latest ones.", cfg.get("keep_n_latest", -1))
if cfg.get("profile", False):
profiler_ctxt.export_chrome_trace("./log/profile/trace.json")

sampler.reset()
start_step = 0


Loading…
Cancel
Save
Baidu
map