2 Commits

Author SHA1 Message Date
  Akegarasu 3f94c40ed7
update sd-scripts 4 months ago
  Akegarasu 2c65fb1af1
fix onnxruntime install 4 months ago
58 changed files with 9042 additions and 1110 deletions
Split View
  1. +30
    -25
      mikazuki/launch_utils.py
  2. +1
    -1
      requirements.txt
  3. +5
    -0
      scripts/dev/.gitignore
  4. +1
    -1
      scripts/dev/COMMIT_ID
  5. +7
    -6
      scripts/dev/README-ja.md
  6. +69
    -33
      scripts/dev/README.md
  7. +2
    -0
      scripts/dev/fine_tune.py
  8. +12
    -11
      scripts/dev/finetune/tag_images_by_wd14_tagger.py
  9. +43
    -25
      scripts/dev/flux_minimal_inference.py
  10. +3
    -2
      scripts/dev/flux_train.py
  11. +8
    -1
      scripts/dev/flux_train_control_net.py
  12. +45
    -57
      scripts/dev/flux_train_network.py
  13. +744
    -0
      scripts/dev/library/chroma_models.py
  14. +6
    -1
      scripts/dev/library/config_util.py
  15. +17
    -13
      scripts/dev/library/custom_offloading_utils.py
  16. +41
    -0
      scripts/dev/library/deepspeed_utils.py
  17. +14
    -178
      scripts/dev/library/flux_models.py
  18. +154
    -84
      scripts/dev/library/flux_train_utils.py
  19. +112
    -41
      scripts/dev/library/flux_utils.py
  20. +62
    -76
      scripts/dev/library/ipex/__init__.py
  21. +6
    -6
      scripts/dev/library/ipex/attention.py
  22. +80
    -1
      scripts/dev/library/ipex/diffusers.py
  23. +0
    -183
      scripts/dev/library/ipex/gradscaler.py
  24. +167
    -68
      scripts/dev/library/ipex/hijacks.py
  25. +186
    -0
      scripts/dev/library/jpeg_xl_util.py
  26. +1392
    -0
      scripts/dev/library/lumina_models.py
  27. +1098
    -0
      scripts/dev/library/lumina_train_util.py
  28. +259
    -0
      scripts/dev/library/lumina_util.py
  29. +484
    -129
      scripts/dev/library/sai_model_spec.py
  30. +2
    -2
      scripts/dev/library/sd3_models.py
  31. +4
    -4
      scripts/dev/library/sd3_utils.py
  32. +75
    -9
      scripts/dev/library/strategy_base.py
  33. +375
    -0
      scripts/dev/library/strategy_lumina.py
  34. +175
    -60
      scripts/dev/library/train_util.py
  35. +115
    -3
      scripts/dev/library/utils.py
  36. +418
    -0
      scripts/dev/lumina_minimal_inference.py
  37. +957
    -0
      scripts/dev/lumina_train.py
  38. +383
    -0
      scripts/dev/lumina_train_network.py
  39. +310
    -23
      scripts/dev/networks/lora_flux.py
  40. +1038
    -0
      scripts/dev/networks/lora_lumina.py
  41. +41
    -28
      scripts/dev/networks/resize_lora.py
  42. +1
    -0
      scripts/dev/pytest.ini
  43. +1
    -1
      scripts/dev/requirements.txt
  44. +3
    -0
      scripts/dev/sd3_train.py
  45. +2
    -1
      scripts/dev/sdxl_train.py
  46. +2
    -0
      scripts/dev/sdxl_train_control_net.py
  47. +2
    -0
      scripts/dev/sdxl_train_control_net_lllite.py
  48. +2
    -0
      scripts/dev/sdxl_train_control_net_lllite_old.py
  49. +0
    -1
      scripts/dev/sdxl_train_network.py
  50. +2
    -0
      scripts/dev/tools/cache_latents.py
  51. +2
    -0
      scripts/dev/tools/cache_text_encoder_outputs.py
  52. +4
    -7
      scripts/dev/tools/detect_face_rotate.py
  53. +4
    -16
      scripts/dev/tools/resize_images_to_resolution.py
  54. +1
    -0
      scripts/dev/train_control_net.py
  55. +2
    -0
      scripts/dev/train_db.py
  56. +69
    -12
      scripts/dev/train_network.py
  57. +2
    -1
      scripts/dev/train_textual_inversion.py
  58. +2
    -0
      scripts/dev/train_textual_inversion_XTI.py

+ 30
- 25
mikazuki/launch_utils.py View File

@@ -205,7 +205,7 @@ def setup_windows_bitsandbytes():
bnb_package = "bitsandbytes==0.46.0"
bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")

installed_bnb = is_installed("bitsandbytes") # don't check version here
installed_bnb = is_installed("bitsandbytes") # don't check version here
bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0

if not installed_bnb or not bnb_cuda_setup:
@@ -214,20 +214,10 @@ def setup_windows_bitsandbytes():
run_pip(f"install {bnb_package}", bnb_package, live=True)


def setup_onnxruntime():
onnx_version = "1.18.1"
index_url = None

try:
import torch
torch_version = torch.__version__
if "cu12" in torch_version:
# for cuda 12
onnx_version = f"1.18.1"
index_url = "https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/"
except ImportError:
log.error("torch not found")

def setup_onnxruntime(
onnx_version: Optional[str] = None,
index_url: Optional[str] = None
):
if sys.platform == "linux":
libc_ver = platform.libc_ver()
if libc_ver[0] == "glibc" and libc_ver[1] <= "2.27":
@@ -235,24 +225,39 @@ def setup_onnxruntime():

onnx_version = os.environ.get("ONNXRUNTIME_VERSION", onnx_version)

if not is_installed(f"onnxruntime-gpu=={onnx_version}"):
if onnx_version and not is_installed(f"onnxruntime-gpu=={onnx_version}"):
log.info("uninstalling wrong onnxruntime version")
# run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime=={onnx_version}", live=True)
run_pip(f"uninstall onnxruntime -y", "onnxruntime", live=True)
run_pip(f"uninstall onnxruntime-gpu -y", "onnxruntime", live=True)

if not is_installed(f"onnxruntime-gpu"):
log.info(f"installing onnxruntime")
run_pip(f"install onnxruntime=={onnx_version}", f"onnxruntime", live=True)
if index_url:
run_pip(f"install onnxruntime-gpu=={onnx_version} -i {index_url}", f"onnxruntime-gpu", live=True)
else:
run_pip(f"install onnxruntime-gpu=={onnx_version}", f"onnxruntime-gpu", live=True)
pip_install("onnxruntime", onnx_version, index_url=index_url, live=True)
pip_install("onnxruntime-gpu", onnx_version, index_url=index_url, live=True)


def run_pip(command, desc=None, live=False):
return run(f'"{python_bin}" -m pip {command}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live)


def pip_install(package: str, version: Optional[str] = None, index_url: Optional[str] = None, live: bool = True):
"""
Install a package using pip.
:param package: The name of the package to install.
:param version: The version of the package to install (optional).
:param index_url: The index URL to use for installing the package (optional).
"""
if version:
package = f"{package}=={version}"

command = f"install {package}"

if index_url:
command = f"{command} -i {index_url}"

run_pip(command, desc=f"Installing {package}", live=live)


def check_run(file: str) -> bool:
result = subprocess.run([python_bin, file], capture_output=True, shell=False)
log.info(result.stdout.decode("utf-8").strip())
@@ -275,7 +280,7 @@ def network_gfw_test(timeout=3):
return False


def prepare_environment(disable_auto_mirror: bool = True):
def prepare_environment(disable_auto_mirror: bool = True, prepare_onnxruntime: bool = True):
if sys.platform == "win32":
# disable triton on windows
os.environ["XFORMERS_FORCE_DISABLE_TRITON"] = "1"
@@ -304,8 +309,8 @@ def prepare_environment(disable_auto_mirror: bool = True):
validate_requirements("requirements.txt")
setup_windows_bitsandbytes()

# if not skip_prepare_onnxruntime:
# setup_onnxruntime()
if prepare_onnxruntime:
setup_onnxruntime()


def catch_exception(f):


+ 1
- 1
requirements.txt View File

@@ -9,7 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.46.0
lion-pytorch==0.1.2
schedulefree==1.4
pytorch-optimizer==3.5.0
pytorch-optimizer==3.7.0
prodigy-plus-schedule-free==1.9.0
prodigyopt==1.1.2
tensorboard==2.10.1


+ 5
- 0
scripts/dev/.gitignore View File

@@ -6,3 +6,8 @@ venv
build
.vscode
wandb
CLAUDE.md
GEMINI.md
.claude
.gemini
MagicMock

+ 1
- 1
scripts/dev/COMMIT_ID View File

@@ -1 +1 @@
6364379f17d50add1696b0672f39c25c08a006b6
18e62515c49fe502ca31b30ea2214a97a2e99633

+ 7
- 6
scripts/dev/README-ja.md View File

@@ -155,11 +155,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b

`#` で始まる行はコメントになります。`--n` のように「ハイフン二個+英小文字」の形でオプションを指定できます。以下が使用可能できます。

* `--n` Negative prompt up to the next option.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--s` Specifies the number of steps in the generation.
* `--n` ネガティブプロンプト(次のオプションまで)
* `--w` 生成画像の幅を指定
* `--h` 生成画像の高さを指定
* `--d` 生成画像のシード値を指定
* `--l` 生成画像のCFGスケールを指定。FLUX.1モデルでは、デフォルトは `1.0` でCFGなしを意味します。Chromaモデルでは、CFGを有効にするために `4.0` 程度に設定してください
* `--g` 埋め込みガイダンス付きモデル(FLUX.1)の埋め込みガイダンススケールを指定、デフォルトは `3.5`。Chromaモデルでは `0.0` に設定してください
* `--s` 生成時のステップ数を指定

`( )` や `[ ]` などの重みづけも動作します。

+ 69
- 33
scripts/dev/README.md View File

@@ -9,11 +9,54 @@ __Please update PyTorch to 2.4.0. We have tested with `torch==2.4.0` and `torchv
The command to install PyTorch is as follows:
`pip3 install torch==2.4.0 torchvision==0.19.0 --index-url https://download.pytorch.org/whl/cu124`

If you are using DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.

- [FLUX.1 training](#flux1-training)
- [SD3 training](#sd3-training)

### Recent Updates

Jul 30, 2025:
- **Breaking Change**: For FLUX.1 and Chroma training, the CFG (Classifier-Free Guidance, using negative prompts) scale option for sample image generation during training has been changed from `--g` to `--l`. The `--g` option is now used for the embedded guidance scale. Please update your prompts accordingly. See [Sample Image Generation During Training](#sample-image-generation-during-training) for details.

- Support for [Chroma](https://huggingface.co/lodestones/Chroma) has been added in PR [#2157](https://github.com/kohya-ss/sd-scripts/pull/2157). Thank you to lodestones for the high-quality model.
- Chroma is a new model based on FLUX.1 schnell. In this repository, `flux_train_network.py` is used for training LoRAs for Chroma with `--model_type chroma`. `--apply_t5_attn_mask` is also needed for Chroma training.
- Please refer to the [FLUX.1 LoRA training documentation](./docs/flux_train_network.md) for more details.

Jul 21, 2025:
- Support for [Lumina-Image 2.0](https://github.com/Alpha-VLLM/Lumina-Image-2.0) has been added in PR [#1927](https://github.com/kohya-ss/sd-scripts/pull/1927) and [#2138](https://github.com/kohya-ss/sd-scripts/pull/2138). Special thanks to sdbds and RockerBOO for their contributions.
- Please refer to the [Lumina-Image 2.0 documentation](./docs/lumina_train_network.md) for more details.
- We have started adding comprehensive training-related documentation to [docs](./docs). These documents are being created with the help of generative AI and will be updated over time. While there are still many gaps at this stage, we plan to improve them gradually.

Currently, the following documents are available:
- train_network.md
- sdxl_train_network.md
- sdxl_train_network_advanced.md
- flux_train_network.md
- sd3_train_network.md
- lumina_train_network.md
Jul 10, 2025:
- [AI Coding Agents](#for-developers-using-ai-coding-agents) section is added to the README. This section provides instructions for developers using AI coding agents like Claude and Gemini to understand the project context and coding standards.

May 1, 2025:
- The error when training FLUX.1 with mixed precision in flux_train.py with DeepSpeed enabled has been resolved. Thanks to sharlynxy for PR [#2060](https://github.com/kohya-ss/sd-scripts/pull/2060). Please refer to the PR for details.
- If you enable DeepSpeed, please install DeepSpeed with `pip install deepspeed==0.16.7`.

Apr 27, 2025:
- FLUX.1 training now supports CFG scale in the sample generation during training. Please use `--g` option, to specify the CFG scale (note that `--l` is used as the embedded guidance scale.) PR [#2064](https://github.com/kohya-ss/sd-scripts/pull/2064).
- See [here](#sample-image-generation-during-training) for details.
- If you have any issues with this, please let us know.

Apr 6, 2025:
- IP noise gamma has been enabled in FLUX.1. Thanks to rockerBOO for PR [#1992](https://github.com/kohya-ss/sd-scripts/pull/1992). See the PR for details.
- `--ip_noise_gamma` and `--ip_noise_gamma_random_strength` are available.
Mar 30, 2025:
- LoRA-GGPO is added for FLUX.1 LoRA training. Thank you to rockerBOO for PR [#1974](https://github.com/kohya-ss/sd-scripts/pull/1974).
- Specify `--network_args ggpo_sigma=0.03 ggpo_beta=0.01` in the command line or `network_args = ["ggpo_sigma=0.03", "ggpo_beta=0.01"]` in .toml file. See PR for details.
- The interpolation method for resizing the original image to the training size can now be specified. Thank you to rockerBOO for PR [#1936](https://github.com/kohya-ss/sd-scripts/pull/1936).

Mar 20, 2025:
- `pytorch-optimizer` is added to requirements.txt. Thank you to gesen2egee for PR [#1985](https://github.com/kohya-ss/sd-scripts/pull/1985).
- For example, you can use CAME optimizer with `--optimizer_type "pytorch_optimizer.CAME" --optimizer_args "weight_decay=0.01"`.
@@ -34,46 +77,30 @@ Jan 25, 2025:
- It will be added to other scripts as well.
- As a current limitation, validation loss is not supported when `--block_to_swap` is specified, or when schedule-free optimizer is used.

Dec 15, 2024:
## For Developers Using AI Coding Agents

- RAdamScheduleFree optimizer is supported. PR [#1830](https://github.com/kohya-ss/sd-scripts/pull/1830) Thanks to nhamanasu!
- Update to `schedulefree==1.4` is required. Please update individually or with `pip install --use-pep517 --upgrade -r requirements.txt`.
- Available with `--optimizer_type=RAdamScheduleFree`. No need to specify warm up steps as well as learning rate scheduler.
This repository provides recommended instructions to help AI agents like Claude and Gemini understand our project context and coding standards.

Dec 7, 2024:
To use them, you need to opt-in by creating your own configuration file in the project root.

- The option to specify the model name during ControlNet training was different in each script. It has been unified. Please specify `--controlnet_model_name_or_path`. PR [#1821](https://github.com/kohya-ss/sd-scripts/pull/1821) Thanks to sdbds!
<!--
Also, the ControlNet training script for SD has been changed from `train_controlnet.py` to `train_control_net.py`.
- `train_controlnet.py` is still available, but it will be removed in the future.
-->
**Quick Setup:**

- Fixed an issue where the saved model would be corrupted (pos_embed would not be saved) when `--enable_scaled_pos_embed` was specified in `sd3_train.py`.
1. Create a `CLAUDE.md` and/or `GEMINI.md` file in the project root.
2. Add the following line to your `CLAUDE.md` to import the repository's recommended prompt:

Dec 3, 2024:
```markdown
@./.ai/claude.prompt.md
```

-`--blocks_to_swap` now works in FLUX.1 ControlNet training. Sample commands for 24GB VRAM and 16GB VRAM are added [here](#flux1-controlnet-training).
or for Gemini:

Dec 2, 2024:
```markdown
@./.ai/gemini.prompt.md
```

- FLUX.1 ControlNet training is supported. PR [#1813](https://github.com/kohya-ss/sd-scripts/pull/1813). Thanks to minux302! See PR and [here](#flux1-controlnet-training) for details.
- Not fully tested. Feedback is welcome.
- 80GB VRAM is required for 1024x1024 resolution, and 48GB VRAM is required for 512x512 resolution.
- Currently, it only works in Linux environment (or Windows WSL2) because DeepSpeed is required.
- Multi-GPU training is not tested.
3. You can now add your own personal instructions below the import line (e.g., `Always respond in Japanese.`).

Dec 1, 2024:

- Pseudo Huber loss is now available for FLUX.1 and SD3.5 training. See PR [#1808](https://github.com/kohya-ss/sd-scripts/pull/1808) for details. Thanks to recris!
- Specify `--loss_type huber` or `--loss_type smooth_l1` to use it. `--huber_c` and `--huber_scale` are also available.

- [Prodigy + ScheduleFree](https://github.com/LoganBooker/prodigy-plus-schedule-free) is supported. See PR [#1811](https://github.com/kohya-ss/sd-scripts/pull/1811) for details. Thanks to rockerBOO!

Nov 14, 2024:

- Improved the implementation of block swap and made it available for both FLUX.1 and SD3 LoRA training. See [FLUX.1 LoRA training](#flux1-lora-training) etc. for how to use the new options. Training is possible with about 8-10GB of VRAM.
- During fine-tuning, the memory usage when specifying the same number of blocks has increased slightly, but the training speed when specifying block swap has been significantly improved.
- There may be bugs due to the significant changes. Feedback is welcome.
This approach ensures that you have full control over the instructions given to your agent while benefiting from the shared project context. Your `CLAUDE.md` and `GEMINI.md` are already listed in `.gitignore`, so it won't be committed to the repository.

## FLUX.1 training

@@ -861,6 +888,14 @@ Note: Some user reports ``ValueError: fp16 mixed precision requires a GPU`` is o

(Single GPU with id `0` will be used.)

## DeepSpeed installation (experimental, Linux or WSL2 only)
To install DeepSpeed, run the following command in your activated virtual environment:

```bash
pip install deepspeed==0.16.7
```

## Upgrade

When a new release comes out you can upgrade your repo with the following command:
@@ -1335,11 +1370,12 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b

Lines beginning with `#` are comments. You can specify options for the generated image with options like `--n` after the prompt. The following can be used.

* `--n` Negative prompt up to the next option.
* `--n` Negative prompt up to the next option. Ignored when CFG scale is `1.0`.
* `--w` Specifies the width of the generated image.
* `--h` Specifies the height of the generated image.
* `--d` Specifies the seed of the generated image.
* `--l` Specifies the CFG scale of the generated image.
* `--l` Specifies the CFG scale of the generated image. For FLUX.1 models, the default is `1.0`, which means no CFG. For Chroma models, set to around `4.0` to enable CFG.
* `--g` Specifies the embedded guidance scale for the models with embedded guidance (FLUX.1), the default is `3.5`. Set to `0.0` for Chroma models.
* `--s` Specifies the number of steps in the generation.

The prompt weighting such as `( )` and `[ ]` are working.

+ 2
- 0
scripts/dev/fine_tune.py View File

@@ -27,6 +27,7 @@ logger = logging.getLogger(__name__)

import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -519,6 +520,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)


+ 12
- 11
scripts/dev/finetune/tag_images_by_wd14_tagger.py View File

@@ -11,7 +11,7 @@ from PIL import Image
from tqdm import tqdm

import library.train_util as train_util
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image

setup_logging()
import logging
@@ -42,10 +42,7 @@ def preprocess_image(image):
pad_t = pad_y // 2
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)

if size > IMAGE_SIZE:
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
else:
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
image = resize_image(image, image.shape[0], image.shape[1], IMAGE_SIZE, IMAGE_SIZE)

image = image.astype(np.float32)
return image
@@ -100,15 +97,19 @@ def main(args):
else:
for file in SUB_DIR_FILES:
hf_hub_download(
args.repo_id,
file,
repo_id=args.repo_id,
filename=file,
subfolder=SUB_DIR,
cache_dir=os.path.join(model_location, SUB_DIR),
local_dir=os.path.join(model_location, SUB_DIR),
force_download=True,
force_filename=file,
)
for file in files:
hf_hub_download(args.repo_id, file, cache_dir=model_location, force_download=True, force_filename=file)
hf_hub_download(
repo_id=args.repo_id,
filename=file,
local_dir=model_location,
force_download=True,
)
else:
logger.info("using existing wd14 tagger model")

@@ -149,7 +150,7 @@ def main(args):
ort_sess = ort.InferenceSession(
onnx_path,
providers=(["OpenVINOExecutionProvider"]),
provider_options=[{'device_type' : "GPU_FP32"}],
provider_options=[{'device_type' : "GPU", "precision": "FP32"}],
)
else:
ort_sess = ort.InferenceSession(


+ 43
- 25
scripts/dev/flux_minimal_inference.py View File

@@ -78,16 +78,19 @@ def denoise(
neg_t5_attn_mask: Optional[torch.Tensor] = None,
cfg_scale: Optional[float] = None,
):
# this is ignored for schnell
# prepare classifier free guidance
logger.info(f"guidance: {guidance}, cfg_scale: {cfg_scale}")
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
do_cfg = neg_txt is not None and (cfg_scale is not None and cfg_scale != 1.0)

# prepare classifier free guidance
if neg_txt is not None and neg_vec is not None:
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0] * (2 if do_cfg else 1),), guidance, device=img.device, dtype=img.dtype)

if do_cfg:
print("Using classifier free guidance")
b_img_ids = torch.cat([img_ids, img_ids], dim=0)
b_txt_ids = torch.cat([txt_ids, txt_ids], dim=0)
b_txt = torch.cat([neg_txt, txt], dim=0)
b_vec = torch.cat([neg_vec, vec], dim=0)
b_vec = torch.cat([neg_vec, vec], dim=0) if neg_vec is not None else None
if t5_attn_mask is not None and neg_t5_attn_mask is not None:
b_t5_attn_mask = torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)
else:
@@ -103,24 +106,29 @@ def denoise(
t_vec = torch.full((b_img_ids.shape[0],), t_curr, dtype=img.dtype, device=img.device)

# classifier free guidance
if neg_txt is not None and neg_vec is not None:
if do_cfg:
b_img = torch.cat([img, img], dim=0)
else:
b_img = img

y_input = b_vec

mod_vectors = model.get_mod_vectors(timesteps=t_vec, guidance=guidance_vec, batch_size=b_img.shape[0])

pred = model(
img=b_img,
img_ids=b_img_ids,
txt=b_txt,
txt_ids=b_txt_ids,
y=b_vec,
y=y_input,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=b_t5_attn_mask,
mod_vectors=mod_vectors,
)

# classifier free guidance
if neg_txt is not None and neg_vec is not None:
if do_cfg:
pred_uncond, pred = torch.chunk(pred, 2, dim=0)
pred = pred_uncond + cfg_scale * (pred - pred_uncond)

@@ -134,7 +142,7 @@ def do_sample(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
l_pooled: torch.Tensor,
l_pooled: Optional[torch.Tensor],
t5_out: torch.Tensor,
txt_ids: torch.Tensor,
num_steps: int,
@@ -192,7 +200,7 @@ def do_sample(

def generate_image(
model,
clip_l: CLIPTextModel,
clip_l: Optional[CLIPTextModel],
t5xxl,
ae,
prompt: str,
@@ -231,7 +239,7 @@ def generate_image(
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width)

# prepare fp8 models
if is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
if clip_l is not None and is_fp8(clip_l_dtype) and (not hasattr(clip_l, "fp8_prepared") or not clip_l.fp8_prepared):
logger.info(f"prepare CLIP-L for fp8: set to {clip_l_dtype}, set embeddings to {torch.bfloat16}")
clip_l.to(clip_l_dtype) # fp8
clip_l.text_model.embeddings.to(dtype=torch.bfloat16)
@@ -267,18 +275,22 @@ def generate_image(

# prepare embeddings
logger.info("Encoding prompts...")
clip_l = clip_l.to(device)
if clip_l is not None:
clip_l = clip_l.to(device)
t5xxl = t5xxl.to(device)

def encode(prpt: str):
tokens_and_masks = tokenize_strategy.tokenize(prpt)
with torch.no_grad():
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
if clip_l is not None:
if is_fp8(clip_l_dtype):
with accelerator.autocast():
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
else:
with torch.autocast(device_type=device.type, dtype=clip_l_dtype):
l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks)
l_pooled = None

if is_fp8(t5xxl_dtype):
with accelerator.autocast():
@@ -288,7 +300,7 @@ def generate_image(
else:
with torch.autocast(device_type=device.type, dtype=t5xxl_dtype):
_, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens(
tokenize_strategy, [None, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask
)
return l_pooled, t5_out, txt_ids, t5_attn_mask

@@ -299,13 +311,14 @@ def generate_image(
neg_l_pooled, neg_t5_out, neg_t5_attn_mask = None, None, None

# NaN check
if torch.isnan(l_pooled).any():
if l_pooled is not None and torch.isnan(l_pooled).any():
raise ValueError("NaN in l_pooled")
if torch.isnan(t5_out).any():
raise ValueError("NaN in t5_out")

if args.offload:
clip_l = clip_l.cpu()
if clip_l is not None:
clip_l = clip_l.cpu()
t5xxl = t5xxl.cpu()
# del clip_l, t5xxl
device_utils.clean_memory()
@@ -318,6 +331,7 @@ def generate_image(

img_ids = img_ids.to(device)
t5_attn_mask = t5_attn_mask.to(device) if args.apply_t5_attn_mask else None
neg_t5_attn_mask = neg_t5_attn_mask.to(device) if neg_t5_attn_mask is not None and args.apply_t5_attn_mask else None

x = do_sample(
accelerator,
@@ -385,6 +399,7 @@ if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--model_type", type=str, choices=["flux", "chroma"], default="flux", help="Model type to use")
parser.add_argument("--clip_l", type=str, required=False)
parser.add_argument("--t5xxl", type=str, required=False)
parser.add_argument("--ae", type=str, required=False)
@@ -438,10 +453,13 @@ if __name__ == "__main__":
else:
accelerator = None

# load clip_l
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
# load clip_l (skip for chroma model)
if args.model_type == "flux":
logger.info(f"Loading clip_l from {args.clip_l}...")
clip_l = flux_utils.load_clip_l(args.clip_l, clip_l_dtype, loading_device)
clip_l.eval()
else:
clip_l = None

logger.info(f"Loading t5xxl from {args.t5xxl}...")
t5xxl = flux_utils.load_t5xxl(args.t5xxl, t5xxl_dtype, loading_device)
@@ -453,7 +471,7 @@ if __name__ == "__main__":
# t5xxl = accelerator.prepare(t5xxl)

# DiT
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device)
is_schnell, model = flux_utils.load_flow_model(args.ckpt_path, None, loading_device, model_type=args.model_type)
model.eval()
logger.info(f"Casting model to {flux_dtype}")
model.to(flux_dtype) # make sure model is dtype


+ 3
- 2
scripts/dev/flux_train.py View File

@@ -30,7 +30,7 @@ from library.device_utils import init_ipex, clean_memory_on_device
init_ipex()

from accelerate.utils import set_seed
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux
from library import deepspeed_utils, flux_train_utils, flux_utils, strategy_base, strategy_flux, sai_model_spec
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler

import library.train_util as train_util
@@ -271,7 +271,7 @@ def train(args):

# load FLUX
_, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
)

if args.gradient_checkpointing:
@@ -787,6 +787,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)


+ 8
- 1
scripts/dev/flux_train_control_net.py View File

@@ -32,6 +32,7 @@ init_ipex()
from accelerate.utils import set_seed

import library.train_util as train_util
import library.sai_model_spec as sai_model_spec
from library import (
deepspeed_utils,
flux_train_utils,
@@ -68,6 +69,11 @@ def train(args):
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check

if args.model_type != "flux":
raise ValueError(
f"FLUX.1 ControlNet training requires model_type='flux'. / FLUX.1 ControlNetの学習にはmodel_type='flux'を指定してください。"
)

# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
@@ -259,7 +265,7 @@ def train(args):

# load FLUX
is_schnell, flux = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors, model_type="flux"
)
flux.requires_grad_(False)

@@ -815,6 +821,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)


+ 45
- 57
scripts/dev/flux_train_network.py View File

@@ -35,6 +35,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
self.sample_prompts_te_outputs = None
self.is_schnell: Optional[bool] = None
self.is_swapping_blocks: bool = False
self.model_type: Optional[str] = None

def assert_extra_args(
self,
@@ -45,6 +46,13 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
# sdxl_train_util.verify_sdxl_training_args(args)

self.model_type = args.model_type # "flux" or "chroma"
if self.model_type != "chroma":
self.use_clip_l = True
else:
self.use_clip_l = False # Chroma does not use CLIP-L
assert args.apply_t5_attn_mask, "apply_t5_attn_mask must be True for Chroma / Chromaではapply_t5_attn_maskを指定する必要があります"

if args.fp8_base_unet:
args.fp8_base = True # if fp8_base_unet is enabled, fp8_base is also enabled for FLUX.1

@@ -60,7 +68,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
), "when caching Text Encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / Text Encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"

# prepare CLIP-L/T5XXL training flags
self.train_clip_l = not args.network_train_unet_only
self.train_clip_l = not args.network_train_unet_only and self.use_clip_l
self.train_t5xxl = False # default is False even if args.network_train_unet_only is False

if args.max_token_length is not None:
@@ -95,8 +103,12 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
loading_dtype = None if args.fp8_base else weight_dtype

# if we load to cpu, flux.to(fp8) takes a long time, so we should load to gpu in future
self.is_schnell, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path, loading_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors
_, model = flux_utils.load_flow_model(
args.pretrained_model_name_or_path,
loading_dtype,
"cpu",
disable_mmap=args.disable_mmap_load_safetensors,
model_type=self.model_type,
)
if args.fp8_base:
# check dtype of model
@@ -120,7 +132,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)

clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
if self.use_clip_l:
clip_l = flux_utils.load_clip_l(args.clip_l, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)
else:
clip_l = flux_utils.dummy_clip_l() # dummy CLIP-L for Chroma, which does not use CLIP-L
clip_l.eval()

# if the file is fp8 and we are using fp8_base (not unet), we can load it as is (fp8)
@@ -141,13 +156,20 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):

ae = flux_utils.load_ae(args.ae, weight_dtype, "cpu", disable_mmap=args.disable_mmap_load_safetensors)

return flux_utils.MODEL_VERSION_FLUX_V1, [clip_l, t5xxl], ae, model
model_version = flux_utils.MODEL_VERSION_FLUX_V1 if self.model_type != "chroma" else flux_utils.MODEL_VERSION_CHROMA
return model_version, [clip_l, t5xxl], ae, model

def get_tokenize_strategy(self, args):
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
# This method is called before `assert_extra_args`, so we cannot use `self.is_schnell` here.
# Instead, we analyze the checkpoint state to determine if it is schnell.
if args.model_type != "chroma":
_, is_schnell, _, _ = flux_utils.analyze_checkpoint_state(args.pretrained_model_name_or_path)
else:
is_schnell = False
self.is_schnell = is_schnell

if args.t5xxl_max_token_length is None:
if is_schnell:
if self.is_schnell:
t5xxl_max_token_length = 256
else:
t5xxl_max_token_length = 512
@@ -268,23 +290,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
text_encoders[0].to(accelerator.device, dtype=weight_dtype)
text_encoders[1].to(accelerator.device)

# def call_unet(self, args, accelerator, unet, noisy_latents, timesteps, text_conds, batch, weight_dtype):
# noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

# # get size embeddings
# orig_size = batch["original_sizes_hw"]
# crop_size = batch["crop_top_lefts"]
# target_size = batch["target_sizes_hw"]
# embs = sdxl_train_util.get_size_embeddings(orig_size, crop_size, target_size, accelerator.device).to(weight_dtype)

# # concat embeddings
# encoder_hidden_states1, encoder_hidden_states2, pool2 = text_conds
# vector_embedding = torch.cat([pool2, embs], dim=1).to(weight_dtype)
# text_embedding = torch.cat([encoder_hidden_states1, encoder_hidden_states2], dim=2).to(weight_dtype)

# noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding)
# return noise_pred

def sample_images(self, accelerator, args, epoch, global_step, device, ae, tokenizer, text_encoder, flux):
text_encoders = text_encoder # for compatibility
text_encoders = self.get_models_for_text_encoding(args, accelerator, text_encoders)
@@ -292,36 +297,6 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, flux, ae, text_encoders, self.sample_prompts_te_outputs
)
# return

"""
class FluxUpperLowerWrapper(torch.nn.Module):
def __init__(self, flux_upper: flux_models.FluxUpper, flux_lower: flux_models.FluxLower, device: torch.device):
super().__init__()
self.flux_upper = flux_upper
self.flux_lower = flux_lower
self.target_device = device

def prepare_block_swap_before_forward(self):
pass

def forward(self, img, img_ids, txt, txt_ids, timesteps, y, guidance=None, txt_attention_mask=None):
self.flux_lower.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_upper.to(self.target_device)
img, txt, vec, pe = self.flux_upper(img, img_ids, txt, txt_ids, timesteps, y, guidance, txt_attention_mask)
self.flux_upper.to("cpu")
clean_memory_on_device(self.target_device)
self.flux_lower.to(self.target_device)
return self.flux_lower(img, txt, vec, pe, txt_attention_mask)

wrapper = FluxUpperLowerWrapper(self.flux_upper, flux, accelerator.device)
clean_memory_on_device(accelerator.device)
flux_train_utils.sample_images(
accelerator, args, epoch, global_step, wrapper, ae, text_encoders, self.sample_prompts_te_outputs
)
clean_memory_on_device(accelerator.device)
"""

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
@@ -366,7 +341,10 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
# ensure guidance_scale in args is float
guidance_vec = torch.full((bsz,), float(args.guidance_scale), device=accelerator.device)

# ensure the hidden state will require grad
# get modulation vectors for Chroma
with accelerator.autocast(), torch.no_grad():
mod_vectors = unet.get_mod_vectors(timesteps=timesteps / 1000, guidance=guidance_vec, batch_size=bsz)

if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
@@ -374,13 +352,15 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
t.requires_grad_(True)
img_ids.requires_grad_(True)
guidance_vec.requires_grad_(True)
if mod_vectors is not None:
mod_vectors.requires_grad_(True)

# Predict the noise residual
l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
if not args.apply_t5_attn_mask:
t5_attn_mask = None

def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask):
def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t5_attn_mask, mod_vectors):
# grad is enabled even if unet is not in train mode, because Text Encoder is in train mode
with torch.set_grad_enabled(is_train), accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
@@ -393,6 +373,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps / 1000,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
mod_vectors=mod_vectors,
)
return model_pred

@@ -405,6 +386,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps,
guidance_vec=guidance_vec,
t5_attn_mask=t5_attn_mask,
mod_vectors=mod_vectors,
)

# unpack latents
@@ -436,6 +418,7 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
timesteps=timesteps[diff_output_pr_indices],
guidance_vec=guidance_vec[diff_output_pr_indices] if guidance_vec is not None else None,
t5_attn_mask=t5_attn_mask[diff_output_pr_indices] if t5_attn_mask is not None else None,
mod_vectors=mod_vectors[diff_output_pr_indices] if mod_vectors is not None else None,
)
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step

@@ -454,9 +437,14 @@ class FluxNetworkTrainer(train_network.NetworkTrainer):
return loss

def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, flux="dev")
if self.model_type != "chroma":
model_description = "schnell" if self.is_schnell else "dev"
else:
model_description = "chroma"
return train_util.get_sai_model_spec(None, args, False, True, False, flux=model_description)

def update_metadata(self, metadata, args):
metadata["ss_model_type"] = args.model_type
metadata["ss_apply_t5_attn_mask"] = args.apply_t5_attn_mask
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean


+ 744
- 0
scripts/dev/library/chroma_models.py View File

@@ -0,0 +1,744 @@
# copy from the official repo: https://github.com/lodestone-rock/flow/blob/master/src/models/chroma/model.py
# and modified
# licensed under Apache License 2.0

import math
from dataclasses import dataclass

import torch
from einops import rearrange
from torch import Tensor, nn
import torch.nn.functional as F
import torch.utils.checkpoint as ckpt

from .flux_models import attention, rope, apply_rope, EmbedND, timestep_embedding, MLPEmbedder, RMSNorm, QKNorm, SelfAttention, Flux
from . import custom_offloading_utils


def distribute_modulations(tensor: torch.Tensor, depth_single_blocks, depth_double_blocks):
"""
Distributes slices of the tensor into the block_dict as ModulationOut objects.

Args:
tensor (torch.Tensor): Input tensor with shape [batch_size, vectors, dim].
"""
batch_size, vectors, dim = tensor.shape

block_dict = {}

# HARD CODED VALUES! lookup table for the generated vectors
# TODO: move this into chroma config!
# Add 38 single mod blocks
for i in range(depth_single_blocks):
key = f"single_blocks.{i}.modulation.lin"
block_dict[key] = None

# Add 19 image double blocks
for i in range(depth_double_blocks):
key = f"double_blocks.{i}.img_mod.lin"
block_dict[key] = None

# Add 19 text double blocks
for i in range(depth_double_blocks):
key = f"double_blocks.{i}.txt_mod.lin"
block_dict[key] = None

# Add the final layer
block_dict["final_layer.adaLN_modulation.1"] = None
# 6.2b version
# block_dict["lite_double_blocks.4.img_mod.lin"] = None
# block_dict["lite_double_blocks.4.txt_mod.lin"] = None

idx = 0 # Index to keep track of the vector slices

for key in block_dict.keys():
if "single_blocks" in key:
# Single block: 1 ModulationOut
block_dict[key] = ModulationOut(
shift=tensor[:, idx : idx + 1, :],
scale=tensor[:, idx + 1 : idx + 2, :],
gate=tensor[:, idx + 2 : idx + 3, :],
)
idx += 3 # Advance by 3 vectors

elif "img_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx : idx + 1, :],
scale=tensor[:, idx + 1 : idx + 2, :],
gate=tensor[:, idx + 2 : idx + 3, :],
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block

elif "txt_mod" in key:
# Double block: List of 2 ModulationOut
double_block = []
for _ in range(2): # Create 2 ModulationOut objects
double_block.append(
ModulationOut(
shift=tensor[:, idx : idx + 1, :],
scale=tensor[:, idx + 1 : idx + 2, :],
gate=tensor[:, idx + 2 : idx + 3, :],
)
)
idx += 3 # Advance by 3 vectors per ModulationOut
block_dict[key] = double_block

elif "final_layer" in key:
# Final layer: 1 ModulationOut
block_dict[key] = [
tensor[:, idx : idx + 1, :],
tensor[:, idx + 1 : idx + 2, :],
]
idx += 2 # Advance by 3 vectors

return block_dict


class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=4):
super().__init__()
self.in_proj = nn.Linear(in_dim, hidden_dim, bias=True)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim) for x in range(n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim) for x in range(n_layers)])
self.out_proj = nn.Linear(hidden_dim, out_dim)

@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device

def enable_gradient_checkpointing(self):
for layer in self.layers:
layer.enable_gradient_checkpointing()

def disable_gradient_checkpointing(self):
for layer in self.layers:
layer.disable_gradient_checkpointing()

def forward(self, x: Tensor) -> Tensor:
x = self.in_proj(x)

for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))

x = self.out_proj(x)

return x


@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor


def _modulation_shift_scale_fn(x, scale, shift):
return (1 + scale) * x + shift


def _modulation_gate_fn(x, gate, gate_params):
return x + gate * gate_params


class DoubleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float,
qkv_bias: bool = False,
):
super().__init__()

mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
)

self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)

self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
)

self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approximate="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)

self.gradient_checkpointing = False

@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device

def modulation_shift_scale_fn(self, x, scale, shift):
return _modulation_shift_scale_fn(x, scale, shift)

def modulation_gate_fn(self, x, gate, gate_params):
return _modulation_gate_fn(x, gate, gate_params)

def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True

def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False

def _forward(
self,
img: Tensor,
txt: Tensor,
pe: list[Tensor],
distill_vec: list[ModulationOut],
txt_seq_len: Tensor,
) -> tuple[Tensor, Tensor]:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = distill_vec

# prepare image for attention
img_modulated = self.img_norm1(img)
# replaced with compiled fn
# img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = self.modulation_shift_scale_fn(img_modulated, img_mod1.scale, img_mod1.shift)
img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated

img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
# replaced with compiled fn
# txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = self.modulation_shift_scale_fn(txt_modulated, txt_mod1.scale, txt_mod1.shift)
txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated

txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention: we split the batch into each element
max_txt_len = torch.max(txt_seq_len).item()
img_len = img_q.shape[-2] # max 64
txt_q = list(torch.chunk(txt_q, txt_q.shape[0], dim=0)) # list of [B, H, L, D] tensors
txt_k = list(torch.chunk(txt_k, txt_k.shape[0], dim=0))
txt_v = list(torch.chunk(txt_v, txt_v.shape[0], dim=0))
img_q = list(torch.chunk(img_q, img_q.shape[0], dim=0))
img_k = list(torch.chunk(img_k, img_k.shape[0], dim=0))
img_v = list(torch.chunk(img_v, img_v.shape[0], dim=0))
txt_attn = []
img_attn = []
for i in range(txt.shape[0]):
txt_q[i] = txt_q[i][:, :, : txt_seq_len[i]]
q = torch.cat((img_q[i], txt_q[i]), dim=2)
txt_q[i] = None
img_q[i] = None

txt_k[i] = txt_k[i][:, :, : txt_seq_len[i]]
k = torch.cat((img_k[i], txt_k[i]), dim=2)
txt_k[i] = None
img_k[i] = None

txt_v[i] = txt_v[i][:, :, : txt_seq_len[i]]
v = torch.cat((img_v[i], txt_v[i]), dim=2)
txt_v[i] = None
img_v[i] = None

attn = attention(q, k, v, pe=pe[i : i + 1, :, : q.shape[2]], attn_mask=None) # attn = (1, L, D)
del q, k, v
img_attn_i = attn[:, :img_len, :]
txt_attn_i = torch.zeros((1, max_txt_len, attn.shape[-1]), dtype=attn.dtype, device=self.device)
txt_attn_i[:, : txt_seq_len[i], :] = attn[:, img_len:, :]
del attn
txt_attn.append(txt_attn_i)
img_attn.append(img_attn_i)

txt_attn = torch.cat(txt_attn, dim=0)
img_attn = torch.cat(img_attn, dim=0)

# q = torch.cat((txt_q, img_q), dim=2)
# k = torch.cat((txt_k, img_k), dim=2)
# v = torch.cat((txt_v, img_v), dim=2)

# attn = attention(q, k, v, pe=pe, attn_mask=mask)
# txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img blocks
# replaced with compiled fn
# img = img + img_mod1.gate * self.img_attn.proj(img_attn)
# img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img = self.modulation_gate_fn(img, img_mod1.gate, self.img_attn.proj(img_attn))
del img_attn, img_mod1
img = self.modulation_gate_fn(
img,
img_mod2.gate,
self.img_mlp(self.modulation_shift_scale_fn(self.img_norm2(img), img_mod2.scale, img_mod2.shift)),
)
del img_mod2

# calculate the txt blocks
# replaced with compiled fn
# txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
# txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt = self.modulation_gate_fn(txt, txt_mod1.gate, self.txt_attn.proj(txt_attn))
del txt_attn, txt_mod1
txt = self.modulation_gate_fn(
txt,
txt_mod2.gate,
self.txt_mlp(self.modulation_shift_scale_fn(self.txt_norm2(txt), txt_mod2.scale, txt_mod2.shift)),
)
del txt_mod2

return img, txt

def forward(
self,
img: Tensor,
txt: Tensor,
pe: Tensor,
distill_vec: list[ModulationOut],
txt_seq_len: Tensor,
) -> tuple[Tensor, Tensor]:
if self.training and self.gradient_checkpointing:
return ckpt.checkpoint(self._forward, img, txt, pe, distill_vec, txt_seq_len, use_reentrant=False)
else:
return self._forward(img, txt, pe, distill_vec, txt_seq_len)


class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float | None = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5

self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)

self.norm = QKNorm(head_dim)

self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)

self.mlp_act = nn.GELU(approximate="tanh")

self.gradient_checkpointing = False

@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device

def modulation_shift_scale_fn(self, x, scale, shift):
return _modulation_shift_scale_fn(x, scale, shift)

def modulation_gate_fn(self, x, gate, gate_params):
return _modulation_gate_fn(x, gate, gate_params)

def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True

def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False

def _forward(self, x: Tensor, pe: list[Tensor], distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor:
mod = distill_vec
# replaced with compiled fn
# x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
x_mod = self.modulation_shift_scale_fn(self.pre_norm(x), mod.scale, mod.shift)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
del x_mod

q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
del qkv
q, k = self.norm(q, k, v)

# # compute attention
# attn = attention(q, k, v, pe=pe, attn_mask=mask)

# compute attention: we split the batch into each element
max_txt_len = torch.max(txt_seq_len).item()
img_len = q.shape[-2] - max_txt_len
q = list(torch.chunk(q, q.shape[0], dim=0))
k = list(torch.chunk(k, k.shape[0], dim=0))
v = list(torch.chunk(v, v.shape[0], dim=0))
attn = []
for i in range(x.size(0)):
q[i] = q[i][:, :, : img_len + txt_seq_len[i]]
k[i] = k[i][:, :, : img_len + txt_seq_len[i]]
v[i] = v[i][:, :, : img_len + txt_seq_len[i]]
attn_trimmed = attention(q[i], k[i], v[i], pe=pe[i : i + 1, :, : img_len + txt_seq_len[i]], attn_mask=None)
q[i] = None
k[i] = None
v[i] = None

attn_i = torch.zeros((1, x.shape[1], attn_trimmed.shape[-1]), dtype=attn_trimmed.dtype, device=self.device)
attn_i[:, : img_len + txt_seq_len[i], :] = attn_trimmed
del attn_trimmed
attn.append(attn_i)

attn = torch.cat(attn, dim=0)

# compute activation in mlp stream, cat again and run second linear layer
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
del attn, mlp
# replaced with compiled fn
# return x + mod.gate * output
return self.modulation_gate_fn(x, mod.gate, output)

def forward(self, x: Tensor, pe: Tensor, distill_vec: list[ModulationOut], txt_seq_len: Tensor) -> Tensor:
if self.training and self.gradient_checkpointing:
return ckpt.checkpoint(self._forward, x, pe, distill_vec, txt_seq_len, use_reentrant=False)
else:
return self._forward(x, pe, distill_vec, txt_seq_len)


class LastLayer(nn.Module):
def __init__(
self,
hidden_size: int,
patch_size: int,
out_channels: int,
):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)

@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device

def modulation_shift_scale_fn(self, x, scale, shift):
return _modulation_shift_scale_fn(x, scale, shift)

def forward(self, x: Tensor, distill_vec: list[Tensor]) -> Tensor:
shift, scale = distill_vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
# replaced with compiled fn
# x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.modulation_shift_scale_fn(self.norm_final(x), scale[:, None, :], shift[:, None, :])
x = self.linear(x)
return x


@dataclass
class ChromaParams:
in_channels: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
approximator_in_dim: int
approximator_depth: int
approximator_hidden_size: int
_use_compiled: bool


chroma_params = ChromaParams(
in_channels=64,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
approximator_in_dim=64,
approximator_depth=5,
approximator_hidden_size=5120,
_use_compiled=False,
)


def modify_mask_to_attend_padding(mask, max_seq_length, num_extra_padding=8):
"""
Modifies attention mask to allow attention to a few extra padding tokens.

Args:
mask: Original attention mask (1 for tokens to attend to, 0 for masked tokens)
max_seq_length: Maximum sequence length of the model
num_extra_padding: Number of padding tokens to unmask

Returns:
Modified mask
"""
# Get the actual sequence length from the mask
seq_length = mask.sum(dim=-1)
batch_size = mask.shape[0]

modified_mask = mask.clone()

for i in range(batch_size):
current_seq_len = int(seq_length[i].item())

# Only add extra padding tokens if there's room
if current_seq_len < max_seq_length:
# Calculate how many padding tokens we can unmask
available_padding = max_seq_length - current_seq_len
tokens_to_unmask = min(num_extra_padding, available_padding)

# Unmask the specified number of padding tokens right after the sequence
modified_mask[i, current_seq_len : current_seq_len + tokens_to_unmask] = 1

return modified_mask


class Chroma(Flux):
"""
Transformer model for flow matching on sequences.
"""

def __init__(self, params: ChromaParams):
nn.Module.__init__(self)
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)

# TODO: need proper mapping for this approximator output!
# currently the mapping is hardcoded in distribute_modulations function
self.distilled_guidance_layer = Approximator(
params.approximator_in_dim,
self.hidden_size,
params.approximator_hidden_size,
params.approximator_depth,
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)

self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
)
for _ in range(params.depth_single_blocks)
]
)

self.final_layer = LastLayer(
self.hidden_size,
1,
self.out_channels,
)

# TODO: move this hardcoded value to config
# single layer has 3 modulation vectors
# double layer has 6 modulation vectors for each expert
# final layer has 2 modulation vectors
self.mod_index_length = 3 * params.depth_single_blocks + 2 * 6 * params.depth + 2
self.depth_single_blocks = params.depth_single_blocks
self.depth_double_blocks = params.depth
# self.mod_index = torch.tensor(list(range(self.mod_index_length)), device=0)
self.register_buffer(
"mod_index",
torch.tensor(list(range(self.mod_index_length)), device="cpu"),
persistent=False,
)
self.approximator_in_dim = params.approximator_in_dim

self.blocks_to_swap = None
self.offloader_double = None
self.offloader_single = None
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)

# Initialize properties required by Flux parent class
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False

def get_model_type(self) -> str:
return "chroma"

def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload

self.distilled_guidance_layer.enable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.enable_gradient_checkpointing()

print(f"Chroma: Gradient checkpointing enabled.")

def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False

self.distilled_guidance_layer.disable_gradient_checkpointing()
for block in self.double_blocks + self.single_blocks:
block.disable_gradient_checkpointing()

print("Chroma: Gradient checkpointing disabled.")

def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
# We extract this logic from forward to clarify the propagation of the gradients
# original comment: https://github.com/lodestone-rock/flow/blob/c76f63058980d0488826936025889e256a2e0458/src/models/chroma/model.py#L195

# print(f"Chroma get_input_vec: timesteps {timesteps}, guidance: {guidance}, batch_size: {batch_size}")
distill_timestep = timestep_embedding(timesteps, self.approximator_in_dim // 4)
# TODO: need to add toggle to omit this from schnell but that's not a priority
distil_guidance = timestep_embedding(guidance, self.approximator_in_dim // 4)
# get all modulation index
modulation_index = timestep_embedding(self.mod_index, self.approximator_in_dim // 2)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(batch_size, 1, 1)
# and we need to broadcast timestep and guidance along too
timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, self.mod_index_length, 1)
# then and only then we could concatenate it together
input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1)

mod_vectors = self.distilled_guidance_layer(input_vec)
return mod_vectors

def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
block_controlnet_hidden_states=None,
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
attn_padding: int = 1,
mod_vectors: Tensor | None = None,
) -> Tensor:
# print(
# f"Chroma forward: img shape {img.shape}, txt shape {txt.shape}, img_ids shape {img_ids.shape}, txt_ids shape {txt_ids.shape}"
# )
# print(f"input_vec shape: {input_vec.shape if input_vec is not None else 'None'}")
# print(f"timesteps: {timesteps}, guidance: {guidance}")

if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")

# running on sequences img
img = self.img_in(img)
txt = self.txt_in(txt)

if mod_vectors is None: # fallback to the original logic
with torch.no_grad():
mod_vectors = self.get_mod_vectors(timesteps, guidance, img.shape[0])
mod_vectors_dict = distribute_modulations(mod_vectors, self.depth_single_blocks, self.depth_double_blocks)

# calculate text length for each batch instead of masking
txt_emb_len = txt.shape[1]
txt_seq_len = txt_attention_mask[:, :txt_emb_len].sum(dim=-1).to(torch.int64) # (batch_size, )
txt_seq_len = torch.clip(txt_seq_len + attn_padding, 0, txt_emb_len)
max_txt_len = torch.max(txt_seq_len).item() # max text length in the batch
# print(f"max_txt_len: {max_txt_len}, txt_seq_len: {txt_seq_len}")

# trim txt embedding to the text length
txt = txt[:, :max_txt_len, :]

# create positional encoding for the text and image
ids = torch.cat((img_ids, txt_ids[:, :max_txt_len]), dim=1) # reverse order of ids for faster attention
pe = self.pe_embedder(ids) # B, 1, seq_length, 64, 2, 2

for i, block in enumerate(self.double_blocks):
if self.blocks_to_swap:
self.offloader_double.wait_for_block(i)

# the guidance replaced by FFN output
img_mod = mod_vectors_dict.pop(f"double_blocks.{i}.img_mod.lin")
txt_mod = mod_vectors_dict.pop(f"double_blocks.{i}.txt_mod.lin")
double_mod = [img_mod, txt_mod]
del img_mod, txt_mod

img, txt = block(img=img, txt=txt, pe=pe, distill_vec=double_mod, txt_seq_len=txt_seq_len)
del double_mod

if self.blocks_to_swap:
self.offloader_double.submit_move_blocks(self.double_blocks, i)

img = torch.cat((img, txt), 1)
del txt

for i, block in enumerate(self.single_blocks):
if self.blocks_to_swap:
self.offloader_single.wait_for_block(i)

single_mod = mod_vectors_dict.pop(f"single_blocks.{i}.modulation.lin")
img = block(img, pe=pe, distill_vec=single_mod, txt_seq_len=txt_seq_len)
del single_mod

if self.blocks_to_swap:
self.offloader_single.submit_move_blocks(self.single_blocks, i)

img = img[:, :-max_txt_len, ...]
final_mod = mod_vectors_dict["final_layer.adaLN_modulation.1"]
img = self.final_layer(img, distill_vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
return img

+ 6
- 1
scripts/dev/library/config_util.py View File

@@ -75,6 +75,7 @@ class BaseSubsetParams:
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0
resize_interpolation: Optional[str] = None


@dataclass
@@ -106,7 +107,7 @@ class BaseDatasetParams:
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0
resize_interpolation: Optional[str] = None

@dataclass
class DreamBoothDatasetParams(BaseDatasetParams):
@@ -196,6 +197,7 @@ class ConfigSanitizer:
"caption_prefix": str,
"caption_suffix": str,
"custom_attributes": dict,
"resize_interpolation": str,
}
# DO means DropOut
DO_SUBSET_ASCENDABLE_SCHEMA = {
@@ -241,6 +243,7 @@ class ConfigSanitizer:
"validation_split": float,
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
"network_multiplier": float,
"resize_interpolation": str,
}

# options handled by argparse but not handled by user config
@@ -525,6 +528,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
[{dataset_type} {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
resize_interpolation: {dataset.resize_interpolation}
enable_bucket: {dataset.enable_bucket}
""")

@@ -558,6 +562,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
alpha_mask: {subset.alpha_mask}
resize_interpolation: {subset.resize_interpolation}
custom_attributes: {subset.custom_attributes}
"""), " ")



+ 17
- 13
scripts/dev/library/custom_offloading_utils.py View File

@@ -1,6 +1,6 @@
from concurrent.futures import ThreadPoolExecutor
import time
from typing import Optional
from typing import Optional, Union, Callable, Tuple
import torch
import torch.nn as nn

@@ -19,7 +19,7 @@ def synchronize_device(device: torch.device):
def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []

# This is not working for all cases (e.g. SD3), so we need to find the corresponding modules
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
@@ -42,7 +42,7 @@ def swap_weight_devices_cuda(device: torch.device, layer_to_cpu: nn.Module, laye

torch.cuda.current_stream().synchronize() # this prevents the illegal loss value

stream = torch.cuda.Stream()
stream = torch.Stream(device="cuda")
with torch.cuda.stream(stream):
# cuda to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
@@ -66,23 +66,24 @@ def swap_weight_devices_no_cuda(device: torch.device, layer_to_cpu: nn.Module, l
"""
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
weight_swap_jobs: list[Tuple[nn.Module, nn.Module, torch.Tensor, torch.Tensor]] = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))


# device to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_data_view.data.to("cpu", non_blocking=True)

synchronize_device()
synchronize_device(device)

# cpu to device
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view

synchronize_device()
synchronize_device(device)


def weighs_to_device(layer: nn.Module, device: torch.device):
@@ -148,13 +149,16 @@ class Offloader:
print(f"Waited for block {block_idx}: {time.perf_counter()-start_time:.2f}s")


# Gradient tensors
_grad_t = Union[tuple[torch.Tensor, ...], torch.Tensor]

class ModelOffloader(Offloader):
"""
supports forward offloading
"""

def __init__(self, blocks: list[nn.Module], num_blocks: int, blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(num_blocks, blocks_to_swap, device, debug)
def __init__(self, blocks: Union[list[nn.Module], nn.ModuleList], blocks_to_swap: int, device: torch.device, debug: bool = False):
super().__init__(len(blocks), blocks_to_swap, device, debug)

# register backward hooks
self.remove_handles = []
@@ -168,7 +172,7 @@ class ModelOffloader(Offloader):
for handle in self.remove_handles:
handle.remove()

def create_backward_hook(self, blocks: list[nn.Module], block_index: int) -> Optional[callable]:
def create_backward_hook(self, blocks: Union[list[nn.Module], nn.ModuleList], block_index: int) -> Optional[Callable[[nn.Module, _grad_t, _grad_t], Union[None, _grad_t]]]:
# -1 for 0-based index
num_blocks_propagated = self.num_blocks - block_index - 1
swapping = num_blocks_propagated > 0 and num_blocks_propagated <= self.blocks_to_swap
@@ -182,7 +186,7 @@ class ModelOffloader(Offloader):
block_idx_to_cuda = self.blocks_to_swap - num_blocks_propagated
block_idx_to_wait = block_index - 1

def backward_hook(module, grad_input, grad_output):
def backward_hook(module: nn.Module, grad_input: _grad_t, grad_output: _grad_t):
if self.debug:
print(f"Backward hook for block {block_index}")

@@ -194,7 +198,7 @@ class ModelOffloader(Offloader):

return backward_hook

def prepare_block_devices_before_forward(self, blocks: list[nn.Module]):
def prepare_block_devices_before_forward(self, blocks: Union[list[nn.Module], nn.ModuleList]):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return

@@ -207,7 +211,7 @@ class ModelOffloader(Offloader):

for b in blocks[self.num_blocks - self.blocks_to_swap :]:
b.to(self.device) # move block to device first
weighs_to_device(b, "cpu") # make sure weights are on cpu
weighs_to_device(b, torch.device("cpu")) # make sure weights are on cpu

synchronize_device(self.device)
clean_memory_on_device(self.device)
@@ -217,7 +221,7 @@ class ModelOffloader(Offloader):
return
self._wait_blocks_move(block_idx)

def submit_move_blocks(self, blocks: list[nn.Module], block_idx: int):
def submit_move_blocks(self, blocks: Union[list[nn.Module], nn.ModuleList], block_idx: int):
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
if block_idx >= self.blocks_to_swap:


+ 41
- 0
scripts/dev/library/deepspeed_utils.py View File

@@ -5,6 +5,8 @@ from accelerate import DeepSpeedPlugin, Accelerator

from .utils import setup_logging

from .device_utils import get_preferred_device

setup_logging()
import logging

@@ -94,6 +96,7 @@ def prepare_deepspeed_plugin(args: argparse.Namespace):
deepspeed_plugin.deepspeed_config["train_batch_size"] = (
args.train_batch_size * args.gradient_accumulation_steps * int(os.environ["WORLD_SIZE"])
)
deepspeed_plugin.set_mixed_precision(args.mixed_precision)
if args.mixed_precision.lower() == "fp16":
deepspeed_plugin.deepspeed_config["fp16"]["initial_scale_power"] = 0 # preventing overflow.
@@ -122,18 +125,56 @@ def prepare_deepspeed_model(args: argparse.Namespace, **models):
class DeepSpeedWrapper(torch.nn.Module):
def __init__(self, **kw_models) -> None:
super().__init__()
self.models = torch.nn.ModuleDict()
wrap_model_forward_with_torch_autocast = args.mixed_precision is not "no"

for key, model in kw_models.items():
if isinstance(model, list):
model = torch.nn.ModuleList(model)
if wrap_model_forward_with_torch_autocast:
model = self.__wrap_model_with_torch_autocast(model)
assert isinstance(
model, torch.nn.Module
), f"model must be an instance of torch.nn.Module, but got {key} is {type(model)}"

self.models.update(torch.nn.ModuleDict({key: model}))

def __wrap_model_with_torch_autocast(self, model):
if isinstance(model, torch.nn.ModuleList):
model = torch.nn.ModuleList([self.__wrap_model_forward_with_torch_autocast(m) for m in model])
else:
model = self.__wrap_model_forward_with_torch_autocast(model)
return model

def __wrap_model_forward_with_torch_autocast(self, model):
assert hasattr(model, "forward"), f"model must have a forward method."

forward_fn = model.forward

def forward(*args, **kwargs):
try:
device_type = model.device.type
except AttributeError:
logger.warning(
"[DeepSpeed] model.device is not available. Using get_preferred_device() "
"to determine the device_type for torch.autocast()."
)
device_type = get_preferred_device().type

with torch.autocast(device_type = device_type):
return forward_fn(*args, **kwargs)

model.forward = forward
return model
def get_models(self):
return self.models

ds_model = DeepSpeedWrapper(**models)
return ds_model

+ 14
- 178
scripts/dev/library/flux_models.py View File

@@ -930,6 +930,9 @@ class Flux(nn.Module):
self.num_double_blocks = len(self.double_blocks)
self.num_single_blocks = len(self.single_blocks)

def get_model_type(self) -> str:
return "flux"

@property
def device(self):
return next(self.parameters()).device
@@ -977,10 +980,10 @@ class Flux(nn.Module):
)

self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1006,6 +1009,9 @@ class Flux(nn.Module):
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)

def get_mod_vectors(self, timesteps: Tensor, guidance: Tensor | None = None, batch_size: int | None = None) -> Tensor:
return None # FLUX.1 does not use mod_vectors, but Chroma does.

def forward(
self,
img: Tensor,
@@ -1018,6 +1024,7 @@ class Flux(nn.Module):
block_controlnet_single_hidden_states=None,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
mod_vectors: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -1169,7 +1176,7 @@ class ControlNetFlux(nn.Module):
nn.SiLU(),
nn.Conv2d(16, 16, 3, padding=1, stride=2),
nn.SiLU(),
zero_module(nn.Conv2d(16, 16, 3, padding=1))
zero_module(nn.Conv2d(16, 16, 3, padding=1)),
)

@property
@@ -1219,10 +1226,10 @@ class ControlNetFlux(nn.Module):
)

self.offloader_double = custom_offloading_utils.ModelOffloader(
self.double_blocks, self.num_double_blocks, double_blocks_to_swap, device # , debug=True
self.double_blocks, double_blocks_to_swap, device # , debug=True
)
self.offloader_single = custom_offloading_utils.ModelOffloader(
self.single_blocks, self.num_single_blocks, single_blocks_to_swap, device # , debug=True
self.single_blocks, single_blocks_to_swap, device # , debug=True
)
print(
f"FLUX: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
@@ -1233,8 +1240,8 @@ class ControlNetFlux(nn.Module):
if self.blocks_to_swap:
save_double_blocks = self.double_blocks
save_single_blocks = self.single_blocks
self.double_blocks = None
self.single_blocks = None
self.double_blocks = nn.ModuleList()
self.single_blocks = nn.ModuleList()

self.to(device)

@@ -1320,174 +1327,3 @@ class ControlNetFlux(nn.Module):
controlnet_single_block_samples = controlnet_single_block_samples + (block_sample,)

return controlnet_block_samples, controlnet_single_block_samples


"""
class FluxUpper(nn.Module):
""
Transformer model for flow matching on sequences.
""

def __init__(self, params: FluxParams):
super().__init__()

self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)

self.gradient_checkpointing = False

@property
def device(self):
return next(self.parameters()).device

@property
def dtype(self):
return next(self.parameters()).dtype

def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True

self.time_in.enable_gradient_checkpointing()
self.vector_in.enable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.enable_gradient_checkpointing()

for block in self.double_blocks:
block.enable_gradient_checkpointing()

print("FLUX: Gradient checkpointing enabled.")

def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False

self.time_in.disable_gradient_checkpointing()
self.vector_in.disable_gradient_checkpointing()
if self.guidance_in.__class__ != nn.Identity:
self.guidance_in.disable_gradient_checkpointing()

for block in self.double_blocks:
block.disable_gradient_checkpointing()

print("FLUX: Gradient checkpointing disabled.")

def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")

# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)

ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)

for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)

return img, txt, vec, pe


class FluxLower(nn.Module):
""
Transformer model for flow matching on sequences.
""

def __init__(self, params: FluxParams):
super().__init__()
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.out_channels = params.in_channels

self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)

self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

self.gradient_checkpointing = False

@property
def device(self):
return next(self.parameters()).device

@property
def dtype(self):
return next(self.parameters()).dtype

def enable_gradient_checkpointing(self):
self.gradient_checkpointing = True

for block in self.single_blocks:
block.enable_gradient_checkpointing()

print("FLUX: Gradient checkpointing enabled.")

def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False

for block in self.single_blocks:
block.disable_gradient_checkpointing()

print("FLUX: Gradient checkpointing disabled.")

def forward(
self,
img: Tensor,
txt: Tensor,
vec: Tensor | None = None,
pe: Tensor | None = None,
txt_attention_mask: Tensor | None = None,
) -> Tensor:
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe, txt_attention_mask=txt_attention_mask)
img = img[:, txt.shape[1] :, ...]

img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
"""

+ 154
- 84
scripts/dev/library/flux_train_utils.py View File

@@ -40,7 +40,7 @@ def sample_images(
text_encoders,
sample_prompts_te_outputs,
prompt_replacement=None,
controlnet=None
controlnet=None,
):
if steps == 0:
if not args.sample_at_first:
@@ -67,7 +67,7 @@ def sample_images(
# unwrap unet and text_encoder(s)
flux = accelerator.unwrap_model(flux)
if text_encoders is not None:
text_encoders = [accelerator.unwrap_model(te) for te in text_encoders]
text_encoders = [(accelerator.unwrap_model(te) if te is not None else None) for te in text_encoders]
if controlnet is not None:
controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])
@@ -101,7 +101,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
@@ -125,7 +125,7 @@ def sample_images(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
)

torch.set_rng_state(rng_state)
@@ -147,14 +147,15 @@ def sample_image_inference(
steps,
sample_prompts_te_outputs,
prompt_replacement,
controlnet
controlnet,
):
assert isinstance(prompt_dict, dict)
# negative_prompt = prompt_dict.get("negative_prompt")
negative_prompt = prompt_dict.get("negative_prompt")
sample_steps = prompt_dict.get("sample_steps", 20)
width = prompt_dict.get("width", 512)
height = prompt_dict.get("height", 512)
scale = prompt_dict.get("scale", 3.5)
emb_guidance_scale = prompt_dict.get("guidance_scale", 3.5)
cfg_scale = prompt_dict.get("scale", 1.0)
seed = prompt_dict.get("seed")
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
@@ -162,8 +163,8 @@ def sample_image_inference(

if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
# if negative_prompt is not None:
# negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(prompt_replacement[0], prompt_replacement[1])

if seed is not None:
torch.manual_seed(seed)
@@ -173,16 +174,21 @@ def sample_image_inference(
torch.seed()
torch.cuda.seed()

# if negative_prompt is None:
# negative_prompt = ""
if negative_prompt is None:
negative_prompt = ""
height = max(64, height - height % 16) # round to divisible by 16
width = max(64, width - width % 16) # round to divisible by 16
logger.info(f"prompt: {prompt}")
# logger.info(f"negative_prompt: {negative_prompt}")
if cfg_scale != 1.0:
logger.info(f"negative_prompt: {negative_prompt}")
elif negative_prompt != "":
logger.info(f"negative prompt is ignored because scale is 1.0")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {scale}")
logger.info(f"embedded guidance scale: {emb_guidance_scale}")
if cfg_scale != 1.0:
logger.info(f"CFG scale: {cfg_scale}")
# logger.info(f"sample_sampler: {sampler_name}")
if seed is not None:
logger.info(f"seed: {seed}")
@@ -191,26 +197,37 @@ def sample_image_inference(
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()

text_encoder_conds = []
if sample_prompts_te_outputs and prompt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prompt]
print(f"Using cached text encoder outputs for prompt: {prompt}")
if text_encoders is not None:
print(f"Encoding prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)

# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]

l_pooled, t5_out, txt_ids, t5_attn_mask = text_encoder_conds
def encode_prompt(prpt):
text_encoder_conds = []
if sample_prompts_te_outputs and prpt in sample_prompts_te_outputs:
text_encoder_conds = sample_prompts_te_outputs[prpt]
print(f"Using cached text encoder outputs for prompt: {prpt}")
if text_encoders is not None:
print(f"Encoding prompt: {prpt}")
tokens_and_masks = tokenize_strategy.tokenize(prpt)
# strategy has apply_t5_attn_mask option
encoded_text_encoder_conds = encoding_strategy.encode_tokens(tokenize_strategy, text_encoders, tokens_and_masks)

# if text_encoder_conds is not cached, use encoded_text_encoder_conds
if len(text_encoder_conds) == 0:
text_encoder_conds = encoded_text_encoder_conds
else:
# if encoded_text_encoder_conds is not None, update cached text_encoder_conds
for i in range(len(encoded_text_encoder_conds)):
if encoded_text_encoder_conds[i] is not None:
text_encoder_conds[i] = encoded_text_encoder_conds[i]
return text_encoder_conds

l_pooled, t5_out, txt_ids, t5_attn_mask = encode_prompt(prompt)
# encode negative prompts
if cfg_scale != 1.0:
neg_l_pooled, neg_t5_out, _, neg_t5_attn_mask = encode_prompt(negative_prompt)
neg_t5_attn_mask = (
neg_t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask and neg_t5_attn_mask is not None else None
)
neg_cond = (cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask)
else:
neg_cond = None

# sample image
weight_dtype = ae.dtype # TOFO give dtype as argument
@@ -224,7 +241,7 @@ def sample_image_inference(
dtype=weight_dtype,
generator=torch.Generator(device=accelerator.device).manual_seed(seed) if seed is not None else None,
)
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # FLUX.1 dev -> shift=True
timesteps = get_schedule(sample_steps, noise.shape[1], shift=True) # Chroma can use shift=True
img_ids = flux_utils.prepare_img_ids(1, packed_latent_height, packed_latent_width).to(accelerator.device, weight_dtype)
t5_attn_mask = t5_attn_mask.to(accelerator.device) if args.apply_t5_attn_mask else None

@@ -235,7 +252,20 @@ def sample_image_inference(
controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)

with accelerator.autocast(), torch.no_grad():
x = denoise(flux, noise, img_ids, t5_out, txt_ids, l_pooled, timesteps=timesteps, guidance=scale, t5_attn_mask=t5_attn_mask, controlnet=controlnet, controlnet_img=controlnet_image)
x = denoise(
flux,
noise,
img_ids,
t5_out,
txt_ids,
l_pooled,
timesteps=timesteps,
guidance=emb_guidance_scale,
t5_attn_mask=t5_attn_mask,
controlnet=controlnet,
controlnet_img=controlnet_image,
neg_cond=neg_cond,
)

x = flux_utils.unpack_latents(x, packed_latent_height, packed_latent_width)

@@ -305,22 +335,24 @@ def denoise(
model: flux_models.Flux,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt: torch.Tensor, # t5_out
txt_ids: torch.Tensor,
vec: torch.Tensor,
vec: torch.Tensor, # l_pooled
timesteps: list[float],
guidance: float = 4.0,
t5_attn_mask: Optional[torch.Tensor] = None,
controlnet: Optional[flux_models.ControlNetFlux] = None,
controlnet_img: Optional[torch.Tensor] = None,
neg_cond: Optional[Tuple[float, torch.Tensor, torch.Tensor, torch.Tensor]] = None,
):
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
do_cfg = neg_cond is not None

for t_curr, t_prev in zip(tqdm(timesteps[:-1]), timesteps[1:]):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
model.prepare_block_swap_before_forward()

if controlnet is not None:
block_samples, block_single_samples = controlnet(
img=img,
@@ -336,20 +368,48 @@ def denoise(
else:
block_samples = None
block_single_samples = None
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)

img = img + (t_prev - t_curr) * pred
if not do_cfg:
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec,
guidance=guidance_vec,
txt_attention_mask=t5_attn_mask,
)

img = img + (t_prev - t_curr) * pred
else:
cfg_scale, neg_l_pooled, neg_t5_out, neg_t5_attn_mask = neg_cond
nc_c_t5_attn_mask = None if t5_attn_mask is None else torch.cat([neg_t5_attn_mask, t5_attn_mask], dim=0)

# TODO is it ok to use the same block samples for both cond and uncond?
block_samples = None if block_samples is None else torch.cat([block_samples, block_samples], dim=0)
block_single_samples = (
None if block_single_samples is None else torch.cat([block_single_samples, block_single_samples], dim=0)
)

nc_c_pred = model(
img=torch.cat([img, img], dim=0),
img_ids=torch.cat([img_ids, img_ids], dim=0),
txt=torch.cat([neg_t5_out, txt], dim=0),
txt_ids=torch.cat([txt_ids, txt_ids], dim=0),
y=torch.cat([neg_l_pooled, vec], dim=0),
block_controlnet_hidden_states=block_samples,
block_controlnet_single_hidden_states=block_single_samples,
timesteps=t_vec.repeat(2),
guidance=guidance_vec.repeat(2),
txt_attention_mask=nc_c_t5_attn_mask,
)
neg_pred, pred = torch.chunk(nc_c_pred, 2, dim=0)
pred = neg_pred + (pred - neg_pred) * cfg_scale

img = img + (t_prev - t_curr) * pred

model.prepare_block_swap_before_forward()
return img
@@ -366,8 +426,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma


@@ -410,42 +468,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):


def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, _, h, w = latents.shape
sigmas = None
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)
sigmas = torch.rand((bsz,), device=device)

timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * latents + t * noise
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)

t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_timesteps
elif args.timestep_sampling == "flux_shift":
logits_norm = torch.randn(bsz, device=device)
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
timesteps = time_shift(mu, 1.0, timesteps)

t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * latents + t * noise
sigmas = torch.randn(bsz, device=device)
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
sigmas = sigmas.sigmoid()
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
sigmas = time_shift(mu, 1.0, sigmas)
timesteps = sigmas * num_timesteps
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
@@ -456,12 +506,24 @@ def get_noisy_model_input_and_timesteps(
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
indices = (u * num_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)

# Add noise according to flow matching.
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents

# Broadcast sigmas to latent shape
sigmas = sigmas.view(-1, 1, 1, 1)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise

return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas

@@ -567,7 +629,7 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
"--controlnet_model_name_or_path",
type=str,
default=None,
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)"
help="path to controlnet (*.sft or *.safetensors) / controlnetのパス(*.sftまたは*.safetensors)",
)
parser.add_argument(
"--t5xxl_max_token_length",
@@ -617,3 +679,11 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

parser.add_argument(
"--model_type",
type=str,
choices=["flux", "chroma"],
default="flux",
help="Model type to use for training / トレーニングに使用するモデルタイプ:flux or chroma (default: flux)",
)

+ 112
- 41
scripts/dev/library/flux_utils.py View File

@@ -23,6 +23,7 @@ from library.utils import load_safetensors
MODEL_VERSION_FLUX_V1 = "flux1"
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"
MODEL_VERSION_CHROMA = "chroma"


def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
@@ -92,50 +93,84 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int


def load_flow_model(
ckpt_path: str, dtype: Optional[torch.dtype], device: Union[str, torch.device], disable_mmap: bool = False
ckpt_path: str,
dtype: Optional[torch.dtype],
device: Union[str, torch.device],
disable_mmap: bool = False,
model_type: str = "flux",
) -> Tuple[bool, flux_models.Flux]:
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL

# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params

# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)

model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)

# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))
if model_type == "flux":
is_diffusers, is_schnell, (num_double_blocks, num_single_blocks), ckpt_paths = analyze_checkpoint_state(ckpt_path)
name = MODEL_NAME_DEV if not is_schnell else MODEL_NAME_SCHNELL

# build model
logger.info(f"Building Flux model {name} from {'Diffusers' if is_diffusers else 'BFL'} checkpoint")
with torch.device("meta"):
params = flux_models.configs[name].params

# set the number of blocks
if params.depth != num_double_blocks:
logger.info(f"Setting the number of double blocks from {params.depth} to {num_double_blocks}")
params = replace(params, depth=num_double_blocks)
if params.depth_single_blocks != num_single_blocks:
logger.info(f"Setting the number of single blocks from {params.depth_single_blocks} to {num_single_blocks}")
params = replace(params, depth_single_blocks=num_single_blocks)

model = flux_models.Flux(params)
if dtype is not None:
model = model.to(dtype)

# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = {}
for ckpt_path in ckpt_paths:
sd.update(load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype))

# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")

# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)

info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model

elif model_type == "chroma":
from . import chroma_models

# build model
logger.info("Building Chroma model")
with torch.device("meta"):
model = chroma_models.Chroma(chroma_models.chroma_params)
if dtype is not None:
model = model.to(dtype)

# load_sft doesn't support torch.device
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)

# convert Diffusers to BFL
if is_diffusers:
logger.info("Converting Diffusers to BFL")
sd = convert_diffusers_sd_to_bfl(sd, num_double_blocks, num_single_blocks)
logger.info("Converted Diffusers to BFL")
# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)

# if the key has annoying prefix, remove it
for key in list(sd.keys()):
new_key = key.replace("model.diffusion_model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)
info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Chroma: {info}")
is_schnell = False # Chroma is not schnell
return is_schnell, model

info = model.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Flux: {info}")
return is_schnell, model
else:
raise ValueError(f"Unsupported model_type: {model_type}. Supported types are 'flux' and 'chroma'.")


def load_ae(
@@ -166,7 +201,43 @@ def load_controlnet(
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)
info = controlnet.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded ControlNet: {info}")
return controlnet
return controlnet


def dummy_clip_l() -> torch.nn.Module:
"""
Returns a dummy CLIP-L model with the output shape of (N, 77, 768).
"""
return DummyCLIPL()


class DummyTextModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.embeddings = torch.nn.Parameter(torch.zeros(1))


class DummyCLIPL(torch.nn.Module):
def __init__(self):
super().__init__()
self.output_shape = (77, 1) # Note: The original code had (77, 768), but we use (77, 1) for the dummy output
self.dummy_param = torch.nn.Parameter(torch.zeros(1)) # get dtype and device from this parameter
self.text_model = DummyTextModel()

@property
def device(self):
return self.dummy_param.device

@property
def dtype(self):
return self.dummy_param.dtype

def forward(self, *args, **kwargs):
"""
Returns a dummy output with the shape of (N, 77, 768).
"""
batch_size = args[0].shape[0] if args else 1
return {"pooler_output": torch.zeros(batch_size, *self.output_shape, device=self.device, dtype=self.dtype)}


def load_clip_l(


+ 62
- 76
scripts/dev/library/ipex/__init__.py View File

@@ -1,14 +1,15 @@
import os
import sys
import contextlib
import torch
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
legacy = True
has_ipex = True
except Exception:
legacy = False
has_ipex = False
from .hijacks import ipex_hijacks

torch_version = float(torch.__version__[:3])

# pylint: disable=protected-access, missing-function-docstring, line-too-long

def ipex_init(): # pylint: disable=too-many-statements
@@ -16,7 +17,10 @@ def ipex_init(): # pylint: disable=too-many-statements
if hasattr(torch, "cuda") and hasattr(torch.cuda, "is_xpu_hijacked") and torch.cuda.is_xpu_hijacked:
return True, "Skipping IPEX hijack"
else:
try: # force xpu device on torch compile and triton
try:
# force xpu device on torch compile and triton
# import inductor utils to get around lazy import
from torch._inductor import utils as torch_inductor_utils # pylint: disable=import-error, unused-import # noqa: F401
torch._inductor.utils.GPU_TYPES = ["xpu"]
torch._inductor.utils.get_gpu_type = lambda *args, **kwargs: "xpu"
from triton import backends as triton_backends # pylint: disable=import-error
@@ -35,7 +39,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.is_available = torch.xpu.is_available
torch.cuda.is_initialized = torch.xpu.is_initialized
torch.cuda.is_current_stream_capturing = lambda: False
torch.cuda.set_device = torch.xpu.set_device
torch.cuda.stream = torch.xpu.stream
torch.cuda.Event = torch.xpu.Event
torch.cuda.Stream = torch.xpu.Stream
@@ -45,7 +48,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.Optional = torch.xpu.Optional
torch.cuda.__cached__ = torch.xpu.__cached__
torch.cuda.__loader__ = torch.xpu.__loader__
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.streams = torch.xpu.streams
torch.cuda.Any = torch.xpu.Any
torch.cuda.__doc__ = torch.xpu.__doc__
@@ -58,7 +60,6 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__annotations__ = torch.xpu.__annotations__
torch.cuda.__package__ = torch.xpu.__package__
torch.cuda.__builtins__ = torch.xpu.__builtins__
torch.cuda.List = torch.xpu.List
torch.cuda._lazy_init = torch.xpu._lazy_init
torch.cuda.StreamContext = torch.xpu.StreamContext
torch.cuda._lazy_call = torch.xpu._lazy_call
@@ -70,47 +71,40 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.__file__ = torch.xpu.__file__
# torch.cuda.is_current_stream_capturing = torch.xpu.is_current_stream_capturing

if legacy:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu
if float(ipex.__version__[:3]) < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda._lazy_new = torch.xpu._lazy_new

torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
if torch_version < 2.3:
torch.cuda._initialization_lock = torch.xpu.lazy_init._initialization_lock
torch.cuda._initialized = torch.xpu.lazy_init._initialized
torch.cuda._is_in_bad_fork = torch.xpu.lazy_init._is_in_bad_fork
torch.cuda._lazy_seed_tracker = torch.xpu.lazy_init._lazy_seed_tracker
torch.cuda._queued_calls = torch.xpu.lazy_init._queued_calls
torch.cuda._tls = torch.xpu.lazy_init._tls
torch.cuda.threading = torch.xpu.lazy_init.threading
torch.cuda.traceback = torch.xpu.lazy_init.traceback
torch.cuda._lazy_new = torch.xpu._lazy_new

if not legacy or float(ipex.__version__[:3]) >= 2.3:
torch.cuda.FloatTensor = torch.xpu.FloatTensor
torch.cuda.FloatStorage = torch.xpu.FloatStorage
torch.cuda.BFloat16Tensor = torch.xpu.BFloat16Tensor
torch.cuda.BFloat16Storage = torch.xpu.BFloat16Storage
torch.cuda.HalfTensor = torch.xpu.HalfTensor
torch.cuda.HalfStorage = torch.xpu.HalfStorage
torch.cuda.ByteTensor = torch.xpu.ByteTensor
torch.cuda.ByteStorage = torch.xpu.ByteStorage
torch.cuda.DoubleTensor = torch.xpu.DoubleTensor
torch.cuda.DoubleStorage = torch.xpu.DoubleStorage
torch.cuda.ShortTensor = torch.xpu.ShortTensor
torch.cuda.ShortStorage = torch.xpu.ShortStorage
torch.cuda.LongTensor = torch.xpu.LongTensor
torch.cuda.LongStorage = torch.xpu.LongStorage
torch.cuda.IntTensor = torch.xpu.IntTensor
torch.cuda.IntStorage = torch.xpu.IntStorage
torch.cuda.CharTensor = torch.xpu.CharTensor
torch.cuda.CharStorage = torch.xpu.CharStorage
torch.cuda.BoolTensor = torch.xpu.BoolTensor
torch.cuda.BoolStorage = torch.xpu.BoolStorage
torch.cuda.ComplexFloatStorage = torch.xpu.ComplexFloatStorage
torch.cuda.ComplexDoubleStorage = torch.xpu.ComplexDoubleStorage
else:
torch.cuda._initialization_lock = torch.xpu._initialization_lock
torch.cuda._initialized = torch.xpu._initialized
torch.cuda._is_in_bad_fork = torch.xpu._is_in_bad_fork
@@ -120,12 +114,24 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.threading = torch.xpu.threading
torch.cuda.traceback = torch.xpu.traceback

if torch_version < 2.5:
torch.cuda.os = torch.xpu.os
torch.cuda.Device = torch.xpu.Device
torch.cuda.warnings = torch.xpu.warnings
torch.cuda.classproperty = torch.xpu.classproperty
torch.UntypedStorage.cuda = torch.UntypedStorage.xpu

if torch_version < 2.7:
torch.cuda.Tuple = torch.xpu.Tuple
torch.cuda.List = torch.xpu.List


# Memory:
if 'linux' in sys.platform and "WSL2" in os.popen("uname -a").read():
torch.xpu.empty_cache = lambda: None
torch.cuda.empty_cache = torch.xpu.empty_cache

if legacy:
if has_ipex:
torch.cuda.memory_summary = torch.xpu.memory_summary
torch.cuda.memory_snapshot = torch.xpu.memory_snapshot
torch.cuda.memory = torch.xpu.memory
@@ -153,40 +159,19 @@ def ipex_init(): # pylint: disable=too-many-statements
torch.cuda.seed_all = torch.xpu.seed_all
torch.cuda.initial_seed = torch.xpu.initial_seed

# AMP:
if legacy:
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.xpu.amp
if float(ipex.__version__[:3]) < 2.3:
torch.is_autocast_enabled = torch.xpu.is_autocast_xpu_enabled
torch.get_autocast_gpu_dtype = torch.xpu.get_autocast_xpu_dtype

if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = contextlib.nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False

try:
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
try:
from .gradscaler import gradscaler_init # pylint: disable=import-outside-toplevel, import-error
gradscaler_init()
torch.cuda.amp.GradScaler = torch.xpu.amp.GradScaler
except Exception: # pylint: disable=broad-exception-caught
torch.cuda.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler

# C
if legacy and float(ipex.__version__[:3]) < 2.3:
if torch_version < 2.3:
torch._C._cuda_getCurrentRawStream = ipex._C._getCurrentRawStream
ipex._C._DeviceProperties.multi_processor_count = ipex._C._DeviceProperties.gpu_subslice_count
ipex._C._DeviceProperties.major = 12
ipex._C._DeviceProperties.minor = 1
ipex._C._DeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750
else:
torch._C._cuda_getCurrentRawStream = torch._C._xpu_getCurrentRawStream
torch._C._XpuDeviceProperties.multi_processor_count = torch._C._XpuDeviceProperties.gpu_subslice_count
torch._C._XpuDeviceProperties.major = 12
torch._C._XpuDeviceProperties.minor = 1
torch._C._XpuDeviceProperties.L2_cache_size = 16*1024*1024 # A770 and A750

# Fix functions with ipex:
# torch.xpu.mem_get_info always returns the total memory as free memory
@@ -195,21 +180,22 @@ def ipex_init(): # pylint: disable=too-many-statements
torch._utils._get_available_device_type = lambda: "xpu"
torch.has_cuda = True
torch.cuda.has_half = True
torch.cuda.is_bf16_supported = lambda *args, **kwargs: True
torch.cuda.is_bf16_supported = getattr(torch.xpu, "is_bf16_supported", lambda *args, **kwargs: True)
torch.cuda.is_fp16_supported = lambda *args, **kwargs: True
torch.backends.cuda.is_built = lambda *args, **kwargs: True
torch.version.cuda = "12.1"
torch.cuda.get_arch_list = lambda: ["ats-m150", "pvc"]
torch.cuda.get_arch_list = getattr(torch.xpu, "get_arch_list", lambda: ["pvc", "dg2", "ats-m150"])
torch.cuda.get_device_capability = lambda *args, **kwargs: (12,1)
torch.cuda.get_device_properties.major = 12
torch.cuda.get_device_properties.minor = 1
torch.cuda.get_device_properties.L2_cache_size = 16*1024*1024 # A770 and A750
torch.cuda.ipc_collect = lambda *args, **kwargs: None
torch.cuda.utilization = lambda *args, **kwargs: 0

device_supports_fp64, can_allocate_plus_4gb = ipex_hijacks(legacy=legacy)
device_supports_fp64 = ipex_hijacks()
try:
from .diffusers import ipex_diffusers
ipex_diffusers(device_supports_fp64=device_supports_fp64, can_allocate_plus_4gb=can_allocate_plus_4gb)
ipex_diffusers(device_supports_fp64=device_supports_fp64)
except Exception: # pylint: disable=broad-exception-caught
pass
torch.cuda.is_xpu_hijacked = True


+ 6
- 6
scripts/dev/library/ipex/attention.py View File

@@ -61,13 +61,13 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop
if query.device.type != "xpu":
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
is_unsqueezed = False
if len(query.shape) == 3:
if query.dim() == 3:
query = query.unsqueeze(0)
is_unsqueezed = True
if len(key.shape) == 3:
key = key.unsqueeze(0)
if len(value.shape) == 3:
value = value.unsqueeze(0)
if key.dim() == 3:
key = key.unsqueeze(0)
if value.dim() == 3:
value = value.unsqueeze(0)
do_batch_split, do_head_split, do_query_split, split_batch_size, split_head_size, split_query_size = find_sdpa_slice_sizes(query.shape, key.shape, query.element_size(), slice_rate=attention_slice_rate, trigger_rate=sdpa_slice_trigger_rate)

# Slice SDPA
@@ -115,5 +115,5 @@ def dynamic_scaled_dot_product_attention(query, key, value, attn_mask=None, drop
else:
hidden_states = original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal, **kwargs)
if is_unsqueezed:
hidden_states.squeeze(0)
hidden_states = hidden_states.squeeze(0)
return hidden_states

+ 80
- 1
scripts/dev/library/ipex/diffusers.py View File

@@ -1,11 +1,13 @@
from functools import wraps
import torch
import diffusers # pylint: disable=import-error
from diffusers.utils import torch_utils # pylint: disable=import-error, unused-import # noqa: F401

# pylint: disable=protected-access, missing-function-docstring, line-too-long


# Diffusers FreeU
# Diffusers is imported before ipex hijacks so fourier_filter needs hijacking too
original_fourier_filter = diffusers.utils.torch_utils.fourier_filter
@wraps(diffusers.utils.torch_utils.fourier_filter)
def fourier_filter(x_in, threshold, scale):
@@ -41,7 +43,84 @@ class FluxPosEmbed(torch.nn.Module):
return freqs_cos, freqs_sin


def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False):
def hidream_rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0, "The dimension must be even."
return_device = pos.device
pos = pos.to("cpu")

scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)

batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
cos_out = torch.cos(out)
sin_out = torch.sin(out)

stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
return out.to(return_device, dtype=torch.float32)


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
if output_type == "np":
return diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")

omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)

pos = pos.reshape(-1) # (M,)
out = torch.outer(pos, omega) # (M, D/2), outer product

emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2)

emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
return emb


def apply_rotary_emb(x, freqs_cis, use_real: bool = True, use_real_unbind_dim: int = -1):
if use_real:
cos, sin = freqs_cis # [S, D]
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)

if use_real_unbind_dim == -1:
# Used for flux, cogvideox, hunyuan-dit
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
elif use_real_unbind_dim == -2:
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
else:
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")

out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
else:
# used for lumina
# force cpu with Alchemist
x_rotated = torch.view_as_complex(x.to("cpu").float().reshape(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.to("cpu").unsqueeze(2)
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
return x_out.type_as(x).to(x.device)


def ipex_diffusers(device_supports_fp64=False):
diffusers.utils.torch_utils.fourier_filter = fourier_filter
if not device_supports_fp64:
# get around lazy imports
from diffusers.models import embeddings as diffusers_embeddings # pylint: disable=import-error, unused-import # noqa: F401
from diffusers.models import transformers as diffusers_transformers # pylint: disable=import-error, unused-import # noqa: F401
from diffusers.models import controlnets as diffusers_controlnets # pylint: disable=import-error, unused-import # noqa: F401
diffusers.models.embeddings.get_1d_sincos_pos_embed_from_grid = get_1d_sincos_pos_embed_from_grid
diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed
diffusers.models.embeddings.apply_rotary_emb = apply_rotary_emb
diffusers.models.transformers.transformer_flux.FluxPosEmbed = FluxPosEmbed
diffusers.models.transformers.transformer_lumina2.apply_rotary_emb = apply_rotary_emb
diffusers.models.controlnets.controlnet_flux.FluxPosEmbed = FluxPosEmbed
diffusers.models.transformers.transformer_hidream_image.rope = hidream_rope

+ 0
- 183
scripts/dev/library/ipex/gradscaler.py View File

@@ -1,183 +0,0 @@
from collections import defaultdict
import torch
import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
import intel_extension_for_pytorch._C as core # pylint: disable=import-error, unused-import

# pylint: disable=protected-access, missing-function-docstring, line-too-long

device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
OptState = ipex.cpu.autocast._grad_scaler.OptState
_MultiDeviceReplicator = ipex.cpu.autocast._grad_scaler._MultiDeviceReplicator
_refresh_per_optimizer_state = ipex.cpu.autocast._grad_scaler._refresh_per_optimizer_state

def _unscale_grads_(self, optimizer, inv_scale, found_inf, allow_fp16): # pylint: disable=unused-argument
per_device_inv_scale = _MultiDeviceReplicator(inv_scale)
per_device_found_inf = _MultiDeviceReplicator(found_inf)

# To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype.
# There could be hundreds of grads, so we'd like to iterate through them just once.
# However, we don't know their devices or dtypes in advance.

# https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict
# Google says mypy struggles with defaultdicts type annotations.
per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated]
# sync grad to master weight
if hasattr(optimizer, "sync_grad"):
optimizer.sync_grad()
with torch.no_grad():
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is None:
continue
if (not allow_fp16) and param.grad.dtype == torch.float16:
raise ValueError("Attempting to unscale FP16 gradients.")
if param.grad.is_sparse:
# is_coalesced() == False means the sparse grad has values with duplicate indices.
# coalesce() deduplicates indices and adds all values that have the same index.
# For scaled fp16 values, there's a good chance coalescing will cause overflow,
# so we should check the coalesced _values().
if param.grad.dtype is torch.float16:
param.grad = param.grad.coalesce()
to_unscale = param.grad._values()
else:
to_unscale = param.grad

# -: is there a way to split by device and dtype without appending in the inner loop?
to_unscale = to_unscale.to("cpu")
per_device_and_dtype_grads[to_unscale.device][
to_unscale.dtype
].append(to_unscale)

for _, per_dtype_grads in per_device_and_dtype_grads.items():
for grads in per_dtype_grads.values():
core._amp_foreach_non_finite_check_and_unscale_(
grads,
per_device_found_inf.get("cpu"),
per_device_inv_scale.get("cpu"),
)

return per_device_found_inf._per_device_tensors

def unscale_(self, optimizer):
"""
Divides ("unscales") the optimizer's gradient tensors by the scale factor.
:meth:`unscale_` is optional, serving cases where you need to
:ref:`modify or inspect gradients<working-with-unscaled-gradients>`
between the backward pass(es) and :meth:`step`.
If :meth:`unscale_` is not called explicitly, gradients will be unscaled automatically during :meth:`step`.
Simple example, using :meth:`unscale_` to enable clipping of unscaled gradients::
...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Args:
optimizer (torch.optim.Optimizer): Optimizer that owns the gradients to be unscaled.
.. warning::
:meth:`unscale_` should only be called once per optimizer per :meth:`step` call,
and only after all gradients for that optimizer's assigned parameters have been accumulated.
Calling :meth:`unscale_` twice for a given optimizer between each :meth:`step` triggers a RuntimeError.
.. warning::
:meth:`unscale_` may unscale sparse gradients out of place, replacing the ``.grad`` attribute.
"""
if not self._enabled:
return

self._check_scale_growth_tracker("unscale_")

optimizer_state = self._per_optimizer_states[id(optimizer)]

if optimizer_state["stage"] is OptState.UNSCALED: # pylint: disable=no-else-raise
raise RuntimeError(
"unscale_() has already been called on this optimizer since the last update()."
)
elif optimizer_state["stage"] is OptState.STEPPED:
raise RuntimeError("unscale_() is being called after step().")

# FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64.
assert self._scale is not None
if device_supports_fp64:
inv_scale = self._scale.double().reciprocal().float()
else:
inv_scale = self._scale.to("cpu").double().reciprocal().float().to(self._scale.device)
found_inf = torch.full(
(1,), 0.0, dtype=torch.float32, device=self._scale.device
)

optimizer_state["found_inf_per_device"] = self._unscale_grads_(
optimizer, inv_scale, found_inf, False
)
optimizer_state["stage"] = OptState.UNSCALED

def update(self, new_scale=None):
"""
Updates the scale factor.
If any optimizer steps were skipped the scale is multiplied by ``backoff_factor``
to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively,
the scale is multiplied by ``growth_factor`` to increase it.
Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not
used directly, it's used to fill GradScaler's internal scale tensor. So if
``new_scale`` was a tensor, later in-place changes to that tensor will not further
affect the scale GradScaler uses internally.)
Args:
new_scale (float or :class:`torch.FloatTensor`, optional, default=None): New scale factor.
.. warning::
:meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has
been invoked for all optimizers used this iteration.
"""
if not self._enabled:
return

_scale, _growth_tracker = self._check_scale_growth_tracker("update")

if new_scale is not None:
# Accept a new user-defined scale.
if isinstance(new_scale, float):
self._scale.fill_(new_scale) # type: ignore[union-attr]
else:
reason = "new_scale should be a float or a 1-element torch.FloatTensor with requires_grad=False."
assert isinstance(new_scale, torch.FloatTensor), reason # type: ignore[attr-defined]
assert new_scale.numel() == 1, reason
assert new_scale.requires_grad is False, reason
self._scale.copy_(new_scale) # type: ignore[union-attr]
else:
# Consume shared inf/nan data collected from optimizers to update the scale.
# If all found_inf tensors are on the same device as self._scale, this operation is asynchronous.
found_infs = [
found_inf.to(device="cpu", non_blocking=True)
for state in self._per_optimizer_states.values()
for found_inf in state["found_inf_per_device"].values()
]

assert len(found_infs) > 0, "No inf checks were recorded prior to update."

found_inf_combined = found_infs[0]
if len(found_infs) > 1:
for i in range(1, len(found_infs)):
found_inf_combined += found_infs[i]

to_device = _scale.device
_scale = _scale.to("cpu")
_growth_tracker = _growth_tracker.to("cpu")

core._amp_update_scale_(
_scale,
_growth_tracker,
found_inf_combined,
self._growth_factor,
self._backoff_factor,
self._growth_interval,
)

_scale = _scale.to(to_device)
_growth_tracker = _growth_tracker.to(to_device)
# To prepare for next iteration, clear the data collected from optimizers this iteration.
self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state)

def gradscaler_init():
torch.xpu.amp.GradScaler = ipex.cpu.autocast._grad_scaler.GradScaler
torch.xpu.amp.GradScaler._unscale_grads_ = _unscale_grads_
torch.xpu.amp.GradScaler.unscale_ = unscale_
torch.xpu.amp.GradScaler.update = update
return torch.xpu.amp.GradScaler

+ 167
- 68
scripts/dev/library/ipex/hijacks.py View File

@@ -4,17 +4,23 @@ from contextlib import nullcontext
import torch
import numpy as np

device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties("xpu").has_fp64
if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0' and (torch.xpu.get_device_properties("xpu").total_memory / 1024 / 1024 / 1024) > 4.1:
try:
x = torch.ones((33000,33000), dtype=torch.float32, device="xpu")
del x
torch.xpu.empty_cache()
can_allocate_plus_4gb = True
except Exception:
can_allocate_plus_4gb = False
torch_version = float(torch.__version__[:3])
current_xpu_device = f"xpu:{torch.xpu.current_device()}"
device_supports_fp64 = torch.xpu.has_fp64_dtype() if hasattr(torch.xpu, "has_fp64_dtype") else torch.xpu.get_device_properties(current_xpu_device).has_fp64

if os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '0':
if (torch.xpu.get_device_properties(current_xpu_device).total_memory / 1024 / 1024 / 1024) > 4.1:
try:
x = torch.ones((33000,33000), dtype=torch.float32, device=current_xpu_device)
del x
torch.xpu.empty_cache()
use_dynamic_attention = False
except Exception:
use_dynamic_attention = True
else:
use_dynamic_attention = True
else:
can_allocate_plus_4gb = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '-1')
use_dynamic_attention = bool(os.environ.get('IPEX_FORCE_ATTENTION_SLICE', '0') == '1')

# pylint: disable=protected-access, missing-function-docstring, line-too-long, unnecessary-lambda, no-else-return

@@ -22,32 +28,67 @@ class DummyDataParallel(torch.nn.Module): # pylint: disable=missing-class-docstr
def __new__(cls, module, device_ids=None, output_device=None, dim=0): # pylint: disable=unused-argument
if isinstance(device_ids, list) and len(device_ids) > 1:
print("IPEX backend doesn't support DataParallel on multiple XPU devices")
return module.to("xpu")
return module.to(f"xpu:{torch.xpu.current_device()}")

def return_null_context(*args, **kwargs): # pylint: disable=unused-argument
return nullcontext()

@property
def is_cuda(self):
return self.device.type == 'xpu' or self.device.type == 'cuda'
return self.device.type == "xpu" or self.device.type == "cuda"

def check_device(device):
return bool((isinstance(device, torch.device) and device.type == "cuda") or (isinstance(device, str) and "cuda" in device) or isinstance(device, int))
def check_device_type(device, device_type: str) -> bool:
if device is None or type(device) not in {str, int, torch.device}:
return False
else:
return bool(torch.device(device).type == device_type)

def return_xpu(device):
return f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"
def check_cuda(device) -> bool:
return bool(isinstance(device, int) or check_device_type(device, "cuda"))

def return_xpu(device): # keep the device instance type, aka return string if the input is string
return f"xpu:{torch.xpu.current_device()}" if device is None else f"xpu:{device.split(':')[-1]}" if isinstance(device, str) and ":" in device else f"xpu:{device}" if isinstance(device, int) else torch.device(f"xpu:{device.index}" if device.index is not None else "xpu") if isinstance(device, torch.device) else "xpu"


# Autocast
original_autocast_init = torch.amp.autocast_mode.autocast.__init__
@wraps(torch.amp.autocast_mode.autocast.__init__)
def autocast_init(self, device_type, dtype=None, enabled=True, cache_enabled=None):
if device_type == "cuda":
def autocast_init(self, device_type=None, dtype=None, enabled=True, cache_enabled=None):
if device_type is None or check_cuda(device_type):
return original_autocast_init(self, device_type="xpu", dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)
else:
return original_autocast_init(self, device_type=device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled)


original_grad_scaler_init = torch.amp.grad_scaler.GradScaler.__init__
@wraps(torch.amp.grad_scaler.GradScaler.__init__)
def GradScaler_init(self, device: str = None, init_scale: float = 2.0**16, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000, enabled: bool = True):
if device is None or check_cuda(device):
return original_grad_scaler_init(self, device=return_xpu(device), init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)
else:
return original_grad_scaler_init(self, device=device, init_scale=init_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, enabled=enabled)


original_is_autocast_enabled = torch.is_autocast_enabled
@wraps(torch.is_autocast_enabled)
def torch_is_autocast_enabled(device_type=None):
if device_type is None or check_cuda(device_type):
return original_is_autocast_enabled(return_xpu(device_type))
else:
return original_is_autocast_enabled(device_type)


original_get_autocast_dtype = torch.get_autocast_dtype
@wraps(torch.get_autocast_dtype)
def torch_get_autocast_dtype(device_type=None):
if device_type is None or check_cuda(device_type) or check_device_type(device_type, "xpu"):
return torch.bfloat16
else:
return original_get_autocast_dtype(device_type)


# Latent Antialias CPU Offload:
# IPEX 2.5 and above has partial support but doesn't really work most of the time.
original_interpolate = torch.nn.functional.interpolate
@wraps(torch.nn.functional.interpolate)
def interpolate(tensor, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False): # pylint: disable=too-many-arguments
@@ -66,23 +107,22 @@ original_from_numpy = torch.from_numpy
@wraps(torch.from_numpy)
def from_numpy(ndarray):
if ndarray.dtype == float:
return original_from_numpy(ndarray.astype('float32'))
return original_from_numpy(ndarray.astype("float32"))
else:
return original_from_numpy(ndarray)

original_as_tensor = torch.as_tensor
@wraps(torch.as_tensor)
def as_tensor(data, dtype=None, device=None):
if check_device(device):
if check_cuda(device):
device = return_xpu(device)
if isinstance(data, np.ndarray) and data.dtype == float and not (
(isinstance(device, torch.device) and device.type == "cpu") or (isinstance(device, str) and "cpu" in device)):
if isinstance(data, np.ndarray) and data.dtype == float and not check_device_type(device, "cpu"):
return original_as_tensor(data, dtype=torch.float32, device=device)
else:
return original_as_tensor(data, dtype=dtype, device=device)


if can_allocate_plus_4gb:
if not use_dynamic_attention:
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
else:
# 32 bit attention workarounds for Alchemist:
@@ -106,7 +146,7 @@ original_torch_bmm = torch.bmm
@wraps(torch.bmm)
def torch_bmm(input, mat2, *, out=None):
if input.dtype != mat2.dtype:
mat2 = mat2.to(input.dtype)
mat2 = mat2.to(dtype=input.dtype)
return original_torch_bmm(input, mat2, out=out)

# Diffusers FreeU
@@ -195,38 +235,36 @@ original_torch_tensor = torch.tensor
@wraps(torch.tensor)
def torch_tensor(data, *args, dtype=None, device=None, **kwargs):
global device_supports_fp64
if check_device(device):
if check_cuda(device):
device = return_xpu(device)
if not device_supports_fp64:
if (isinstance(device, torch.device) and device.type == "xpu") or (isinstance(device, str) and "xpu" in device):
if check_device_type(device, "xpu"):
if dtype == torch.float64:
dtype = torch.float32
elif dtype is None and (hasattr(data, "dtype") and (data.dtype == torch.float64 or data.dtype == float)):
dtype = torch.float32
return original_torch_tensor(data, *args, dtype=dtype, device=device, **kwargs)

original_Tensor_to = torch.Tensor.to
torch.Tensor.original_Tensor_to = torch.Tensor.to
@wraps(torch.Tensor.to)
def Tensor_to(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_to(self, return_xpu(device), *args, **kwargs)
if check_cuda(device):
return self.original_Tensor_to(return_xpu(device), *args, **kwargs)
else:
return original_Tensor_to(self, device, *args, **kwargs)
return self.original_Tensor_to(device, *args, **kwargs)

original_Tensor_cuda = torch.Tensor.cuda
@wraps(torch.Tensor.cuda)
def Tensor_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_Tensor_cuda(self, return_xpu(device), *args, **kwargs)
if device is None or check_cuda(device):
return self.to(return_xpu(device), *args, **kwargs)
else:
return original_Tensor_cuda(self, device, *args, **kwargs)

original_Tensor_pin_memory = torch.Tensor.pin_memory
@wraps(torch.Tensor.pin_memory)
def Tensor_pin_memory(self, device=None, *args, **kwargs):
if device is None:
device = "xpu"
if check_device(device):
if device is None or check_cuda(device):
return original_Tensor_pin_memory(self, return_xpu(device), *args, **kwargs)
else:
return original_Tensor_pin_memory(self, device, *args, **kwargs)
@@ -234,23 +272,32 @@ def Tensor_pin_memory(self, device=None, *args, **kwargs):
original_UntypedStorage_init = torch.UntypedStorage.__init__
@wraps(torch.UntypedStorage.__init__)
def UntypedStorage_init(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_UntypedStorage_init(*args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_init(*args, device=device, **kwargs)

original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, *args, **kwargs):
if check_device(device):
return original_UntypedStorage_cuda(self, return_xpu(device), *args, **kwargs)
else:
return original_UntypedStorage_cuda(self, device, *args, **kwargs)
if torch_version >= 2.4:
original_UntypedStorage_to = torch.UntypedStorage.to
@wraps(torch.UntypedStorage.to)
def UntypedStorage_to(self, *args, device=None, **kwargs):
if check_cuda(device):
return original_UntypedStorage_to(self, *args, device=return_xpu(device), **kwargs)
else:
return original_UntypedStorage_to(self, *args, device=device, **kwargs)

original_UntypedStorage_cuda = torch.UntypedStorage.cuda
@wraps(torch.UntypedStorage.cuda)
def UntypedStorage_cuda(self, device=None, non_blocking=False, **kwargs):
if device is None or check_cuda(device):
return self.to(device=return_xpu(device), non_blocking=non_blocking, **kwargs)
else:
return original_UntypedStorage_cuda(self, device=device, non_blocking=non_blocking, **kwargs)

original_torch_empty = torch.empty
@wraps(torch.empty)
def torch_empty(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_empty(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_empty(*args, device=device, **kwargs)
@@ -260,7 +307,7 @@ original_torch_randn = torch.randn
def torch_randn(*args, device=None, dtype=None, **kwargs):
if dtype is bytes:
dtype = None
if check_device(device):
if check_cuda(device):
return original_torch_randn(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_randn(*args, device=device, **kwargs)
@@ -268,7 +315,7 @@ def torch_randn(*args, device=None, dtype=None, **kwargs):
original_torch_ones = torch.ones
@wraps(torch.ones)
def torch_ones(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_ones(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_ones(*args, device=device, **kwargs)
@@ -276,7 +323,7 @@ def torch_ones(*args, device=None, **kwargs):
original_torch_zeros = torch.zeros
@wraps(torch.zeros)
def torch_zeros(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_zeros(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_zeros(*args, device=device, **kwargs)
@@ -284,7 +331,7 @@ def torch_zeros(*args, device=None, **kwargs):
original_torch_full = torch.full
@wraps(torch.full)
def torch_full(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_full(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_full(*args, device=device, **kwargs)
@@ -292,63 +339,91 @@ def torch_full(*args, device=None, **kwargs):
original_torch_linspace = torch.linspace
@wraps(torch.linspace)
def torch_linspace(*args, device=None, **kwargs):
if check_device(device):
if check_cuda(device):
return original_torch_linspace(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_linspace(*args, device=device, **kwargs)

original_torch_eye = torch.eye
@wraps(torch.eye)
def torch_eye(*args, device=None, **kwargs):
if check_cuda(device):
return original_torch_eye(*args, device=return_xpu(device), **kwargs)
else:
return original_torch_eye(*args, device=device, **kwargs)

original_torch_load = torch.load
@wraps(torch.load)
def torch_load(f, map_location=None, *args, **kwargs):
if map_location is None:
map_location = "xpu"
if check_device(map_location):
if map_location is None or check_cuda(map_location):
return original_torch_load(f, *args, map_location=return_xpu(map_location), **kwargs)
else:
return original_torch_load(f, *args, map_location=map_location, **kwargs)

original_torch_Generator = torch.Generator
@wraps(torch.Generator)
def torch_Generator(device=None):
if check_device(device):
return original_torch_Generator(return_xpu(device))
else:
return original_torch_Generator(device)

@wraps(torch.cuda.synchronize)
def torch_cuda_synchronize(device=None):
if check_device(device):
if check_cuda(device):
return torch.xpu.synchronize(return_xpu(device))
else:
return torch.xpu.synchronize(device)

@wraps(torch.cuda.device)
def torch_cuda_device(device):
if check_cuda(device):
return torch.xpu.device(return_xpu(device))
else:
return torch.xpu.device(device)

@wraps(torch.cuda.set_device)
def torch_cuda_set_device(device):
if check_cuda(device):
torch.xpu.set_device(return_xpu(device))
else:
torch.xpu.set_device(device)

# torch.Generator has to be a class for isinstance checks
original_torch_Generator = torch.Generator
class torch_Generator(original_torch_Generator):
def __new__(self, device=None):
# can't hijack __init__ because of C override so use return super().__new__
if check_cuda(device):
return super().__new__(self, return_xpu(device))
else:
return super().__new__(self, device)


# Hijack Functions:
def ipex_hijacks(legacy=True):
global device_supports_fp64, can_allocate_plus_4gb
if legacy and float(torch.__version__[:3]) < 2.5:
torch.nn.functional.interpolate = interpolate
def ipex_hijacks():
global device_supports_fp64
if torch_version >= 2.4:
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.UntypedStorage.to = UntypedStorage_to
torch.tensor = torch_tensor
torch.Tensor.to = Tensor_to
torch.Tensor.cuda = Tensor_cuda
torch.Tensor.pin_memory = Tensor_pin_memory
torch.UntypedStorage.__init__ = UntypedStorage_init
torch.UntypedStorage.cuda = UntypedStorage_cuda
torch.empty = torch_empty
torch.randn = torch_randn
torch.ones = torch_ones
torch.zeros = torch_zeros
torch.full = torch_full
torch.linspace = torch_linspace
torch.eye = torch_eye
torch.load = torch_load
torch.Generator = torch_Generator
torch.cuda.synchronize = torch_cuda_synchronize
torch.cuda.device = torch_cuda_device
torch.cuda.set_device = torch_cuda_set_device

torch.Generator = torch_Generator
torch._C.Generator = torch_Generator

torch.backends.cuda.sdp_kernel = return_null_context
torch.nn.DataParallel = DummyDataParallel
torch.UntypedStorage.is_cuda = is_cuda
torch.amp.autocast_mode.autocast.__init__ = autocast_init

torch.nn.functional.interpolate = interpolate
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention
torch.nn.functional.group_norm = functional_group_norm
torch.nn.functional.layer_norm = functional_layer_norm
@@ -364,4 +439,28 @@ def ipex_hijacks(legacy=True):
if not device_supports_fp64:
torch.from_numpy = from_numpy
torch.as_tensor = as_tensor
return device_supports_fp64, can_allocate_plus_4gb

# AMP:
torch.amp.grad_scaler.GradScaler.__init__ = GradScaler_init
torch.is_autocast_enabled = torch_is_autocast_enabled
torch.get_autocast_gpu_dtype = torch_get_autocast_dtype
torch.get_autocast_dtype = torch_get_autocast_dtype

if hasattr(torch.xpu, "amp"):
if not hasattr(torch.xpu.amp, "custom_fwd"):
torch.xpu.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.xpu.amp.custom_bwd = torch.cuda.amp.custom_bwd
if not hasattr(torch.xpu.amp, "GradScaler"):
torch.xpu.amp.GradScaler = torch.amp.grad_scaler.GradScaler
torch.cuda.amp = torch.xpu.amp
else:
if not hasattr(torch.amp, "custom_fwd"):
torch.amp.custom_fwd = torch.cuda.amp.custom_fwd
torch.amp.custom_bwd = torch.cuda.amp.custom_bwd
torch.cuda.amp = torch.amp

if not hasattr(torch.cuda.amp, "common"):
torch.cuda.amp.common = nullcontext()
torch.cuda.amp.common.amp_definitely_not_available = lambda: False

return device_supports_fp64

+ 186
- 0
scripts/dev/library/jpeg_xl_util.py View File

@@ -0,0 +1,186 @@
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
# Added partial read support for up to 200x speedup

import os
from typing import List, Tuple

class JXLBitstream:
"""
A stream of bits with methods for easy handling.
"""

def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
self.shift = 0
self.bitstream = bytearray()
self.file = file
self.offset = offset
self.offsets = offsets
if self.offsets:
self.offset = self.offsets[0][1]
self.previous_data_len = 0
self.index = 0
self.file.seek(self.offset)

def get_bits(self, length: int = 1) -> int:
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_to_read_length = length
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(0, length)
self.bitstream.extend(self.file.read(self.partial_to_read_length))
else:
self.bitstream.extend(self.file.read(length))
bitmask = 2**length - 1
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
self.shift += length
return bits

def partial_read(self, current_length: int, length: int) -> None:
self.previous_data_len += self.offsets[self.index][2]
to_read_length = self.previous_data_len - (self.shift + current_length)
self.bitstream.extend(self.file.read(to_read_length))
current_length += to_read_length
self.partial_to_read_length -= to_read_length
self.index += 1
self.file.seek(self.offsets[self.index][1])
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
self.partial_read(current_length, length)


def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
"""
Decodes the actual codestream.
JXL codestream specification: http://www-internal/2022/18181-1
"""

# Convert codestream to int within an object to get some handy methods.
codestream = JXLBitstream(file, offset=offset, offsets=offsets)

# Skip signature
codestream.get_bits(16)

# SizeHeader
div8 = codestream.get_bits(1)
if div8:
height = 8 * (1 + codestream.get_bits(5))
else:
distribution = codestream.get_bits(2)
match distribution:
case 0:
height = 1 + codestream.get_bits(9)
case 1:
height = 1 + codestream.get_bits(13)
case 2:
height = 1 + codestream.get_bits(18)
case 3:
height = 1 + codestream.get_bits(30)
ratio = codestream.get_bits(3)
if div8 and not ratio:
width = 8 * (1 + codestream.get_bits(5))
elif not ratio:
distribution = codestream.get_bits(2)
match distribution:
case 0:
width = 1 + codestream.get_bits(9)
case 1:
width = 1 + codestream.get_bits(13)
case 2:
width = 1 + codestream.get_bits(18)
case 3:
width = 1 + codestream.get_bits(30)
else:
match ratio:
case 1:
width = height
case 2:
width = (height * 12) // 10
case 3:
width = (height * 4) // 3
case 4:
width = (height * 3) // 2
case 5:
width = (height * 16) // 9
case 6:
width = (height * 5) // 4
case 7:
width = (height * 2) // 1
return width, height


def decode_container(file) -> Tuple[int,int]:
"""
Parses the ISOBMFF container, extracts the codestream, and decodes it.
JXL container specification: http://www-internal/2022/18181-2
"""

def parse_box(file, file_start: int) -> dict:
file.seek(file_start)
LBox = int.from_bytes(file.read(4), "big")
XLBox = None
if 1 < LBox <= 8:
raise ValueError(f"Invalid LBox at byte {file_start}.")
if LBox == 1:
file.seek(file_start + 8)
XLBox = int.from_bytes(file.read(8), "big")
if XLBox <= 16:
raise ValueError(f"Invalid XLBox at byte {file_start}.")
if XLBox:
header_length = 16
box_length = XLBox
else:
header_length = 8
if LBox == 0:
box_length = os.fstat(file.fileno()).st_size - file_start
else:
box_length = LBox
file.seek(file_start + 4)
box_type = file.read(4)
file.seek(file_start)
return {
"length": box_length,
"type": box_type,
"offset": header_length,
}

file.seek(0)
# Reject files missing required boxes. These two boxes are required to be at
# the start and contain no values, so we can manually check there presence.
# Signature box. (Redundant as has already been checked.)
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
raise ValueError("Invalid signature box.")
# File Type box.
if file.read(20) != bytes.fromhex(
"00000014 66747970 6A786C20 00000000 6A786C20"
):
raise ValueError("Invalid file type box.")

offset = 0
offsets = []
data_offset_not_found = True
container_pointer = 32
file_size = os.fstat(file.fileno()).st_size
while data_offset_not_found:
box = parse_box(file, container_pointer)
match box["type"]:
case b"jxlc":
offset = container_pointer + box["offset"]
data_offset_not_found = False
case b"jxlp":
file.seek(container_pointer + box["offset"])
index = int.from_bytes(file.read(4), "big")
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
container_pointer += box["length"]
if container_pointer >= file_size:
data_offset_not_found = False

if offsets:
offsets.sort(key=lambda i: i[0])
file.seek(0)

return decode_codestream(file, offset=offset, offsets=offsets)


def get_jxl_size(path: str) -> Tuple[int,int]:
with open(path, "rb") as file:
if file.read(2) == bytes.fromhex("FF0A"):
return decode_codestream(file)
return decode_container(file)

+ 1392
- 0
scripts/dev/library/lumina_models.py View File

@@ -0,0 +1,1392 @@
# Copyright Alpha VLLM/Lumina Image 2.0 and contributors
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# References:
# GLIDE: https://github.com/openai/glide-text2im
# MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
# --------------------------------------------------------

import math
from typing import List, Optional, Tuple
from dataclasses import dataclass

import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import torch.nn as nn
import torch.nn.functional as F

from library import custom_offloading_utils

try:
from flash_attn import flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
except:
# flash_attn may not be available but it is not required
pass

try:
from sageattention import sageattn
except:
pass

try:
from apex.normalization import FusedRMSNorm as RMSNorm
except:
import warnings

warnings.warn("Cannot import apex RMSNorm, switch to vanilla implementation")

#############################################################################
# RMSNorm #
#############################################################################

class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
"""
Initialize the RMSNorm normalization layer.

Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.

Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.

"""
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x) -> Tensor:
"""
Apply the RMSNorm normalization to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.

"""
return x * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: Tensor):
"""
Apply RMSNorm to the input tensor.

Args:
x (torch.Tensor): The input tensor.

Returns:
torch.Tensor: The normalized tensor.
"""
x_dtype = x.dtype
# To handle float8 we need to convert the tensor to float
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return ((x * rrms) * self.weight.float()).to(dtype=x_dtype)



@dataclass
class LuminaParams:
"""Parameters for Lumina model configuration"""

patch_size: int = 2
in_channels: int = 4
dim: int = 4096
n_layers: int = 30
n_refiner_layers: int = 2
n_heads: int = 24
n_kv_heads: int = 8
multiple_of: int = 256
axes_dims: List[int] = None
axes_lens: List[int] = None
qk_norm: bool = False
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
scaling_factor: float = 1.0
cap_feat_dim: int = 32

def __post_init__(self):
if self.axes_dims is None:
self.axes_dims = [36, 36, 36]
if self.axes_lens is None:
self.axes_lens = [300, 512, 512]

@classmethod
def get_2b_config(cls) -> "LuminaParams":
"""Returns the configuration for the 2B parameter model"""
return cls(
patch_size=2,
in_channels=16, # VAE channels
dim=2304,
n_layers=26,
n_heads=24,
n_kv_heads=8,
axes_dims=[32, 32, 32],
axes_lens=[300, 512, 512],
qk_norm=True,
cap_feat_dim=2304, # Gemma 2 hidden_size
)

@classmethod
def get_7b_config(cls) -> "LuminaParams":
"""Returns the configuration for the 7B parameter model"""
return cls(
patch_size=2,
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
axes_dims=[64, 64, 64],
axes_lens=[300, 512, 512],
)


class GradientCheckpointMixin(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False

def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True

def disable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = False

def forward(self, *args, **kwargs):
if self.training and self.gradient_checkpointing:
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
else:
return self._forward(*args, **kwargs)



def modulate(x, scale):
return x * (1 + scale.unsqueeze(1))


#############################################################################
# Embedding Layers for Timesteps and Class Labels #
#############################################################################


class TimestepEmbedder(GradientCheckpointMixin):
"""
Embeds scalar timesteps into vector representations.
"""

def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size,
hidden_size,
bias=True,
),
nn.SiLU(),
nn.Linear(
hidden_size,
hidden_size,
bias=True,
),
)
nn.init.normal_(self.mlp[0].weight, std=0.02)
nn.init.zeros_(self.mlp[0].bias)
nn.init.normal_(self.mlp[2].weight, std=0.02)
nn.init.zeros_(self.mlp[2].bias)

self.frequency_embedding_size = frequency_embedding_size

@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding

def _forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb


def to_cuda(x):
if isinstance(x, torch.Tensor):
return x.cuda()
elif isinstance(x, (list, tuple)):
return [to_cuda(elem) for elem in x]
elif isinstance(x, dict):
return {k: to_cuda(v) for k, v in x.items()}
else:
return x


def to_cpu(x):
if isinstance(x, torch.Tensor):
return x.cpu()
elif isinstance(x, (list, tuple)):
return [to_cpu(elem) for elem in x]
elif isinstance(x, dict):
return {k: to_cpu(v) for k, v in x.items()}
else:
return x


#############################################################################
# Core NextDiT Model #
#############################################################################


class JointAttention(nn.Module):
"""Multi-head attention module."""

def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int],
qk_norm: bool,
use_flash_attn=False,
use_sage_attn=False,
):
"""
Initialize the Attention module.

Args:
dim (int): Number of input dimensions.
n_heads (int): Number of heads.
n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
qk_norm (bool): Whether to use normalization for queries and keys.

"""
super().__init__()
self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
self.n_local_heads = n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = dim // n_heads

self.qkv = nn.Linear(
dim,
(n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
bias=False,
)
nn.init.xavier_uniform_(self.qkv.weight)

self.out = nn.Linear(
n_heads * self.head_dim,
dim,
bias=False,
)
nn.init.xavier_uniform_(self.out.weight)

if qk_norm:
self.q_norm = RMSNorm(self.head_dim)
self.k_norm = RMSNorm(self.head_dim)
else:
self.q_norm = self.k_norm = nn.Identity()

self.use_flash_attn = use_flash_attn
self.use_sage_attn = use_sage_attn

if use_sage_attn :
self.attention_processor = self.sage_attn
else:
# self.attention_processor = xformers.ops.memory_efficient_attention
self.attention_processor = F.scaled_dot_product_attention

def set_attention_processor(self, attention_processor):
self.attention_processor = attention_processor

def get_attention_processor(self):
return self.attention_processor

def forward(
self,
x: Tensor,
x_mask: Tensor,
freqs_cis: Tensor,
) -> Tensor:
"""
Args:
x:
x_mask:
freqs_cis:
"""
bsz, seqlen, _ = x.shape
dtype = x.dtype

xq, xk, xv = torch.split(
self.qkv(x),
[
self.n_local_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
self.n_local_kv_heads * self.head_dim,
],
dim=-1,
)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq = self.q_norm(xq)
xk = self.k_norm(xk)
xq = apply_rope(xq, freqs_cis=freqs_cis)
xk = apply_rope(xk, freqs_cis=freqs_cis)
xq, xk = xq.to(dtype), xk.to(dtype)

softmax_scale = math.sqrt(1 / self.head_dim)

if self.use_sage_attn:
# Handle GQA (Grouped Query Attention) if needed
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)

output = self.sage_attn(xq, xk, xv, x_mask, softmax_scale)
elif self.use_flash_attn:
output = self.flash_attn(xq, xk, xv, x_mask, softmax_scale)
else:
n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)

output = (
self.attention_processor(
xq.permute(0, 2, 1, 3),
xk.permute(0, 2, 1, 3),
xv.permute(0, 2, 1, 3),
attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
scale=softmax_scale,
)
.permute(0, 2, 1, 3)
.to(dtype)
)

output = output.flatten(-2)
return self.out(output)

# copied from huggingface modeling_llama.py
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)

indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
indices_k,
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
indices_k,
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim),
indices_k,
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)

def sage_attn(self, q: Tensor, k: Tensor, v: Tensor, x_mask: Tensor, softmax_scale: float):
try:
bsz = q.shape[0]
seqlen = q.shape[1]

# Transpose tensors to match SageAttention's expected format (HND layout)
q_transposed = q.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
k_transposed = k.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
v_transposed = v.permute(0, 2, 1, 3) # [batch, heads, seq_len, head_dim]
# Handle masking for SageAttention
# We need to filter out masked positions - this approach handles variable sequence lengths
outputs = []
for b in range(bsz):
# Find valid token positions from the mask
valid_indices = torch.nonzero(x_mask[b], as_tuple=False).squeeze(-1)
if valid_indices.numel() == 0:
# If all tokens are masked, create a zero output
batch_output = torch.zeros(
seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype
)
else:
# Extract only valid tokens for this batch
batch_q = q_transposed[b, :, valid_indices, :]
batch_k = k_transposed[b, :, valid_indices, :]
batch_v = v_transposed[b, :, valid_indices, :]
# Run SageAttention on valid tokens only
batch_output_valid = sageattn(
batch_q.unsqueeze(0), # Add batch dimension back
batch_k.unsqueeze(0),
batch_v.unsqueeze(0),
tensor_layout="HND",
is_causal=False,
sm_scale=softmax_scale
)
# Create output tensor with zeros for masked positions
batch_output = torch.zeros(
seqlen, self.n_local_heads, self.head_dim,
device=q.device, dtype=q.dtype
)
# Place valid outputs back in the right positions
batch_output[valid_indices] = batch_output_valid.squeeze(0).permute(1, 0, 2)
outputs.append(batch_output)
# Stack batch outputs and reshape to expected format
output = torch.stack(outputs, dim=0) # [batch, seq_len, heads, head_dim]
except NameError as e:
raise RuntimeError(
f"Could not load Sage Attention. Please install https://github.com/thu-ml/SageAttention. / Sage Attention を読み込めませんでした。https://github.com/thu-ml/SageAttention をインストールしてください。 / {e}"
)

return output

def flash_attn(
self,
q: Tensor,
k: Tensor,
v: Tensor,
x_mask: Tensor,
softmax_scale,
) -> Tensor:
bsz, seqlen, _, _ = q.shape

try:
# begin var_len flash attn
(
query_states,
key_states,
value_states,
indices_q,
cu_seq_lens,
max_seq_lens,
) = self._upad_input(q, k, v, x_mask, seqlen)

cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=0.0,
causal=False,
softmax_scale=softmax_scale,
)
output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
# end var_len_flash_attn

return output
except NameError as e:
raise RuntimeError(
f"Could not load flash attention. Please install flash_attn. / フラッシュアテンションを読み込めませんでした。flash_attn をインストールしてください。 / {e}"
)


def apply_rope(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.

This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.

Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
with torch.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)

return x_out.type_as(x_in)


class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
"""
Initialize the FeedForward module.

Args:
dim (int): Input dimension.
hidden_dim (int): Hidden dimension of the feedforward layer.
multiple_of (int): Value to ensure hidden dimension is a multiple
of this value.
ffn_dim_multiplier (float, optional): Custom multiplier for hidden
dimension. Defaults to None.

"""
super().__init__()
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
nn.init.xavier_uniform_(self.w1.weight)
self.w2 = nn.Linear(
hidden_dim,
dim,
bias=False,
)
nn.init.xavier_uniform_(self.w2.weight)
self.w3 = nn.Linear(
dim,
hidden_dim,
bias=False,
)
nn.init.xavier_uniform_(self.w3.weight)

# @torch.compile
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3

def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))


class JointTransformerBlock(GradientCheckpointMixin):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: Optional[int],
multiple_of: int,
ffn_dim_multiplier: Optional[float],
norm_eps: float,
qk_norm: bool,
modulation=True,
use_flash_attn=False,
use_sage_attn=False,
) -> None:
"""
Initialize a TransformerBlock.

Args:
layer_id (int): Identifier for the layer.
dim (int): Embedding dimension of the input features.
n_heads (int): Number of attention heads.
n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as
query.
multiple_of (int): Number of multiple of the hidden dimension.
ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
feedforward layer.
norm_eps (float): Epsilon value for normalization.
qk_norm (bool): Whether to use normalization for queries and keys.
modulation (bool): Whether to use modulation for the attention
layer.
"""
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn)
self.feed_forward = FeedForward(
dim=dim,
hidden_dim=4 * dim,
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)

self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)

self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(
min(dim, 1024),
4 * dim,
bias=True,
),
)
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)

def _forward(
self,
x: torch.Tensor,
x_mask: torch.Tensor,
pe: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
"""
Perform a forward pass through the TransformerBlock.

Args:
x (Tensor): Input tensor.
pe (Tensor): Rope position embedding.

Returns:
Tensor: Output tensor after applying attention and
feedforward layers.

"""
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)

x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
pe,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
)
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
self.attention(
self.attention_norm1(x),
x_mask,
pe,
)
)
x = x + self.ffn_norm2(
self.feed_forward(
self.ffn_norm1(x),
)
)
return x


class FinalLayer(GradientCheckpointMixin):
"""
The final layer of NextDiT.
"""

def __init__(self, hidden_size, patch_size, out_channels):
"""
Initialize the FinalLayer.

Args:
hidden_size (int): Hidden size of the input features.
patch_size (int): Patch size of the input features.
out_channels (int): Number of output channels.
"""
super().__init__()
self.norm_final = nn.LayerNorm(
hidden_size,
elementwise_affine=False,
eps=1e-6,
)
self.linear = nn.Linear(
hidden_size,
patch_size * patch_size * out_channels,
bias=True,
)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)

self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(
min(hidden_size, 1024),
hidden_size,
bias=True,
),
)
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)

def forward(self, x, c):
scale = self.adaLN_modulation(c)
x = modulate(self.norm_final(x), scale)
x = self.linear(x)
return x


class RopeEmbedder:
def __init__(
self,
theta: float = 10000.0,
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
):
super().__init__()
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.freqs_cis = NextDiT.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)

def __call__(self, ids: torch.Tensor):
device = ids.device
self.freqs_cis = [freqs_cis.to(ids.device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
freqs = self.freqs_cis[i].to(ids.device)
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1)


class NextDiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""

def __init__(
self,
patch_size: int = 2,
in_channels: int = 4,
dim: int = 4096,
n_layers: int = 32,
n_refiner_layers: int = 2,
n_heads: int = 32,
n_kv_heads: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
qk_norm: bool = False,
cap_feat_dim: int = 5120,
axes_dims: List[int] = [16, 56, 56],
axes_lens: List[int] = [1, 512, 512],
use_flash_attn=False,
use_sage_attn=False,
) -> None:
"""
Initialize the NextDiT model.

Args:
patch_size (int): Patch size of the input features.
in_channels (int): Number of input channels.
dim (int): Hidden size of the input features.
n_layers (int): Number of Transformer layers.
n_refiner_layers (int): Number of refiner layers.
n_heads (int): Number of attention heads.
n_kv_heads (Optional[int]): Number of attention heads in key and
value features (if using GQA), or set to None for the same as
query.
multiple_of (int): Multiple of the hidden size.
ffn_dim_multiplier (Optional[float]): Dimension multiplier for the
feedforward layer.
norm_eps (float): Epsilon value for normalization.
qk_norm (bool): Whether to use query key normalization.
cap_feat_dim (int): Dimension of the caption features.
axes_dims (List[int]): List of dimensions for the axes.
axes_lens (List[int]): List of lengths for the axes.
use_flash_attn (bool): Whether to use Flash Attention.
use_sage_attn (bool): Whether to use Sage Attention. Sage Attention only supports inference.

Returns:
None
"""
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
self.patch_size = patch_size

self.t_embedder = TimestepEmbedder(min(dim, 1024))
self.cap_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
dim,
bias=True,
),
)

nn.init.trunc_normal_(self.cap_embedder[1].weight, std=0.02)
nn.init.zeros_(self.cap_embedder[1].bias)

self.context_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=False,
)
for layer_id in range(n_refiner_layers)
]
)

self.x_embedder = nn.Linear(
in_features=patch_size * patch_size * in_channels,
out_features=dim,
bias=True,
)
nn.init.xavier_uniform_(self.x_embedder.weight)
nn.init.constant_(self.x_embedder.bias, 0.0)

self.noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
)
for layer_id in range(n_refiner_layers)
]
)


self.layers = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
use_flash_attn=use_flash_attn,
use_sage_attn=use_sage_attn,
)
for layer_id in range(n_layers)
]
)
self.norm_final = RMSNorm(dim, eps=norm_eps)
self.final_layer = FinalLayer(dim, patch_size, self.out_channels)

assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
self.dim = dim
self.n_heads = n_heads

self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False # TODO: not yet supported
self.blocks_to_swap = None # TODO: not yet supported

@property
def device(self):
return next(self.parameters()).device

@property
def dtype(self):
return next(self.parameters()).dtype

def enable_gradient_checkpointing(self, cpu_offload: bool = False):
self.gradient_checkpointing = True
self.cpu_offload_checkpointing = cpu_offload

self.t_embedder.enable_gradient_checkpointing()

for block in self.layers + self.context_refiner + self.noise_refiner:
block.enable_gradient_checkpointing(cpu_offload=cpu_offload)

self.final_layer.enable_gradient_checkpointing()

print(f"Lumina: Gradient checkpointing enabled. CPU offload: {cpu_offload}")

def disable_gradient_checkpointing(self):
self.gradient_checkpointing = False
self.cpu_offload_checkpointing = False

self.t_embedder.disable_gradient_checkpointing()

for block in self.layers + self.context_refiner + self.noise_refiner:
block.disable_gradient_checkpointing()

self.final_layer.disable_gradient_checkpointing()

print("Lumina: Gradient checkpointing disabled.")

def unpatchify(
self,
x: Tensor,
width: int,
height: int,
encoder_seq_lengths: List[int],
seq_lengths: List[int],
) -> Tensor:
"""
Unpatchify the input tensor and embed the caption features.
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)

Args:
x (Tensor): Input tensor.
width (int): Width of the input tensor.
height (int): Height of the input tensor.
encoder_seq_lengths (List[int]): List of encoder sequence lengths.
seq_lengths (List[int]): List of sequence lengths

Returns:
output: (N, C, H, W)
"""
pH = pW = self.patch_size

output = []
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
output.append(
x[i][encoder_seq_len:seq_len]
.view(height // pH, width // pW, pH, pW, self.out_channels)
.permute(4, 0, 2, 1, 3)
.flatten(3, 4)
.flatten(1, 2)
)
output = torch.stack(output, dim=0)

return output

def patchify_and_embed(
self,
x: Tensor,
cap_feats: Tensor,
cap_mask: Tensor,
t: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:
"""
Patchify and embed the input image and caption features.

Args:
x: (N, C, H, W) image latents
cap_feats: (N, C, D) caption features
cap_mask: (N, C, D) caption attention mask
t: (N), T timesteps

Returns:
Tuple[Tensor, Tensor, Tensor, List[int], List[int]]:

return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths
"""
bsz, channels, height, width = x.shape
pH = pW = self.patch_size
device = x.device

l_effective_cap_len = cap_mask.sum(dim=1).tolist()
encoder_seq_len = cap_mask.shape[1]
image_seq_len = (height // self.patch_size) * (width // self.patch_size)

seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
max_seq_len = max(seq_lengths)

position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)

for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
H_tokens, W_tokens = height // pH, width // pW

position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:seq_len, 0] = cap_len

row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()

position_ids[i, cap_len:seq_len, 1] = row_ids
position_ids[i, cap_len:seq_len, 2] = col_ids

# Get combined rotary embeddings
freqs_cis = self.rope_embedder(position_ids)

# Create separate rotary embeddings for captions and images
cap_freqs_cis = torch.zeros(
bsz,
encoder_seq_len,
freqs_cis.shape[-1],
device=device,
dtype=freqs_cis.dtype,
)
img_freqs_cis = torch.zeros(
bsz,
image_seq_len,
freqs_cis.shape[-1],
device=device,
dtype=freqs_cis.dtype,
)

for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_len:seq_len]

# Refine caption context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)

x = x.view(bsz, channels, height // pH, pH, width // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)

x_mask = torch.zeros(bsz, image_seq_len, dtype=torch.bool, device=device)
for i in range(bsz):
x[i, :image_seq_len] = x[i]
x_mask[i, :image_seq_len] = True

x = self.x_embedder(x)

# Refine image context
for layer in self.noise_refiner:
x = layer(x, x_mask, img_freqs_cis, t)

joint_hidden_states = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x.dtype)
attention_mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device)
for i, (cap_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
attention_mask[i, :seq_len] = True
joint_hidden_states[i, :cap_len] = cap_feats[i, :cap_len]
joint_hidden_states[i, cap_len:seq_len] = x[i]

x = joint_hidden_states

return x, attention_mask, freqs_cis, l_effective_cap_len, seq_lengths

def forward(self, x: Tensor, t: Tensor, cap_feats: Tensor, cap_mask: Tensor) -> Tensor:
"""
Forward pass of NextDiT.
Args:
x: (N, C, H, W) image latents
t: (N,) tensor of diffusion timesteps
cap_feats: (N, L, D) caption features
cap_mask: (N, L) caption attention mask

Returns:
x: (N, C, H, W) denoised latents
"""
_, _, height, width = x.shape # B, C, H, W
t = self.t_embedder(t) # (N, D)
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

x, mask, freqs_cis, l_effective_cap_len, seq_lengths = self.patchify_and_embed(x, cap_feats, cap_mask, t)

if not self.blocks_to_swap:
for layer in self.layers:
x = layer(x, mask, freqs_cis, t)
else:
for block_idx, layer in enumerate(self.layers):
self.offloader_main.wait_for_block(block_idx)
x = layer(x, mask, freqs_cis, t)
self.offloader_main.submit_move_blocks(self.layers, block_idx)

x = self.final_layer(x, t)
x = self.unpatchify(x, width, height, l_effective_cap_len, seq_lengths)

return x

def forward_with_cfg(
self,
x: Tensor,
t: Tensor,
cap_feats: Tensor,
cap_mask: Tensor,
cfg_scale: float,
cfg_trunc: float = 0.25,
renorm_cfg: float = 1.0,
):
"""
Forward pass of NextDiT, but also batches the unconditional forward pass
for classifier-free guidance.
"""
# # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
if t[0] < cfg_trunc:
combined = torch.cat([half, half], dim=0) # [2, 16, 128, 128]
assert (
cap_mask.shape[0] == combined.shape[0]
), f"caption attention mask shape: {cap_mask.shape[0]} latents shape: {combined.shape[0]}"
model_out = self.forward(x, t, cap_feats, cap_mask) # [2, 16, 128, 128]
# For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
eps, rest = (
model_out[:, : self.in_channels],
model_out[:, self.in_channels :],
)
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
if float(renorm_cfg) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(cond_eps, dim=tuple(range(1, len(cond_eps.shape))), keepdim=True)
max_new_norm = ori_pos_norm * float(renorm_cfg)
new_pos_norm = torch.linalg.vector_norm(half_eps, dim=tuple(range(1, len(half_eps.shape))), keepdim=True)
if new_pos_norm >= max_new_norm:
half_eps = half_eps * (max_new_norm / new_pos_norm)
else:
combined = half
model_out = self.forward(
combined,
t[: len(x) // 2],
cap_feats[: len(x) // 2],
cap_mask[: len(x) // 2],
)
eps, rest = (
model_out[:, : self.in_channels],
model_out[:, self.in_channels :],
)
half_eps = eps

output = torch.cat([half_eps, half_eps], dim=0)
return output

@staticmethod
def precompute_freqs_cis(
dim: List[int],
end: List[int],
theta: float = 10000.0,
) -> List[Tensor]:
"""
Precompute the frequency tensor for complex exponentials (cis) with
given dimensions.

This function calculates a frequency tensor with complex exponentials
using the given dimension 'dim' and the end index 'end'. The 'theta'
parameter scales the frequencies. The returned tensor contains complex
values in complex64 data type.

Args:
dim (list): Dimension of the frequency tensor.
end (list): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation.
Defaults to 10000.0.

Returns:
List[torch.Tensor]: Precomputed frequency tensor with complex
exponentials.
"""
freqs_cis = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64

for i, (d, e) in enumerate(zip(dim, end)):
pos = torch.arange(e, dtype=freqs_dtype, device="cpu")
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=freqs_dtype, device="cpu") / d))
freqs = torch.outer(pos, freqs)
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs) # [S, D/2]
freqs_cis.append(freqs_cis_i)

return freqs_cis

def parameter_count(self) -> int:
total_params = 0

def _recursive_count_params(module):
nonlocal total_params
for param in module.parameters(recurse=False):
total_params += param.numel()
for submodule in module.children():
_recursive_count_params(submodule)

_recursive_count_params(self)
return total_params

def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
return list(self.layers)

def get_checkpointing_wrap_module_list(self) -> List[nn.Module]:
return list(self.layers)

def enable_block_swap(self, blocks_to_swap: int, device: torch.device):
"""
Enable block swapping to reduce memory usage during inference.
Args:
num_blocks (int): Number of blocks to swap between CPU and device
device (torch.device): Device to use for computation
"""
self.blocks_to_swap = blocks_to_swap
# Calculate how many blocks to swap from main layers
assert blocks_to_swap <= len(self.layers) - 2, (
f"Cannot swap more than {len(self.layers) - 2} main blocks. "
f"Requested {blocks_to_swap} blocks."
)
self.offloader_main = custom_offloading_utils.ModelOffloader(
self.layers, blocks_to_swap, device, debug=False
)

def move_to_device_except_swap_blocks(self, device: torch.device):
"""
Move the model to the device except for blocks that will be swapped.
This reduces temporary memory usage during model loading.
Args:
device (torch.device): Device to move the model to
"""
if self.blocks_to_swap:
save_layers = self.layers
self.layers = nn.ModuleList([])
self.to(device)
if self.blocks_to_swap:
self.layers = save_layers

def prepare_block_swap_before_forward(self):
"""
Prepare blocks for swapping before forward pass.
"""
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
return
self.offloader_main.prepare_block_devices_before_forward(self.layers)


#############################################################################
# NextDiT Configs #
#############################################################################


def NextDiT_2B_GQA_patch2_Adaln_Refiner(params: Optional[LuminaParams] = None, **kwargs):
if params is None:
params = LuminaParams.get_2b_config()

return NextDiT(
patch_size=params.patch_size,
in_channels=params.in_channels,
dim=params.dim,
n_layers=params.n_layers,
n_heads=params.n_heads,
n_kv_heads=params.n_kv_heads,
axes_dims=params.axes_dims,
axes_lens=params.axes_lens,
qk_norm=params.qk_norm,
ffn_dim_multiplier=params.ffn_dim_multiplier,
norm_eps=params.norm_eps,
cap_feat_dim=params.cap_feat_dim,
**kwargs,
)


def NextDiT_3B_GQA_patch2_Adaln_Refiner(**kwargs):
return NextDiT(
patch_size=2,
dim=2592,
n_layers=30,
n_heads=24,
n_kv_heads=8,
axes_dims=[36, 36, 36],
axes_lens=[300, 512, 512],
**kwargs,
)


def NextDiT_4B_GQA_patch2_Adaln_Refiner(**kwargs):
return NextDiT(
patch_size=2,
dim=2880,
n_layers=32,
n_heads=24,
n_kv_heads=8,
axes_dims=[40, 40, 40],
axes_lens=[300, 512, 512],
**kwargs,
)


def NextDiT_7B_GQA_patch2_Adaln_Refiner(**kwargs):
return NextDiT(
patch_size=2,
dim=3840,
n_layers=32,
n_heads=32,
n_kv_heads=8,
axes_dims=[40, 40, 40],
axes_lens=[300, 512, 512],
**kwargs,
)

+ 1098
- 0
scripts/dev/library/lumina_train_util.py View File

@@ -0,0 +1,1098 @@
import inspect
import argparse
import math
import os
import numpy as np
import time
from typing import Callable, Dict, List, Optional, Tuple, Any, Union, Generator

import torch
from torch import Tensor
from accelerate import Accelerator, PartialState
from transformers import Gemma2Model
from tqdm import tqdm
from PIL import Image
from safetensors.torch import save_file

from library import lumina_models, strategy_base, strategy_lumina, train_util
from library.flux_models import AutoEncoder
from library.device_utils import init_ipex, clean_memory_on_device
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler

init_ipex()

from .utils import setup_logging, mem_eff_save_file

setup_logging()
import logging

logger = logging.getLogger(__name__)


# region sample images


def batchify(
prompt_dicts, batch_size=None
) -> Generator[list[dict[str, str]], None, None]:
"""
Group prompt dictionaries into batches with configurable batch size.

Args:
prompt_dicts (list): List of dictionaries containing prompt parameters.
batch_size (int, optional): Number of prompts per batch. Defaults to None.

Yields:
list[dict[str, str]]: Batch of prompts.
"""
# Validate batch_size
if batch_size is not None:
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size must be a positive integer or None")

# Group prompts by their parameters
batches = {}
for prompt_dict in prompt_dicts:
# Extract parameters
width = int(prompt_dict.get("width", 1024))
height = int(prompt_dict.get("height", 1024))
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8
guidance_scale = float(prompt_dict.get("scale", 3.5))
sample_steps = int(prompt_dict.get("sample_steps", 38))
cfg_trunc_ratio = float(prompt_dict.get("cfg_trunc_ratio", 0.25))
renorm_cfg = float(prompt_dict.get("renorm_cfg", 1.0))
seed = prompt_dict.get("seed", None)
seed = int(seed) if seed is not None else None

# Create a key based on the parameters
key = (
width,
height,
guidance_scale,
seed,
sample_steps,
cfg_trunc_ratio,
renorm_cfg,
)

# Add the prompt_dict to the corresponding batch
if key not in batches:
batches[key] = []
batches[key].append(prompt_dict)

# Yield each batch with its parameters
for key in batches:
prompts = batches[key]
if batch_size is None:
# Yield the entire group as a single batch
yield prompts
else:
# Split the group into batches of size `batch_size`
start = 0
while start < len(prompts):
end = start + batch_size
batch = prompts[start:end]
yield batch
start = end


@torch.no_grad()
def sample_images(
accelerator: Accelerator,
args: argparse.Namespace,
epoch: int,
global_step: int,
nextdit: lumina_models.NextDiT,
vae: AutoEncoder,
gemma2_model: Gemma2Model,
sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
):
"""
Generate sample images using the NextDiT model.

Args:
accelerator (Accelerator): Accelerator instance.
args (argparse.Namespace): Command-line arguments.
epoch (int): Current epoch number.
global_step (int): Current global step number.
nextdit (lumina_models.NextDiT): The NextDiT model instance.
vae (AutoEncoder): The VAE module.
gemma2_model (Gemma2Model): The Gemma2 model instance.
sample_prompts_gemma2_outputs (dict[str, Tuple[Tensor, Tensor, Tensor]]):
Dictionary of tuples containing the encoded prompts, text masks, and timestep for each sample.
prompt_replacement (Optional[Tuple[str, str]], optional):
Tuple containing the prompt and negative prompt replacements. Defaults to None.
controlnet (): ControlNet model, not yet supported

Returns:
None
"""
if global_step == 0:
if not args.sample_at_first:
return
else:
if args.sample_every_n_steps is None and args.sample_every_n_epochs is None:
return
if args.sample_every_n_epochs is not None:
# sample_every_n_steps は無視する
if epoch is None or epoch % args.sample_every_n_epochs != 0:
return
else:
if (
global_step % args.sample_every_n_steps != 0 or epoch is not None
): # steps is not divisible or end of epoch
return

assert (
args.sample_prompts is not None
), "No sample prompts found. Provide `--sample_prompts` / サンプルプロンプトが見つかりません。`--sample_prompts` を指定してください"

logger.info("")
logger.info(
f"generating sample images at step / サンプル画像生成 ステップ: {global_step}"
)
if (
not os.path.isfile(args.sample_prompts)
and sample_prompts_gemma2_outputs is None
):
logger.error(
f"No prompt file / プロンプトファイルがありません: {args.sample_prompts}"
)
return

distributed_state = (
PartialState()
) # for multi gpu distributed inference. this is a singleton, so it's safe to use it here

# unwrap nextdit and gemma2_model
nextdit = accelerator.unwrap_model(nextdit)
if gemma2_model is not None:
gemma2_model = accelerator.unwrap_model(gemma2_model)
# if controlnet is not None:
# controlnet = accelerator.unwrap_model(controlnet)
# print([(te.parameters().__next__().device if te is not None else None) for te in text_encoders])

prompts = train_util.load_prompts(args.sample_prompts)

save_dir = args.output_dir + "/sample"
os.makedirs(save_dir, exist_ok=True)

# save random state to restore later
rng_state = torch.get_rng_state()
cuda_rng_state = None
try:
cuda_rng_state = (
torch.cuda.get_rng_state() if torch.cuda.is_available() else None
)
except Exception:
pass

batch_size = args.sample_batch_size or args.train_batch_size or 1

if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
# TODO: batch prompts together with buckets of image sizes
for prompt_dicts in batchify(prompts, batch_size):
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
vae,
save_dir,
prompt_dicts,
epoch,
global_step,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet,
)
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
per_process_prompts = [] # list of lists
for i in range(distributed_state.num_processes):
per_process_prompts.append(prompts[i :: distributed_state.num_processes])

with distributed_state.split_between_processes(
per_process_prompts
) as prompt_dict_lists:
# TODO: batch prompts together with buckets of image sizes
for prompt_dicts in batchify(prompt_dict_lists[0], batch_size):
sample_image_inference(
accelerator,
args,
nextdit,
gemma2_model,
vae,
save_dir,
prompt_dicts,
epoch,
global_step,
sample_prompts_gemma2_outputs,
prompt_replacement,
controlnet,
)

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)

clean_memory_on_device(accelerator.device)


@torch.no_grad()
def sample_image_inference(
accelerator: Accelerator,
args: argparse.Namespace,
nextdit: lumina_models.NextDiT,
gemma2_model: list[Gemma2Model],
vae: AutoEncoder,
save_dir: str,
prompt_dicts: list[Dict[str, str]],
epoch: int,
global_step: int,
sample_prompts_gemma2_outputs: dict[str, Tuple[Tensor, Tensor, Tensor]],
prompt_replacement: Optional[Tuple[str, str]] = None,
controlnet=None,
):
"""
Generates sample images

Args:
accelerator (Accelerator): Accelerator object
args (argparse.Namespace): Arguments object
nextdit (lumina_models.NextDiT): NextDiT model
gemma2_model (list[Gemma2Model]): Gemma2 model
vae (AutoEncoder): VAE model
save_dir (str): Directory to save images
prompt_dict (Dict[str, str]): Prompt dictionary
epoch (int): Epoch number
steps (int): Number of steps to run
sample_prompts_gemma2_outputs (List[Tuple[Tensor, Tensor, Tensor]]): List of tuples containing Gemma 2 outputs
prompt_replacement (Optional[Tuple[str, str]], optional): Replacement for positive and negative prompt. Defaults to None.

Returns:
None
"""

# encode prompts
tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()

assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)

text_conds = []

# assuming seed, width, height, sample steps, guidance are the same
width = int(prompt_dicts[0].get("width", 1024))
height = int(prompt_dicts[0].get("height", 1024))
height = max(64, height - height % 8) # round to divisible by 8
width = max(64, width - width % 8) # round to divisible by 8

guidance_scale = float(prompt_dicts[0].get("scale", 3.5))
cfg_trunc_ratio = float(prompt_dicts[0].get("cfg_trunc_ratio", 0.25))
renorm_cfg = float(prompt_dicts[0].get("renorm_cfg", 1.0))
sample_steps = int(prompt_dicts[0].get("sample_steps", 36))
seed = prompt_dicts[0].get("seed", None)
seed = int(seed) if seed is not None else None
assert seed is None or seed > 0, f"Invalid seed {seed}"
generator = torch.Generator(device=accelerator.device)
if seed is not None:
generator.manual_seed(seed)

for prompt_dict in prompt_dicts:
controlnet_image = prompt_dict.get("controlnet_image")
prompt: str = prompt_dict.get("prompt", "")
negative_prompt = prompt_dict.get("negative_prompt", "")
# sampler_name: str = prompt_dict.get("sample_sampler", args.sample_sampler)

if prompt_replacement is not None:
prompt = prompt.replace(prompt_replacement[0], prompt_replacement[1])
if negative_prompt is not None:
negative_prompt = negative_prompt.replace(
prompt_replacement[0], prompt_replacement[1]
)

if negative_prompt is None:
negative_prompt = ""
logger.info(f"prompt: {prompt}")
logger.info(f"negative_prompt: {negative_prompt}")
logger.info(f"height: {height}")
logger.info(f"width: {width}")
logger.info(f"sample_steps: {sample_steps}")
logger.info(f"scale: {guidance_scale}")
logger.info(f"trunc: {cfg_trunc_ratio}")
logger.info(f"renorm: {renorm_cfg}")
# logger.info(f"sample_sampler: {sampler_name}")


# No need to add system prompt here, as it has been handled in the tokenize_strategy

# Get sample prompts from cache
if sample_prompts_gemma2_outputs and prompt in sample_prompts_gemma2_outputs:
gemma2_conds = sample_prompts_gemma2_outputs[prompt]
logger.info(f"Using cached Gemma2 outputs for prompt: {prompt}")

if (
sample_prompts_gemma2_outputs
and negative_prompt in sample_prompts_gemma2_outputs
):
neg_gemma2_conds = sample_prompts_gemma2_outputs[negative_prompt]
logger.info(
f"Using cached Gemma2 outputs for negative prompt: {negative_prompt}"
)

# Load sample prompts from Gemma 2
if gemma2_model is not None:
tokens_and_masks = tokenize_strategy.tokenize(prompt)
gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, gemma2_model, tokens_and_masks
)

tokens_and_masks = tokenize_strategy.tokenize(negative_prompt, is_negative=True)
neg_gemma2_conds = encoding_strategy.encode_tokens(
tokenize_strategy, gemma2_model, tokens_and_masks
)

# Unpack Gemma2 outputs
gemma2_hidden_states, _, gemma2_attn_mask = gemma2_conds
neg_gemma2_hidden_states, _, neg_gemma2_attn_mask = neg_gemma2_conds

text_conds.append(
(
gemma2_hidden_states.squeeze(0),
gemma2_attn_mask.squeeze(0),
neg_gemma2_hidden_states.squeeze(0),
neg_gemma2_attn_mask.squeeze(0),
)
)

# Stack conditioning
cond_hidden_states = torch.stack([text_cond[0] for text_cond in text_conds]).to(
accelerator.device
)
cond_attn_masks = torch.stack([text_cond[1] for text_cond in text_conds]).to(
accelerator.device
)
uncond_hidden_states = torch.stack([text_cond[2] for text_cond in text_conds]).to(
accelerator.device
)
uncond_attn_masks = torch.stack([text_cond[3] for text_cond in text_conds]).to(
accelerator.device
)

# sample image
weight_dtype = vae.dtype # TOFO give dtype as argument
latent_height = height // 8
latent_width = width // 8
latent_channels = 16
noise = torch.randn(
1,
latent_channels,
latent_height,
latent_width,
device=accelerator.device,
dtype=weight_dtype,
generator=generator,
)
noise = noise.repeat(cond_hidden_states.shape[0], 1, 1, 1)

scheduler = FlowMatchEulerDiscreteScheduler(shift=6.0)
timesteps, num_inference_steps = retrieve_timesteps(
scheduler, num_inference_steps=sample_steps
)

# if controlnet_image is not None:
# controlnet_image = Image.open(controlnet_image).convert("RGB")
# controlnet_image = controlnet_image.resize((width, height), Image.LANCZOS)
# controlnet_image = torch.from_numpy((np.array(controlnet_image) / 127.5) - 1)
# controlnet_image = controlnet_image.permute(2, 0, 1).unsqueeze(0).to(weight_dtype).to(accelerator.device)

with accelerator.autocast():
x = denoise(
scheduler,
nextdit,
noise,
cond_hidden_states,
cond_attn_masks,
uncond_hidden_states,
uncond_attn_masks,
timesteps=timesteps,
guidance_scale=guidance_scale,
cfg_trunc_ratio=cfg_trunc_ratio,
renorm_cfg=renorm_cfg,
)

# Latent to image
clean_memory_on_device(accelerator.device)
org_vae_device = vae.device # will be on cpu
vae.to(accelerator.device) # distributed_state.device is same as accelerator.device
for img, prompt_dict in zip(x, prompt_dicts):

img = (img / vae.scale_factor) + vae.shift_factor

with accelerator.autocast():
# Add a single batch image for the VAE to decode
img = vae.decode(img.unsqueeze(0))

img = img.clamp(-1, 1)
img = img.permute(0, 2, 3, 1) # B, H, W, C
# Scale images back to 0 to 255
img = (127.5 * (img + 1.0)).float().cpu().numpy().astype(np.uint8)

# Get single image
image = Image.fromarray(img[0])

# adding accelerator.wait_for_everyone() here should sync up and ensure that sample images are saved in the same order as the original prompt list
# but adding 'enum' to the filename should be enough

ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
num_suffix = f"e{epoch:06d}" if epoch is not None else f"{global_step:06d}"
seed_suffix = "" if seed is None else f"_{seed}"
i: int = int(prompt_dict.get("enum", 0))
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}.png"
image.save(os.path.join(save_dir, img_filename))

# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")

import wandb

# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log(
{f"sample_{i}": wandb.Image(image, caption=prompt)}, commit=False
) # positive prompt as a caption

vae.to(org_vae_device)
clean_memory_on_device(accelerator.device)


def time_shift(mu: float, sigma: float, t: torch.Tensor):
# the following implementation was original for t=0: clean / t=1: noise
# Since we adopt the reverse, the 1-t operations are needed
t = 1 - t
t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
t = 1 - t
return t


def get_lin_function(
x1: float = 256, x2: float = 4096, y1: float = 0.5, y2: float = 1.15
) -> Callable[[float], float]:
"""
Get linear function

Args:
image_seq_len,
x1 base_seq_len: int = 256,
y2 max_seq_len: int = 4096,
y1 base_shift: float = 0.5,
y2 max_shift: float = 1.15,

Return:
Callable[[float], float]: linear function
"""
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b


def get_schedule(
num_steps: int,
image_seq_len: int,
base_shift: float = 0.5,
max_shift: float = 1.15,
shift: bool = True,
) -> list[float]:
"""
Get timesteps schedule

Args:
num_steps (int): Number of steps in the schedule.
image_seq_len (int): Sequence length of the image.
base_shift (float, optional): Base shift value. Defaults to 0.5.
max_shift (float, optional): Maximum shift value. Defaults to 1.15.
shift (bool, optional): Whether to shift the schedule. Defaults to True.

Return:
List[float]: timesteps schedule
"""
timesteps = torch.linspace(1, 1 / num_steps, num_steps)

# shifting the schedule to favor high timesteps for higher signal images
if shift:
# eastimate mu based on linear estimation between two points
mu = get_lin_function(y1=base_shift, y2=max_shift, x1=256, x2=4096)(
image_seq_len
)
timesteps = time_shift(mu, 1.0, timesteps)

return timesteps.tolist()


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
) -> Tuple[torch.Tensor, int]:
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.

Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.

Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError(
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
)
if timesteps is not None:
accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps

def denoise(
scheduler,
model: lumina_models.NextDiT,
img: Tensor,
txt: Tensor,
txt_mask: Tensor,
neg_txt: Tensor,
neg_txt_mask: Tensor,
timesteps: Union[List[float], torch.Tensor],
guidance_scale: float = 4.0,
cfg_trunc_ratio: float = 0.25,
renorm_cfg: float = 1.0,
):
"""
Denoise an image using the NextDiT model.

Args:
scheduler ():
Noise scheduler
model (lumina_models.NextDiT): The NextDiT model instance.
img (Tensor):
The input image latent tensor.
txt (Tensor):
The input text tensor.
txt_mask (Tensor):
The input text mask tensor.
neg_txt (Tensor):
The negative input txt tensor
neg_txt_mask (Tensor):
The negative input text mask tensor.
timesteps (List[Union[float, torch.FloatTensor]]):
A list of timesteps for the denoising process.
guidance_scale (float, optional):
The guidance scale for the denoising process. Defaults to 4.0.
cfg_trunc_ratio (float, optional):
The ratio of the timestep interval to apply normalization-based guidance scale.
renorm_cfg (float, optional):
The factor to limit the maximum norm after guidance. Default: 1.0
Returns:
img (Tensor): Denoised latent tensor
"""

for i, t in enumerate(tqdm(timesteps)):
model.prepare_block_swap_before_forward()

# reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
current_timestep = 1 - t / scheduler.config.num_train_timesteps
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep * torch.ones(
img.shape[0], device=img.device
)

noise_pred_cond = model(
img,
current_timestep,
cap_feats=txt, # Gemma2的hidden states作为caption features
cap_mask=txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
)

# compute whether to apply classifier-free guidance based on current timestep
if current_timestep[0] < cfg_trunc_ratio:
model.prepare_block_swap_before_forward()
noise_pred_uncond = model(
img,
current_timestep,
cap_feats=neg_txt, # Gemma2的hidden states作为caption features
cap_mask=neg_txt_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# apply normalization after classifier-free guidance
if float(renorm_cfg) > 0.0:
cond_norm = torch.linalg.vector_norm(
noise_pred_cond,
dim=tuple(range(1, len(noise_pred_cond.shape))),
keepdim=True,
)
max_new_norms = cond_norm * float(renorm_cfg)
noise_norms = torch.linalg.vector_norm(
noise_pred, dim=tuple(range(1, len(noise_pred.shape))), keepdim=True
)
# Iterate through batch
for i, (noise_norm, max_new_norm) in enumerate(zip(noise_norms, max_new_norms)):
if noise_norm >= max_new_norm:
noise_pred[i] = noise_pred[i] * (max_new_norm / noise_norm)
else:
noise_pred = noise_pred_cond

img_dtype = img.dtype

if img.dtype != img_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
img = img.to(img_dtype)

# compute the previous noisy sample x_t -> x_t-1
noise_pred = -noise_pred
img = scheduler.step(noise_pred, t, img, return_dict=False)[0]

model.prepare_block_swap_before_forward()
return img


# endregion


# region train
def get_sigmas(
noise_scheduler: FlowMatchEulerDiscreteScheduler,
timesteps: Tensor,
device: torch.device,
n_dim=4,
dtype=torch.float32,
) -> Tensor:
"""
Get sigmas for timesteps

Args:
noise_scheduler (FlowMatchEulerDiscreteScheduler): The noise scheduler instance.
timesteps (Tensor): A tensor of timesteps for the denoising process.
device (torch.device): The device on which the tensors are stored.
n_dim (int, optional): The number of dimensions for the output tensor. Defaults to 4.
dtype (torch.dtype, optional): The data type for the output tensor. Defaults to torch.float32.

Returns:
sigmas (Tensor): The sigmas tensor.
"""
sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(device)
timesteps = timesteps.to(device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma


def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
):
"""
Compute the density for sampling the timesteps when doing SD3 training.

Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

SD3 paper reference: https://arxiv.org/abs/2403.03206v1.

Args:
weighting_scheme (str): The weighting scheme to use.
batch_size (int): The batch size for the sampling process.
logit_mean (float, optional): The mean of the logit distribution. Defaults to None.
logit_std (float, optional): The standard deviation of the logit distribution. Defaults to None.
mode_scale (float, optional): The mode scale for the mode weighting scheme. Defaults to None.

Returns:
u (Tensor): The sampled timesteps.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(
mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu"
)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u


def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None) -> Tensor:
"""Computes loss weighting scheme for SD3 training.

Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.

SD3 paper reference: https://arxiv.org/abs/2403.03206v1.

Args:
weighting_scheme (str): The weighting scheme to use.
sigmas (Tensor, optional): The sigmas tensor. Defaults to None.

Returns:
u (Tensor): The sampled timesteps.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting


def get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, device, dtype
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Get noisy model input and timesteps.

Args:
args (argparse.Namespace): Arguments.
noise_scheduler (noise_scheduler): Noise scheduler.
latents (Tensor): Latents.
noise (Tensor): Latent noise.
device (torch.device): Device.
dtype (torch.dtype): Data type

Return:
Tuple[Tensor, Tensor, Tensor]:
noisy model input
timesteps
sigmas
"""
bsz, _, h, w = latents.shape
sigmas = None

if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random t-based noise sampling
if args.timestep_sampling == "sigmoid":
# https://github.com/XLabs-AI/x-flux/tree/main
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
else:
t = torch.rand((bsz,), device=device)

timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "shift":
shift = args.discrete_flow_shift
logits_norm = torch.randn(bsz, device=device)
logits_norm = (
logits_norm * args.sigmoid_scale
) # larger scale for more uniform sampling
timesteps = logits_norm.sigmoid()
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)

t = timesteps.view(-1, 1, 1, 1)
timesteps = timesteps * 1000.0
noisy_model_input = (1 - t) * noise + t * latents
elif args.timestep_sampling == "nextdit_shift":
t = torch.rand((bsz,), device=device)
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
t = time_shift(mu, 1.0, t)

timesteps = t * 1000.0
t = t.view(-1, 1, 1, 1)
noisy_model_input = (1 - t) * noise + t * latents
else:
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=device)

# Add noise according to flow matching.
sigmas = get_sigmas(
noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype
)
noisy_model_input = sigmas * latents + (1.0 - sigmas) * noise

return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas


def apply_model_prediction_type(
args, model_pred: Tensor, noisy_model_input: Tensor, sigmas: Tensor
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Apply model prediction type to the model prediction and the sigmas.

Args:
args (argparse.Namespace): Arguments.
model_pred (Tensor): Model prediction.
noisy_model_input (Tensor): Noisy model input.
sigmas (Tensor): Sigmas.

Return:
Tuple[Tensor, Optional[Tensor]]:
"""
weighting = None
if args.model_prediction_type == "raw":
pass
elif args.model_prediction_type == "additive":
# add the model_pred to the noisy_model_input
model_pred = model_pred + noisy_model_input
elif args.model_prediction_type == "sigma_scaled":
# apply sigma scaling
model_pred = model_pred * (-sigmas) + noisy_model_input

# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(
weighting_scheme=args.weighting_scheme, sigmas=sigmas
)

return model_pred, weighting


def save_models(
ckpt_path: str,
lumina: lumina_models.NextDiT,
sai_metadata: Dict[str, Any],
save_dtype: Optional[torch.dtype] = None,
use_mem_eff_save: bool = False,
):
"""
Save the model to the checkpoint path.

Args:
ckpt_path (str): Path to the checkpoint.
lumina (lumina_models.NextDiT): NextDIT model.
sai_metadata (Optional[dict]): Metadata for the SAI model.
save_dtype (Optional[torch.dtype]): Data

Return:
None
"""
state_dict = {}

def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
if save_dtype is not None and v.dtype != save_dtype:
v = v.detach().clone().to("cpu").to(save_dtype)
state_dict[key] = v

update_sd("", lumina.state_dict())

if not use_mem_eff_save:
save_file(state_dict, ckpt_path, metadata=sai_metadata)
else:
mem_eff_save_file(state_dict, ckpt_path, metadata=sai_metadata)


def save_lumina_model_on_train_end(
args: argparse.Namespace,
save_dtype: torch.dtype,
epoch: int,
global_step: int,
lumina: lumina_models.NextDiT,
):
def sd_saver(ckpt_file, epoch_no, global_step):
sai_metadata = train_util.get_sai_model_spec(
None,
args,
False,
False,
False,
is_stable_diffusion_ckpt=True,
lumina="lumina2",
)
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)

train_util.save_sd_model_on_train_end_common(
args, True, True, epoch, global_step, sd_saver, None
)


# epochとstepの保存、メタデータにepoch/stepが含まれ引数が同じになるため、統合してている
# on_epoch_end: Trueならepoch終了時、Falseならstep経過時
def save_lumina_model_on_epoch_end_or_stepwise(
args: argparse.Namespace,
on_epoch_end: bool,
accelerator: Accelerator,
save_dtype: torch.dtype,
epoch: int,
num_train_epochs: int,
global_step: int,
lumina: lumina_models.NextDiT,
):
"""
Save the model to the checkpoint path.

Args:
args (argparse.Namespace): Arguments.
save_dtype (torch.dtype): Data type.
epoch (int): Epoch.
global_step (int): Global step.
lumina (lumina_models.NextDiT): NextDIT model.

Return:
None
"""

def sd_saver(ckpt_file: str, epoch_no: int, global_step: int):
sai_metadata = train_util.get_sai_model_spec(
{},
args,
False,
False,
False,
is_stable_diffusion_ckpt=True,
lumina="lumina2",
)
save_models(ckpt_file, lumina, sai_metadata, save_dtype, args.mem_eff_save)

train_util.save_sd_model_on_epoch_end_or_stepwise_common(
args,
on_epoch_end,
accelerator,
True,
True,
epoch,
num_train_epochs,
global_step,
sd_saver,
None,
)


# endregion


def add_lumina_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--gemma2",
type=str,
help="path to gemma2 model (*.sft or *.safetensors), should be float16 / gemma2のパス(*.sftまたは*.safetensors)、float16が前提",
)
parser.add_argument(
"--ae",
type=str,
help="path to ae (*.sft or *.safetensors) / aeのパス(*.sftまたは*.safetensors)",
)
parser.add_argument(
"--gemma2_max_token_length",
type=int,
default=None,
help="maximum token length for Gemma2. if omitted, 256"
" / Gemma2の最大トークン長。省略された場合、256になります",
)

parser.add_argument(
"--timestep_sampling",
choices=["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"],
default="shift",
help="Method to sample timesteps: sigma-based, uniform random, sigmoid of random normal, shift of sigmoid and NextDIT.1 shifting. Default is 'shift'."
" / タイムステップをサンプリングする方法:sigma、random uniform、random normalのsigmoid、sigmoidのシフト、NextDIT.1のシフト。デフォルトは'shift'です。",
)
parser.add_argument(
"--sigmoid_scale",
type=float,
default=1.0,
help='Scale factor for sigmoid timestep sampling (only used when timestep-sampling is "sigmoid"). / sigmoidタイムステップサンプリングの倍率(timestep-samplingが"sigmoid"の場合のみ有効)。',
)
parser.add_argument(
"--model_prediction_type",
choices=["raw", "additive", "sigma_scaled"],
default="raw",
help="How to interpret and process the model prediction: "
"raw (use as is), additive (add to noisy input), sigma_scaled (apply sigma scaling)."
" / モデル予測の解釈と処理方法:"
"raw(そのまま使用)、additive(ノイズ入力に加算)、sigma_scaled(シグマスケーリングを適用)。",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=6.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 6.0 / Euler Discrete Schedulerの離散フローシフト、デフォルトは6.0",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
help="Use Flash Attention for the model / モデルにFlash Attentionを使用する",
)
parser.add_argument(
"--use_sage_attn",
action="store_true",
help="Use Sage Attention for the model / モデルにSage Attentionを使用する",
)
parser.add_argument(
"--system_prompt",
type=str,
default="",
help="System prompt to add to the prompt / プロンプトに追加するシステムプロンプト",
)
parser.add_argument(
"--sample_batch_size",
type=int,
default=None,
help="Batch size to use for sampling, defaults to --training_batch_size value. Sample batches are bucketed by width, height, guidance scale, and seed / サンプリングに使用するバッチサイズ。デフォルトは --training_batch_size の値です。サンプルバッチは、幅、高さ、ガイダンススケール、シードによってバケット化されます",
)

+ 259
- 0
scripts/dev/library/lumina_util.py View File

@@ -0,0 +1,259 @@
import json
import os
from dataclasses import replace
from typing import List, Optional, Tuple, Union

import einops
import torch
from accelerate import init_empty_weights
from safetensors import safe_open
from safetensors.torch import load_file
from transformers import Gemma2Config, Gemma2Model

from library.utils import setup_logging
from library import lumina_models, flux_models
from library.utils import load_safetensors
import logging

setup_logging()
logger = logging.getLogger(__name__)

MODEL_VERSION_LUMINA_V2 = "lumina2"


def load_lumina_model(
ckpt_path: str,
dtype: Optional[torch.dtype],
device: torch.device,
disable_mmap: bool = False,
use_flash_attn: bool = False,
use_sage_attn: bool = False,
):
"""
Load the Lumina model from the checkpoint path.

Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (torch.device): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
use_flash_attn (bool, optional): Whether to use flash attention. Defaults to False.

Returns:
model (lumina_models.NextDiT): The loaded model.
"""
logger.info("Building Lumina")
with torch.device("meta"):
model = lumina_models.NextDiT_2B_GQA_patch2_Adaln_Refiner(use_flash_attn=use_flash_attn, use_sage_attn=use_sage_attn).to(
dtype
)

logger.info(f"Loading state dict from {ckpt_path}")
state_dict = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)

# Neta-Lumina support
if "model.diffusion_model.cap_embedder.0.weight" in state_dict:
# remove "model.diffusion_model." prefix
filtered_state_dict = {
k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
}
state_dict = filtered_state_dict

info = model.load_state_dict(state_dict, strict=False, assign=True)
logger.info(f"Loaded Lumina: {info}")
return model


def load_ae(
ckpt_path: str,
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
) -> flux_models.AutoEncoder:
"""
Load the AutoEncoder model from the checkpoint path.

Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.

Returns:
ae (flux_models.AutoEncoder): The loaded model.
"""
logger.info("Building AutoEncoder")
with torch.device("meta"):
# dev and schnell have the same AE params
ae = flux_models.AutoEncoder(flux_models.configs["schnell"].ae_params).to(dtype)

logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=device, disable_mmap=disable_mmap, dtype=dtype)

# Neta-Lumina support
if "vae.decoder.conv_in.bias" in sd:
# remove "vae." prefix
filtered_sd = {k.replace("vae.", ""): v for k, v in sd.items() if k.startswith("vae.")}
sd = filtered_sd

info = ae.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded AE: {info}")
return ae


def load_gemma2(
ckpt_path: Optional[str],
dtype: torch.dtype,
device: Union[str, torch.device],
disable_mmap: bool = False,
state_dict: Optional[dict] = None,
) -> Gemma2Model:
"""
Load the Gemma2 model from the checkpoint path.

Args:
ckpt_path (str): Path to the checkpoint.
dtype (torch.dtype): The data type for the model.
device (Union[str, torch.device]): The device to load the model on.
disable_mmap (bool, optional): Whether to disable mmap. Defaults to False.
state_dict (Optional[dict], optional): The state dict to load. Defaults to None.

Returns:
gemma2 (Gemma2Model): The loaded model
"""
logger.info("Building Gemma2")
GEMMA2_CONFIG = {
"_name_or_path": "google/gemma-2-2b",
"architectures": ["Gemma2Model"],
"attention_bias": False,
"attention_dropout": 0.0,
"attn_logit_softcapping": 50.0,
"bos_token_id": 2,
"cache_implementation": "hybrid",
"eos_token_id": 1,
"final_logit_softcapping": 30.0,
"head_dim": 256,
"hidden_act": "gelu_pytorch_tanh",
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2304,
"initializer_range": 0.02,
"intermediate_size": 9216,
"max_position_embeddings": 8192,
"model_type": "gemma2",
"num_attention_heads": 8,
"num_hidden_layers": 26,
"num_key_value_heads": 4,
"pad_token_id": 0,
"query_pre_attn_scalar": 256,
"rms_norm_eps": 1e-06,
"rope_theta": 10000.0,
"sliding_window": 4096,
"torch_dtype": "float32",
"transformers_version": "4.44.2",
"use_cache": True,
"vocab_size": 256000,
}

config = Gemma2Config(**GEMMA2_CONFIG)
with init_empty_weights():
gemma2 = Gemma2Model._from_config(config)

if state_dict is not None:
sd = state_dict
else:
logger.info(f"Loading state dict from {ckpt_path}")
sd = load_safetensors(ckpt_path, device=str(device), disable_mmap=disable_mmap, dtype=dtype)

for key in list(sd.keys()):
new_key = key.replace("model.", "")
if new_key == key:
break # the model doesn't have annoying prefix
sd[new_key] = sd.pop(key)

# Neta-Lumina support
if "text_encoders.gemma2_2b.logit_scale" in sd:
# remove "text_encoders.gemma2_2b.transformer.model." prefix
filtered_sd = {
k.replace("text_encoders.gemma2_2b.transformer.model.", ""): v
for k, v in sd.items()
if k.startswith("text_encoders.gemma2_2b.transformer.model.")
}
sd = filtered_sd

info = gemma2.load_state_dict(sd, strict=False, assign=True)
logger.info(f"Loaded Gemma2: {info}")
return gemma2


def unpack_latents(x: torch.Tensor, packed_latent_height: int, packed_latent_width: int) -> torch.Tensor:
"""
x: [b (h w) (c ph pw)] -> [b c (h ph) (w pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=packed_latent_height, w=packed_latent_width, ph=2, pw=2)
return x


def pack_latents(x: torch.Tensor) -> torch.Tensor:
"""
x: [b c (h ph) (w pw)] -> [b (h w) (c ph pw)], ph=2, pw=2
"""
x = einops.rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
return x


DIFFUSERS_TO_ALPHA_VLLM_MAP: dict[str, str] = {
# Embedding layers
"time_caption_embed.caption_embedder.0.weight": "cap_embedder.0.weight",
"time_caption_embed.caption_embedder.1.weight": "cap_embedder.1.weight",
"text_embedder.1.bias": "cap_embedder.1.bias",
"patch_embedder.proj.weight": "x_embedder.weight",
"patch_embedder.proj.bias": "x_embedder.bias",
# Attention modulation
"transformer_blocks.().adaln_modulation.1.weight": "layers.().adaLN_modulation.1.weight",
"transformer_blocks.().adaln_modulation.1.bias": "layers.().adaLN_modulation.1.bias",
# Final layers
"final_adaln_modulation.1.weight": "final_layer.adaLN_modulation.1.weight",
"final_adaln_modulation.1.bias": "final_layer.adaLN_modulation.1.bias",
"final_linear.weight": "final_layer.linear.weight",
"final_linear.bias": "final_layer.linear.bias",
# Noise refiner
"single_transformer_blocks.().adaln_modulation.1.weight": "noise_refiner.().adaLN_modulation.1.weight",
"single_transformer_blocks.().adaln_modulation.1.bias": "noise_refiner.().adaLN_modulation.1.bias",
"single_transformer_blocks.().attn.to_qkv.weight": "noise_refiner.().attention.qkv.weight",
"single_transformer_blocks.().attn.to_out.0.weight": "noise_refiner.().attention.out.weight",
# Normalization
"transformer_blocks.().norm1.weight": "layers.().attention_norm1.weight",
"transformer_blocks.().norm2.weight": "layers.().attention_norm2.weight",
# FFN
"transformer_blocks.().ff.net.0.proj.weight": "layers.().feed_forward.w1.weight",
"transformer_blocks.().ff.net.2.weight": "layers.().feed_forward.w2.weight",
"transformer_blocks.().ff.net.4.weight": "layers.().feed_forward.w3.weight",
}


def convert_diffusers_sd_to_alpha_vllm(sd: dict, num_double_blocks: int) -> dict:
"""Convert Diffusers checkpoint to Alpha-VLLM format"""
logger.info("Converting Diffusers checkpoint to Alpha-VLLM format")
new_sd = sd.copy() # Preserve original keys

for diff_key, alpha_key in DIFFUSERS_TO_ALPHA_VLLM_MAP.items():
# Handle block-specific patterns
if "()." in diff_key:
for block_idx in range(num_double_blocks):
block_alpha_key = alpha_key.replace("().", f"{block_idx}.")
block_diff_key = diff_key.replace("().", f"{block_idx}.")

# Search for and convert block-specific keys
for input_key, value in list(sd.items()):
if input_key == block_diff_key:
new_sd[block_alpha_key] = value
else:
# Handle static keys
if diff_key in sd:
print(f"Replacing {diff_key} with {alpha_key}")
new_sd[alpha_key] = sd[diff_key]
else:
print(f"Not found: {diff_key}")

logger.info(f"Converted {len(new_sd)} keys to Alpha-VLLM format")
return new_sd

+ 484
- 129
scripts/dev/library/sai_model_spec.py View File

@@ -1,14 +1,19 @@
# based on https://github.com/Stability-AI/ModelSpec
import datetime
import hashlib
import argparse
import base64
import logging
import mimetypes
import subprocess
from dataclasses import dataclass, field
from io import BytesIO
import os
from typing import List, Optional, Tuple, Union
from typing import Union
import safetensors
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

@@ -31,23 +36,34 @@ metadata = {
"""

BASE_METADATA = {
# === Must ===
"modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec
# === MUST ===
"modelspec.sai_model_spec": "1.0.1",
"modelspec.architecture": None,
"modelspec.implementation": None,
"modelspec.title": None,
"modelspec.resolution": None,
# === Should ===
# === SHOULD ===
"modelspec.description": None,
"modelspec.author": None,
"modelspec.date": None,
# === Can ===
"modelspec.hash_sha256": None,
# === CAN===
"modelspec.implementation_version": None,
"modelspec.license": None,
"modelspec.usage_hint": None,
"modelspec.thumbnail": None,
"modelspec.tags": None,
"modelspec.merged_from": None,
"modelspec.trigger_phrase": None,
"modelspec.prediction_type": None,
"modelspec.timestep_range": None,
"modelspec.encoder_layer": None,
"modelspec.preprocessor": None,
"modelspec.is_negative_embedding": None,
"modelspec.unet_dtype": None,
"modelspec.vae_dtype": None,
}

# 別に使うやつだけ定義
@@ -60,7 +76,11 @@ ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base"
ARCH_SD3_M = "stable-diffusion-3" # may be followed by "-m" or "-5-large" etc.
# ARCH_SD3_UNKNOWN = "stable-diffusion-3"
ARCH_FLUX_1_DEV = "flux-1-dev"
ARCH_FLUX_1_SCHNELL = "flux-1-schnell"
ARCH_FLUX_1_CHROMA = "chroma" # for Flux Chroma
ARCH_FLUX_1_UNKNOWN = "flux-1"
ARCH_LUMINA_2 = "lumina-2"
ARCH_LUMINA_UNKNOWN = "lumina"

ADAPTER_LORA = "lora"
ADAPTER_TEXTUAL_INVERSION = "textual-inversion"
@@ -69,11 +89,253 @@ IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models"
IMPL_COMFY_UI = "https://github.com/comfyanonymous/ComfyUI"
IMPL_DIFFUSERS = "diffusers"
IMPL_FLUX = "https://github.com/black-forest-labs/flux"
IMPL_CHROMA = "https://huggingface.co/lodestones/Chroma"
IMPL_LUMINA = "https://github.com/Alpha-VLLM/Lumina-Image-2.0"

PRED_TYPE_EPSILON = "epsilon"
PRED_TYPE_V = "v"


@dataclass
class ModelSpecMetadata:
"""
ModelSpec 1.0.1 compliant metadata for safetensors models.
All fields correspond to modelspec.* keys in the final metadata.
"""
# === MUST ===
architecture: str
implementation: str
title: str
resolution: str
sai_model_spec: str = "1.0.1"
# === SHOULD ===
description: str | None = None
author: str | None = None
date: str | None = None
hash_sha256: str | None = None
# === CAN ===
implementation_version: str | None = None
license: str | None = None
usage_hint: str | None = None
thumbnail: str | None = None
tags: str | None = None
merged_from: str | None = None
trigger_phrase: str | None = None
prediction_type: str | None = None
timestep_range: str | None = None
encoder_layer: str | None = None
preprocessor: str | None = None
is_negative_embedding: str | None = None
unet_dtype: str | None = None
vae_dtype: str | None = None
# === Additional metadata ===
additional_fields: dict[str, str] = field(default_factory=dict)
def to_metadata_dict(self) -> dict[str, str]:
"""Convert dataclass to metadata dictionary with modelspec. prefixes."""
metadata = {}
# Add all non-None fields with modelspec prefix
for field_name, value in self.__dict__.items():
if field_name == "additional_fields":
# Handle additional fields separately
for key, val in value.items():
if key.startswith("modelspec."):
metadata[key] = val
else:
metadata[f"modelspec.{key}"] = val
elif value is not None:
metadata[f"modelspec.{field_name}"] = value
return metadata
@classmethod
def from_args(cls, args, **kwargs) -> "ModelSpecMetadata":
"""Create ModelSpecMetadata from argparse Namespace, extracting metadata_* fields."""
metadata_fields = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
value = getattr(args, attr_name, None)
if value is not None:
# Remove metadata_ prefix
field_name = attr_name[9:] # len("metadata_") = 9
metadata_fields[field_name] = value
# Handle known standard fields
standard_fields = {
"author": metadata_fields.pop("author", None),
"description": metadata_fields.pop("description", None),
"license": metadata_fields.pop("license", None),
"tags": metadata_fields.pop("tags", None),
}
# Remove None values
standard_fields = {k: v for k, v in standard_fields.items() if v is not None}
# Merge with kwargs and remaining metadata fields
all_fields = {**standard_fields, **kwargs}
if metadata_fields:
all_fields["additional_fields"] = metadata_fields
return cls(**all_fields)


def determine_architecture(
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
model_config: dict[str, str] | None = None
) -> str:
"""Determine model architecture string from parameters."""
model_config = model_config or {}
if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif "sd3" in model_config:
arch = ARCH_SD3_M + "-" + model_config["sd3"]
elif "flux" in model_config:
flux_type = model_config["flux"]
if flux_type == "dev":
arch = ARCH_FLUX_1_DEV
elif flux_type == "schnell":
arch = ARCH_FLUX_1_SCHNELL
elif flux_type == "chroma":
arch = ARCH_FLUX_1_CHROMA
else:
arch = ARCH_FLUX_1_UNKNOWN
elif "lumina" in model_config:
lumina_type = model_config["lumina"]
if lumina_type == "lumina2":
arch = ARCH_LUMINA_2
else:
arch = ARCH_LUMINA_UNKNOWN
elif v2:
arch = ARCH_SD_V2_768_V if v_parameterization else ARCH_SD_V2_512
else:
arch = ARCH_SD_V1
# Add adapter suffix
if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"
return arch


def determine_implementation(
lora: bool,
textual_inversion: bool,
sdxl: bool,
model_config: dict[str, str] | None = None,
is_stable_diffusion_ckpt: bool | None = None
) -> str:
"""Determine implementation string from parameters."""
model_config = model_config or {}
if "flux" in model_config:
if model_config["flux"] == "chroma":
return IMPL_CHROMA
else:
return IMPL_FLUX
elif "lumina" in model_config:
return IMPL_LUMINA
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
return IMPL_STABILITY_AI
else:
return IMPL_DIFFUSERS


def get_implementation_version() -> str:
"""Get the current implementation version as sd-scripts/{commit_hash}."""
try:
# Get the git commit hash
result = subprocess.run(
["git", "rev-parse", "HEAD"],
capture_output=True,
text=True,
cwd=os.path.dirname(os.path.dirname(__file__)), # Go up to sd-scripts root
timeout=5
)
if result.returncode == 0:
commit_hash = result.stdout.strip()
return f"sd-scripts/{commit_hash}"
else:
logger.warning("Failed to get git commit hash, using fallback")
return "sd-scripts/unknown"
except (subprocess.TimeoutExpired, subprocess.SubprocessError, FileNotFoundError) as e:
logger.warning(f"Could not determine git commit: {e}")
return "sd-scripts/unknown"


def file_to_data_url(file_path: str) -> str:
"""Convert a file path to a data URL for embedding in metadata."""
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
# Get MIME type
mime_type, _ = mimetypes.guess_type(file_path)
if mime_type is None:
# Default to binary if we can't detect
mime_type = "application/octet-stream"
# Read file and encode as base64
with open(file_path, "rb") as f:
file_data = f.read()
encoded_data = base64.b64encode(file_data).decode("ascii")
return f"data:{mime_type};base64,{encoded_data}"


def determine_resolution(
reso: Union[int, tuple[int, int]] | None = None,
sdxl: bool = False,
model_config: dict[str, str] | None = None,
v2: bool = False,
v_parameterization: bool = False
) -> str:
"""Determine resolution string from parameters."""
model_config = model_config or {}
if reso is not None:
# Handle comma separated string
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
# Handle single int
if isinstance(reso, int):
reso = (reso, reso)
# Handle single-element tuple
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# Determine default resolution based on model type
if (sdxl or
"sd3" in model_config or
"flux" in model_config or
"lumina" in model_config):
reso = (1024, 1024)
elif v2 and v_parameterization:
reso = (768, 768)
else:
reso = (512, 512)
return f"{reso[0]}x{reso[1]}"


def load_bytes_in_safetensors(tensors):
bytes = safetensors.torch.save(tensors)
b = BytesIO(bytes)
@@ -103,77 +365,46 @@ def update_hash_sha256(metadata: dict, state_dict: dict):
raise NotImplementedError


def build_metadata(
state_dict: Optional[dict],
def build_metadata_dataclass(
state_dict: dict | None,
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: Optional[str] = None,
reso: Optional[Union[int, Tuple[int, int]]] = None,
is_stable_diffusion_ckpt: Optional[bool] = None,
author: Optional[str] = None,
description: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
merged_from: Optional[str] = None,
timesteps: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
sd3: Optional[str] = None,
flux: Optional[str] = None,
):
title: str | None = None,
reso: int | tuple[int, int] | None = None,
is_stable_diffusion_ckpt: bool | None = None,
author: str | None = None,
description: str | None = None,
license: str | None = None,
tags: str | None = None,
merged_from: str | None = None,
timesteps: tuple[int, int] | None = None,
clip_skip: int | None = None,
model_config: dict | None = None,
optional_metadata: dict | None = None,
) -> ModelSpecMetadata:
"""
sd3: only supports "m", flux: only supports "dev"
Build ModelSpec 1.0.1 compliant metadata dataclass.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# if state_dict is None, hash is not calculated

metadata = {}
metadata.update(BASE_METADATA)

# TODO メモリを消費せずかつ正しいハッシュ計算の方法がわかったら実装する
# if state_dict is not None:
# hash = precalculate_safetensors_hashes(state_dict)
# metadata["modelspec.hash_sha256"] = hash

if sdxl:
arch = ARCH_SD_XL_V1_BASE
elif sd3 is not None:
arch = ARCH_SD3_M + "-" + sd3
elif flux is not None:
if flux == "dev":
arch = ARCH_FLUX_1_DEV
else:
arch = ARCH_FLUX_1_UNKNOWN
elif v2:
if v_parameterization:
arch = ARCH_SD_V2_768_V
else:
arch = ARCH_SD_V2_512
else:
arch = ARCH_SD_V1

if lora:
arch += f"/{ADAPTER_LORA}"
elif textual_inversion:
arch += f"/{ADAPTER_TEXTUAL_INVERSION}"

metadata["modelspec.architecture"] = arch
# Use helper functions for complex logic
architecture = determine_architecture(
v2, v_parameterization, sdxl, lora, textual_inversion, model_config
)

if not lora and not textual_inversion and is_stable_diffusion_ckpt is None:
is_stable_diffusion_ckpt = True # default is stable diffusion ckpt if not lora and not textual_inversion

if flux is not None:
# Flux
impl = IMPL_FLUX
elif (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt:
# Stable Diffusion ckpt, TI, SDXL LoRA
impl = IMPL_STABILITY_AI
else:
# v1/v2 LoRA or Diffusers
impl = IMPL_DIFFUSERS
metadata["modelspec.implementation"] = impl
implementation = determine_implementation(
lora, textual_inversion, sdxl, model_config, is_stable_diffusion_ckpt
)

if title is None:
if lora:
@@ -183,92 +414,145 @@ def build_metadata(
else:
title = "Checkpoint"
title += f"@{timestamp}"
metadata[MODELSPEC_TITLE] = title

if author is not None:
metadata["modelspec.author"] = author
else:
del metadata["modelspec.author"]

if description is not None:
metadata["modelspec.description"] = description
else:
del metadata["modelspec.description"]

if merged_from is not None:
metadata["modelspec.merged_from"] = merged_from
else:
del metadata["modelspec.merged_from"]

if license is not None:
metadata["modelspec.license"] = license
else:
del metadata["modelspec.license"]

if tags is not None:
metadata["modelspec.tags"] = tags
else:
del metadata["modelspec.tags"]

# remove microsecond from time
int_ts = int(timestamp)

# time to iso-8601 compliant date
date = datetime.datetime.fromtimestamp(int_ts).isoformat()
metadata["modelspec.date"] = date

if reso is not None:
# comma separated to tuple
if isinstance(reso, str):
reso = tuple(map(int, reso.split(",")))
if len(reso) == 1:
reso = (reso[0], reso[0])
else:
# resolution is defined in dataset, so use default
if sdxl or sd3 is not None or flux is not None:
reso = 1024
elif v2 and v_parameterization:
reso = 768
else:
reso = 512
if isinstance(reso, int):
reso = (reso, reso)

metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}"
# Use helper function for resolution
resolution = determine_resolution(
reso, sdxl, model_config, v2, v_parameterization
)

if flux is not None:
del metadata["modelspec.prediction_type"]
elif v_parameterization:
metadata["modelspec.prediction_type"] = PRED_TYPE_V
else:
metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON
# Handle prediction type - Flux models don't use prediction_type
model_config = model_config or {}
prediction_type = None
if "flux" not in model_config:
if v_parameterization:
prediction_type = PRED_TYPE_V
else:
prediction_type = PRED_TYPE_EPSILON

# Handle timesteps
timestep_range = None
if timesteps is not None:
if isinstance(timesteps, str) or isinstance(timesteps, int):
timesteps = (timesteps, timesteps)
if len(timesteps) == 1:
timesteps = (timesteps[0], timesteps[0])
metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}"
else:
del metadata["modelspec.timestep_range"]
timestep_range = f"{timesteps[0]},{timesteps[1]}"

# Handle encoder layer (clip skip)
encoder_layer = None
if clip_skip is not None:
metadata["modelspec.encoder_layer"] = f"{clip_skip}"
else:
del metadata["modelspec.encoder_layer"]
encoder_layer = f"{clip_skip}"

# # assert all values are filled
# assert all([v is not None for v in metadata.values()]), metadata
if not all([v is not None for v in metadata.values()]):
logger.error(f"Internal error: some metadata values are None: {metadata}")
# TODO: Implement hash calculation when memory-efficient method is available
# hash_sha256 = None
# if state_dict is not None:
# hash_sha256 = precalculate_safetensors_hashes(state_dict)

# Process thumbnail - convert file path to data URL if needed
processed_optional_metadata = optional_metadata.copy() if optional_metadata else {}
if "thumbnail" in processed_optional_metadata:
thumbnail_value = processed_optional_metadata["thumbnail"]
# Check if it's already a data URL or if it's a file path
if thumbnail_value and not thumbnail_value.startswith("data:"):
try:
processed_optional_metadata["thumbnail"] = file_to_data_url(thumbnail_value)
logger.info(f"Converted thumbnail file {thumbnail_value} to data URL")
except FileNotFoundError as e:
logger.warning(f"Thumbnail file not found, skipping: {e}")
del processed_optional_metadata["thumbnail"]
except Exception as e:
logger.warning(f"Failed to convert thumbnail to data URL: {e}")
del processed_optional_metadata["thumbnail"]

# Automatically set implementation version if not provided
if "implementation_version" not in processed_optional_metadata:
processed_optional_metadata["implementation_version"] = get_implementation_version()

# Create the dataclass
metadata = ModelSpecMetadata(
architecture=architecture,
implementation=implementation,
title=title,
description=description,
author=author,
date=date,
license=license,
tags=tags,
merged_from=merged_from,
resolution=resolution,
prediction_type=prediction_type,
timestep_range=timestep_range,
encoder_layer=encoder_layer,
additional_fields=processed_optional_metadata
)

return metadata


def build_metadata(
state_dict: dict | None,
v2: bool,
v_parameterization: bool,
sdxl: bool,
lora: bool,
textual_inversion: bool,
timestamp: float,
title: str | None = None,
reso: int | tuple[int, int] | None = None,
is_stable_diffusion_ckpt: bool | None = None,
author: str | None = None,
description: str | None = None,
license: str | None = None,
tags: str | None = None,
merged_from: str | None = None,
timesteps: tuple[int, int] | None = None,
clip_skip: int | None = None,
model_config: dict | None = None,
optional_metadata: dict | None = None,
) -> dict[str, str]:
"""
Build ModelSpec 1.0.1 compliant metadata for safetensors models.
Legacy function that returns dict - prefer build_metadata_dataclass for new code.
Args:
model_config: Dict containing model type info, e.g. {"flux": "dev"}, {"sd3": "large"}
optional_metadata: Dict of additional metadata fields to include
"""
# Use the dataclass function and convert to dict
metadata_obj = build_metadata_dataclass(
state_dict=state_dict,
v2=v2,
v_parameterization=v_parameterization,
sdxl=sdxl,
lora=lora,
textual_inversion=textual_inversion,
timestamp=timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
author=author,
description=description,
license=license,
tags=tags,
merged_from=merged_from,
timesteps=timesteps,
clip_skip=clip_skip,
model_config=model_config,
optional_metadata=optional_metadata,
)
return metadata_obj.to_metadata_dict()


# region utils


def get_title(metadata: dict) -> Optional[str]:
def get_title(metadata: dict) -> str | None:
return metadata.get(MODELSPEC_TITLE, None)


@@ -283,7 +567,7 @@ def load_metadata_from_safetensors(model: str) -> dict:
return metadata


def build_merged_from(models: List[str]) -> str:
def build_merged_from(models: list[str]) -> str:
def get_title(model: str):
metadata = load_metadata_from_safetensors(model)
title = metadata.get(MODELSPEC_TITLE, None)
@@ -295,6 +579,77 @@ def build_merged_from(models: List[str]) -> str:
return ", ".join(titles)


def add_model_spec_arguments(parser: argparse.ArgumentParser):
"""Add all ModelSpec metadata arguments to the parser."""
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)
parser.add_argument(
"--metadata_usage_hint",
type=str,
default=None,
help="usage hint for model metadata / メタデータに書き込まれる使用方法のヒント",
)
parser.add_argument(
"--metadata_thumbnail",
type=str,
default=None,
help="thumbnail image as data URL or file path (will be converted to data URL) for model metadata / メタデータに書き込まれるサムネイル画像(データURLまたはファイルパス、ファイルパスの場合はデータURLに変換されます)",
)
parser.add_argument(
"--metadata_merged_from",
type=str,
default=None,
help="source models for merged model metadata / メタデータに書き込まれるマージ元モデル名",
)
parser.add_argument(
"--metadata_trigger_phrase",
type=str,
default=None,
help="trigger phrase for model metadata / メタデータに書き込まれるトリガーフレーズ",
)
parser.add_argument(
"--metadata_preprocessor",
type=str,
default=None,
help="preprocessor used for model metadata / メタデータに書き込まれる前処理手法",
)
parser.add_argument(
"--metadata_is_negative_embedding",
type=str,
default=None,
help="whether this is a negative embedding for model metadata / メタデータに書き込まれるネガティブ埋め込みかどうか",
)


# endregion




+ 2
- 2
scripts/dev/library/sd3_models.py View File

@@ -1080,7 +1080,7 @@ class MMDiT(nn.Module):
), f"Cannot swap more than {self.num_blocks - 2} blocks. Requested: {self.blocks_to_swap} blocks."

self.offloader = custom_offloading_utils.ModelOffloader(
self.joint_blocks, self.num_blocks, self.blocks_to_swap, device # , debug=True
self.joint_blocks, self.blocks_to_swap, device # , debug=True
)
print(f"SD3: Block swap enabled. Swapping {num_blocks} blocks, total blocks: {self.num_blocks}, device: {device}.")

@@ -1088,7 +1088,7 @@ class MMDiT(nn.Module):
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
if self.blocks_to_swap:
save_blocks = self.joint_blocks
self.joint_blocks = None
self.joint_blocks = nn.ModuleList()

self.to(device)



+ 4
- 4
scripts/dev/library/sd3_utils.py View File

@@ -50,14 +50,14 @@ def analyze_state_dict_state(state_dict: Dict, prefix: str = ""):
context_embedder_in_features = context_shape[1]
context_embedder_out_features = context_shape[0]

# only supports 3-5-large, medium or 3-medium
# only supports 3-5-large, medium or 3-medium. This is added after `stable-diffusion-3-`.
if qk_norm is not None:
if len(x_block_self_attn_layers) == 0:
model_type = "3-5-large"
model_type = "5-large"
else:
model_type = "3-5-medium"
model_type = "5-medium"
else:
model_type = "3-medium"
model_type = "medium"

params = sd3_models.SD3Params(
patch_size=patch_size,


+ 75
- 9
scripts/dev/library/strategy_base.py View File

@@ -2,7 +2,7 @@

import os
import re
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union, Callable

import numpy as np
import torch
@@ -430,9 +430,21 @@ class LatentsCachingStrategy:
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
apply_alpha_mask: bool,
multi_resolution: bool = False,
):
) -> bool:
"""
Args:
latents_stride: stride of latents
bucket_reso: resolution of the bucket
npz_path: path to the npz file
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
multi_resolution: whether to use multi-resolution latents

Returns:
bool
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
@@ -451,7 +463,7 @@ class LatentsCachingStrategy:
return False
if flip_aug and "latents_flipped" + key_reso_suffix not in npz:
return False
if alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
if apply_alpha_mask and "alpha_mask" + key_reso_suffix not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
@@ -462,22 +474,35 @@ class LatentsCachingStrategy:
# TODO remove circular dependency for ImageInfo
def _default_cache_batch_latents(
self,
encode_by_vae,
vae_device,
vae_dtype,
encode_by_vae: Callable,
vae_device: torch.device,
vae_dtype: torch.dtype,
image_infos: List,
flip_aug: bool,
alpha_mask: bool,
apply_alpha_mask: bool,
random_crop: bool,
multi_resolution: bool = False,
):
"""
Default implementation for cache_batch_latents. Image loading, VAE, flipping, alpha mask handling are common.

Args:
encode_by_vae: function to encode images by VAE
vae_device: device to use for VAE
vae_dtype: dtype to use for VAE
image_infos: list of ImageInfo
flip_aug: whether to flip images
apply_alpha_mask: whether to apply alpha mask
random_crop: whether to random crop images
multi_resolution: whether to use multi-resolution latents
Returns:
None
"""
from library import train_util # import here to avoid circular import

img_tensor, alpha_masks, original_sizes, crop_ltrbs = train_util.load_images_and_masks_for_caching(
image_infos, alpha_mask, random_crop
image_infos, apply_alpha_mask, random_crop
)
img_tensor = img_tensor.to(device=vae_device, dtype=vae_dtype)

@@ -519,12 +544,40 @@ class LatentsCachingStrategy:
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
for SD/SDXL

Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
return self._default_load_latents_from_disk(None, npz_path, bucket_reso)

def _default_load_latents_from_disk(
self, latents_stride: Optional[int], npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]:
"""
Args:
latents_stride (Optional[int]): Stride for latents. If None, load all latents.
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.
Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray]
]: Latent np tensors, original size, crop (left top, right bottom), flipped latents, alpha mask
"""
if latents_stride is None:
key_reso_suffix = ""
else:
@@ -552,6 +605,19 @@ class LatentsCachingStrategy:
alpha_mask=None,
key_reso_suffix="",
):
"""
Args:
npz_path (str): Path to the npz file.
latents_tensor (torch.Tensor): Latent tensor
original_size (List[int]): Original size of the image
crop_ltrb (List[int]): Crop left top right bottom
flipped_latents_tensor (Optional[torch.Tensor]): Flipped latent tensor
alpha_mask (Optional[torch.Tensor]): Alpha mask
key_reso_suffix (str): Key resolution suffix

Returns:
None
"""
kwargs = {}

if os.path.exists(npz_path):


+ 375
- 0
scripts/dev/library/strategy_lumina.py View File

@@ -0,0 +1,375 @@
import glob
import os
from typing import Any, List, Optional, Tuple, Union

import torch
from transformers import AutoTokenizer, AutoModel, Gemma2Model, GemmaTokenizerFast
from library import train_util
from library.strategy_base import (
LatentsCachingStrategy,
TokenizeStrategy,
TextEncodingStrategy,
TextEncoderOutputsCachingStrategy,
)
import numpy as np
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


GEMMA_ID = "google/gemma-2-2b"


class LuminaTokenizeStrategy(TokenizeStrategy):
def __init__(
self, system_prompt:str, max_length: Optional[int], tokenizer_cache_dir: Optional[str] = None
) -> None:
self.tokenizer: GemmaTokenizerFast = AutoTokenizer.from_pretrained(
GEMMA_ID, cache_dir=tokenizer_cache_dir
)
self.tokenizer.padding_side = "right"

if system_prompt is None:
system_prompt = ""
system_prompt_special_token = "<Prompt Start>"
system_prompt = f"{system_prompt} {system_prompt_special_token} " if system_prompt else ""
self.system_prompt = system_prompt

if max_length is None:
self.max_length = 256
else:
self.max_length = max_length

def tokenize(
self, text: Union[str, List[str]], is_negative: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize

Returns:
Tuple[torch.Tensor, torch.Tensor]:
token input ids, attention_masks
"""
text = [text] if isinstance(text, str) else text
# In training, we always add system prompt (is_negative=False)
if not is_negative:
# Add system prompt to the beginning of each text
text = [self.system_prompt + t for t in text]

encodings = self.tokenizer(
text,
max_length=self.max_length,
return_tensors="pt",
padding="max_length",
truncation=True,
pad_to_multiple_of=8,
)
return (encodings.input_ids, encodings.attention_mask)

def tokenize_with_weights(
self, text: str | List[str]
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
"""
Args:
text (Union[str, List[str]]): Text to tokenize

Returns:
Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
token input ids, attention_masks, weights
"""
# Gemma doesn't support weighted prompts, return uniform weights
tokens, attention_masks = self.tokenize(text)
weights = [torch.ones_like(t) for t in tokens]
return tokens, attention_masks, weights


class LuminaTextEncodingStrategy(TextEncodingStrategy):
def __init__(self) -> None:
super().__init__()

def encode_tokens(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: Tuple[torch.Tensor, torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks

Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
text_encoder = models[0]
# Check model or torch dynamo OptimizedModule
assert isinstance(text_encoder, Gemma2Model) or isinstance(text_encoder._orig_mod, Gemma2Model), f"text encoder is not Gemma2Model {text_encoder.__class__.__name__}"
input_ids, attention_masks = tokens

outputs = text_encoder(
input_ids=input_ids.to(text_encoder.device),
attention_mask=attention_masks.to(text_encoder.device),
output_hidden_states=True,
return_dict=True,
)

return outputs.hidden_states[-2], input_ids, attention_masks

def encode_tokens_with_weights(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: Tuple[torch.Tensor, torch.Tensor],
weights: List[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
tokens (Tuple[torch.Tensor, torch.Tensor]): tokens, attention_masks
weights_list (List[torch.Tensor]): Currently unused

Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states, input_ids, attention_masks
"""
# For simplicity, use uniform weighting
return self.encode_tokens(tokenize_strategy, models, tokens)


class LuminaTextEncoderOutputsCachingStrategy(TextEncoderOutputsCachingStrategy):
LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX = "_lumina_te.npz"

def __init__(
self,
cache_to_disk: bool,
batch_size: int,
skip_disk_cache_validity_check: bool,
is_partial: bool = False,
) -> None:
super().__init__(
cache_to_disk,
batch_size,
skip_disk_cache_validity_check,
is_partial,
)

def get_outputs_npz_path(self, image_abs_path: str) -> str:
return (
os.path.splitext(image_abs_path)[0]
+ LuminaTextEncoderOutputsCachingStrategy.LUMINA_TEXT_ENCODER_OUTPUTS_NPZ_SUFFIX
)

def is_disk_cached_outputs_expected(self, npz_path: str) -> bool:
"""
Args:
npz_path (str): Path to the npz file.

Returns:
bool: True if the npz file is expected to be cached.
"""
if not self.cache_to_disk:
return False
if not os.path.exists(npz_path):
return False
if self.skip_disk_cache_validity_check:
return True

try:
npz = np.load(npz_path)
if "hidden_state" not in npz:
return False
if "attention_mask" not in npz:
return False
if "input_ids" not in npz:
return False
except Exception as e:
logger.error(f"Error loading file: {npz_path}")
raise e

return True

def load_outputs_npz(self, npz_path: str) -> List[np.ndarray]:
"""
Load outputs from a npz file

Returns:
List[np.ndarray]: hidden_state, input_ids, attention_mask
"""
data = np.load(npz_path)
hidden_state = data["hidden_state"]
attention_mask = data["attention_mask"]
input_ids = data["input_ids"]
return [hidden_state, input_ids, attention_mask]

@torch.no_grad()
def cache_batch_outputs(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
text_encoding_strategy: TextEncodingStrategy,
batch: List[train_util.ImageInfo],
) -> None:
"""
Args:
tokenize_strategy (LuminaTokenizeStrategy): Tokenize strategy
models (List[Any]): Text encoders
text_encoding_strategy (LuminaTextEncodingStrategy):
infos (List): List of ImageInfo

Returns:
None
"""
assert isinstance(text_encoding_strategy, LuminaTextEncodingStrategy)
assert isinstance(tokenize_strategy, LuminaTokenizeStrategy)

captions = [info.caption for info in batch]

if self.is_weighted:
tokens, attention_masks, weights_list = (
tokenize_strategy.tokenize_with_weights(captions)
)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
models,
(tokens, attention_masks),
weights_list,
)
)
else:
tokens = tokenize_strategy.tokenize(captions)
hidden_state, input_ids, attention_masks = (
text_encoding_strategy.encode_tokens(
tokenize_strategy, models, tokens
)
)

if hidden_state.dtype != torch.float32:
hidden_state = hidden_state.float()

hidden_state = hidden_state.cpu().numpy()
attention_mask = attention_masks.cpu().numpy() # (B, S)
input_ids = input_ids.cpu().numpy() # (B, S)


for i, info in enumerate(batch):
hidden_state_i = hidden_state[i]
attention_mask_i = attention_mask[i]
input_ids_i = input_ids[i]

if self.cache_to_disk:
assert info.text_encoder_outputs_npz is not None, f"Text encoder cache outputs to disk not found for image {info.image_key}"
np.savez(
info.text_encoder_outputs_npz,
hidden_state=hidden_state_i,
attention_mask=attention_mask_i,
input_ids=input_ids_i,
)
else:
info.text_encoder_outputs = [
hidden_state_i,
input_ids_i,
attention_mask_i,
]


class LuminaLatentsCachingStrategy(LatentsCachingStrategy):
LUMINA_LATENTS_NPZ_SUFFIX = "_lumina.npz"

def __init__(
self, cache_to_disk: bool, batch_size: int, skip_disk_cache_validity_check: bool
) -> None:
super().__init__(cache_to_disk, batch_size, skip_disk_cache_validity_check)

@property
def cache_suffix(self) -> str:
return LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX

def get_latents_npz_path(
self, absolute_path: str, image_size: Tuple[int, int]
) -> str:
return (
os.path.splitext(absolute_path)[0]
+ f"_{image_size[0]:04d}x{image_size[1]:04d}"
+ LuminaLatentsCachingStrategy.LUMINA_LATENTS_NPZ_SUFFIX
)

def is_disk_cached_latents_expected(
self,
bucket_reso: Tuple[int, int],
npz_path: str,
flip_aug: bool,
alpha_mask: bool,
) -> bool:
"""
Args:
bucket_reso (Tuple[int, int]): The resolution of the bucket.
npz_path (str): Path to the npz file.
flip_aug (bool): Whether to flip the image.
alpha_mask (bool): Whether to apply
"""
return self._default_is_disk_cached_latents_expected(
8, bucket_reso, npz_path, flip_aug, alpha_mask, multi_resolution=True
)

def load_latents_from_disk(
self, npz_path: str, bucket_reso: Tuple[int, int]
) -> Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]:
"""
Args:
npz_path (str): Path to the npz file.
bucket_reso (Tuple[int, int]): The resolution of the bucket.

Returns:
Tuple[
Optional[np.ndarray],
Optional[List[int]],
Optional[List[int]],
Optional[np.ndarray],
Optional[np.ndarray],
]: Tuple of latent tensors, attention_mask, input_ids, latents, latents_unet
"""
return self._default_load_latents_from_disk(
8, npz_path, bucket_reso
) # support multi-resolution

# TODO remove circular dependency for ImageInfo
def cache_batch_latents(
self,
model,
batch: List,
flip_aug: bool,
alpha_mask: bool,
random_crop: bool,
):
encode_by_vae = lambda img_tensor: model.encode(img_tensor).to("cpu")
vae_device = model.device
vae_dtype = model.dtype

self._default_cache_batch_latents(
encode_by_vae,
vae_device,
vae_dtype,
batch,
flip_aug,
alpha_mask,
random_crop,
multi_resolution=True,
)

if not train_util.HIGH_VRAM:
train_util.clean_memory_on_device(model.device)

+ 175
- 60
scripts/dev/library/train_util.py View File

@@ -74,7 +74,7 @@ import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
import library.deepspeed_utils as deepspeed_utils
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image, validate_interpolation_fn

setup_logging()
import logging
@@ -113,14 +113,16 @@ except:
# JPEG-XL on Linux
try:
from jxlpy import JXLImagePlugin
from library.jpeg_xl_util import get_jxl_size

IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
pass

# JPEG-XL on Windows
# JPEG-XL on Linux and Windows
try:
import pillow_jxl
from library.jpeg_xl_util import get_jxl_size

IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
except:
@@ -205,6 +207,7 @@ class ImageInfo:
self.text_encoder_pool2: Optional[torch.Tensor] = None

self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime
self.resize_interpolation: Optional[str] = None


class BucketManager:
@@ -429,6 +432,7 @@ class BaseSubset:
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
@@ -459,6 +463,8 @@ class BaseSubset:
self.validation_seed = validation_seed
self.validation_split = validation_split

self.resize_interpolation = resize_interpolation


class DreamBoothSubset(BaseSubset):
def __init__(
@@ -490,6 +496,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

@@ -517,6 +524,7 @@ class DreamBoothSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)

self.is_reg = is_reg
@@ -559,6 +567,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"

@@ -586,6 +595,7 @@ class FineTuningSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)

self.metadata_file = metadata_file
@@ -624,6 +634,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
resize_interpolation: Optional[str] = None,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

@@ -651,6 +662,7 @@ class ControlNetSubset(BaseSubset):
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
resize_interpolation=resize_interpolation,
)

self.conditioning_data_dir = conditioning_data_dir
@@ -671,6 +683,7 @@ class BaseDataset(torch.utils.data.Dataset):
resolution: Optional[Tuple[int, int]],
network_multiplier: float,
debug_dataset: bool,
resize_interpolation: Optional[str] = None
) -> None:
super().__init__()

@@ -705,6 +718,10 @@ class BaseDataset(torch.utils.data.Dataset):

self.image_transforms = IMAGE_TRANSFORMS

if resize_interpolation is not None:
assert validate_interpolation_fn(resize_interpolation), f"Resize interpolation \"{resize_interpolation}\" is not a valid interpolation"
self.resize_interpolation = resize_interpolation

self.image_data: Dict[str, ImageInfo] = {}
self.image_to_subset: Dict[str, Union[DreamBoothSubset, FineTuningSubset]] = {}

@@ -1043,8 +1060,11 @@ class BaseDataset(torch.utils.data.Dataset):
self.bucket_info["buckets"][i] = {"resolution": reso, "count": len(bucket)}
logger.info(f"bucket {i}: resolution {reso}, count: {len(bucket)}")

img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
if len(img_ar_errors) == 0:
mean_img_ar_error = 0 # avoid NaN
else:
img_ar_errors = np.array(img_ar_errors)
mean_img_ar_error = np.mean(np.abs(img_ar_errors))
self.bucket_info["mean_img_ar_error"] = mean_img_ar_error
logger.info(f"mean ar error (without repeats): {mean_img_ar_error}")

@@ -1448,6 +1468,8 @@ class BaseDataset(torch.utils.data.Dataset):
)

def get_image_size(self, image_path):
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
return get_jxl_size(image_path)
# return imagesize.get(image_path)
image_size = imagesize.get(image_path)
if image_size[0] <= 0:
@@ -1494,7 +1516,7 @@ class BaseDataset(torch.utils.data.Dataset):
nh = int(height * scale + 0.5)
nw = int(width * scale + 0.5)
assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
image = resize_image(image, width, height, nw, nh, subset.resize_interpolation)
face_cx = int(face_cx * scale + 0.5)
face_cy = int(face_cy * scale + 0.5)
height, width = nh, nw
@@ -1591,7 +1613,7 @@ class BaseDataset(torch.utils.data.Dataset):

if self.enable_bucket:
img, original_size, crop_ltrb = trim_and_resize_if_required(
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size
subset.random_crop, img, image_info.bucket_reso, image_info.resized_size, resize_interpolation=image_info.resize_interpolation
)
else:
if face_cx > 0: # 顔位置情報あり
@@ -1852,8 +1874,9 @@ class DreamBoothDataset(BaseDataset):
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)

assert resolution is not None, f"resolution is required / resolution(解像度)指定は必須です"

@@ -2078,6 +2101,7 @@ class DreamBoothDataset(BaseDataset):

for img_path, caption, size in zip(img_paths, captions, sizes):
info = ImageInfo(img_path, num_repeats, caption, subset.is_reg, img_path)
info.resize_interpolation = subset.resize_interpolation if subset.resize_interpolation is not None else self.resize_interpolation
if size is not None:
info.image_size = size
if subset.is_reg:
@@ -2133,8 +2157,9 @@ class FineTuningDataset(BaseDataset):
debug_dataset: bool,
validation_seed: int,
validation_split: float,
resize_interpolation: Optional[str],
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)

self.batch_size = batch_size

@@ -2360,9 +2385,10 @@ class ControlNetDataset(BaseDataset):
bucket_no_upscale: bool,
debug_dataset: bool,
validation_split: float,
validation_seed: Optional[int],
validation_seed: Optional[int],
resize_interpolation: Optional[str] = None,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)
super().__init__(resolution, network_multiplier, debug_dataset, resize_interpolation)

db_subsets = []
for subset in subsets:
@@ -2394,6 +2420,7 @@ class ControlNetDataset(BaseDataset):
subset.caption_suffix,
subset.token_warmup_min,
subset.token_warmup_step,
resize_interpolation=subset.resize_interpolation,
)
db_subsets.append(db_subset)

@@ -2412,6 +2439,7 @@ class ControlNetDataset(BaseDataset):
debug_dataset,
validation_split,
validation_seed,
resize_interpolation,
)

# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
@@ -2420,7 +2448,8 @@ class ControlNetDataset(BaseDataset):
self.num_train_images = self.dreambooth_dataset_delegate.num_train_images
self.num_reg_images = self.dreambooth_dataset_delegate.num_reg_images
self.validation_split = validation_split
self.validation_seed = validation_seed
self.validation_seed = validation_seed
self.resize_interpolation = resize_interpolation

# assert all conditioning data exists
missing_imgs = []
@@ -2508,9 +2537,8 @@ class ControlNetDataset(BaseDataset):
assert (
cond_img.shape[0] == original_size_hw[0] and cond_img.shape[1] == original_size_hw[1]
), f"size of conditioning image is not match / 画像サイズが合いません: {image_info.absolute_path}"
cond_img = cv2.resize(
cond_img, image_info.resized_size, interpolation=cv2.INTER_AREA
) # INTER_AREAでやりたいのでcv2でリサイズ

cond_img = resize_image(cond_img, original_size_hw[1], original_size_hw[0], target_size_hw[1], target_size_hw[0], self.resize_interpolation)

# TODO support random crop
# 現在サポートしているcropはrandomではなく中央のみ
@@ -2524,7 +2552,7 @@ class ControlNetDataset(BaseDataset):
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
# resize to target
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0])))
cond_img = resize_image(cond_img, cond_img.shape[0], cond_img.shape[1], target_size_hw[1], target_size_hw[0], self.resize_interpolation)

if flipped:
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2921,17 +2949,13 @@ def load_image(image_path, alpha=False):

# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int]
random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int], resize_interpolation: Optional[str] = None
) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]:
image_height, image_width = image.shape[0:2]
original_size = (image_width, image_height) # size before resize

if image_width != resized_size[0] or image_height != resized_size[1]:
# リサイズする
if image_width > resized_size[0] and image_height > resized_size[1]:
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
else:
image = pil_resize(image, resized_size)
image = resize_image(image, image_width, image_height, resized_size[0], resized_size[1], resize_interpolation)

image_height, image_width = image.shape[0:2]

@@ -2976,7 +3000,7 @@ def load_images_and_masks_for_caching(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)

original_sizes.append(original_size)
crop_ltrbs.append(crop_ltrb)
@@ -3017,7 +3041,7 @@ def cache_batch_latents(
for info in image_infos:
image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size, resize_interpolation=info.resize_interpolation)

info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
@@ -3458,7 +3482,9 @@ def get_sai_model_spec(
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None, # None for TI and LoRA
sd3: str = None,
flux: str = None,
flux: str = None, # "dev", "schnell" or "chroma"
lumina: str = None,
optional_metadata: dict[str, str] | None = None
):
timestamp = time.time()

@@ -3475,6 +3501,34 @@ def get_sai_model_spec(
else:
timesteps = None

# Convert individual model parameters to model_config dict
# TODO: Update calls to this function to pass in the model config
model_config = {}
if sd3 is not None:
model_config["sd3"] = sd3
if flux is not None:
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina

# Extract metadata_* fields from args and merge with optional_metadata
extracted_metadata = {}
# Extract all metadata_* attributes from args
for attr_name in dir(args):
if attr_name.startswith("metadata_") and not attr_name.startswith("metadata___"):
value = getattr(args, attr_name, None)
if value is not None:
# Remove metadata_ prefix and exclude already handled fields
field_name = attr_name[9:] # len("metadata_") = 9
if field_name not in ["title", "author", "description", "license", "tags"]:
extracted_metadata[field_name] = value
# Merge extracted metadata with provided optional_metadata
all_optional_metadata = {**extracted_metadata}
if optional_metadata:
all_optional_metadata.update(optional_metadata)

metadata = sai_model_spec.build_metadata(
state_dict,
v2,
@@ -3492,12 +3546,75 @@ def get_sai_model_spec(
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip, # None or int
sd3=sd3,
flux=flux,
model_config=model_config,
optional_metadata=all_optional_metadata if all_optional_metadata else None,
)
return metadata


def get_sai_model_spec_dataclass(
state_dict: dict,
args: argparse.Namespace,
sdxl: bool,
lora: bool,
textual_inversion: bool,
is_stable_diffusion_ckpt: Optional[bool] = None,
sd3: str = None,
flux: str = None,
lumina: str = None,
optional_metadata: dict[str, str] | None = None
) -> sai_model_spec.ModelSpecMetadata:
"""
Get ModelSpec metadata as a dataclass - preferred for new code.
Automatically extracts metadata_* fields from args.
"""
timestamp = time.time()

v2 = args.v2
v_parameterization = args.v_parameterization
reso = args.resolution

title = args.metadata_title if args.metadata_title is not None else args.output_name

if args.min_timestep is not None or args.max_timestep is not None:
min_time_step = args.min_timestep if args.min_timestep is not None else 0
max_time_step = args.max_timestep if args.max_timestep is not None else 1000
timesteps = (min_time_step, max_time_step)
else:
timesteps = None

# Convert individual model parameters to model_config dict
model_config = {}
if sd3 is not None:
model_config["sd3"] = sd3
if flux is not None:
model_config["flux"] = flux
if lumina is not None:
model_config["lumina"] = lumina

# Use the dataclass function directly
return sai_model_spec.build_metadata_dataclass(
state_dict,
v2,
v_parameterization,
sdxl,
lora,
textual_inversion,
timestamp,
title=title,
reso=reso,
is_stable_diffusion_ckpt=is_stable_diffusion_ckpt,
author=args.metadata_author,
description=args.metadata_description,
license=args.metadata_license,
tags=args.metadata_tags,
timesteps=timesteps,
clip_skip=args.clip_skip,
model_config=model_config,
optional_metadata=optional_metadata,
)


def add_sd_models_arguments(parser: argparse.ArgumentParser):
# for pretrained models
parser.add_argument(
@@ -4077,39 +4194,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--output_config", action="store_true", help="output command line args to given .toml file / 引数を.tomlファイルに出力する"
)

# SAI Model spec
parser.add_argument(
"--metadata_title",
type=str,
default=None,
help="title for model metadata (default is output_name) / メタデータに書き込まれるモデルタイトル、省略時はoutput_name",
)
parser.add_argument(
"--metadata_author",
type=str,
default=None,
help="author name for model metadata / メタデータに書き込まれるモデル作者名",
)
parser.add_argument(
"--metadata_description",
type=str,
default=None,
help="description for model metadata / メタデータに書き込まれるモデル説明",
)
parser.add_argument(
"--metadata_license",
type=str,
default=None,
help="license for model metadata / メタデータに書き込まれるモデルライセンス",
)
parser.add_argument(
"--metadata_tags",
type=str,
default=None,
help="tags for model metadata, separated by comma / メタデータに書き込まれるモデルタグ、カンマ区切り",
)

if support_dreambooth:
# DreamBooth training
parser.add_argument(
@@ -4495,7 +4579,13 @@ def add_dataset_arguments(
action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します",
)

parser.add_argument(
"--resize_interpolation",
type=str,
default=None,
choices=["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area"],
help="Resize interpolation when required. Default: area Options: lanczos, nearest, bilinear, bicubic, area / 必要に応じてサイズ補間を変更します。デフォルト: area オプション: lanczos, nearest, bilinear, bicubic, area",
)
parser.add_argument(
"--token_warmup_min",
type=int,
@@ -5468,6 +5558,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio


def patch_accelerator_for_fp16_training(accelerator):
from accelerate import DistributedType
if accelerator.distributed_type == DistributedType.DEEPSPEED:
return
org_unscale_grads = accelerator.scaler._unscale_grads_

def _unscale_grads_replacer(optimizer, inv_scale, found_inf, allow_fp16):
@@ -5973,6 +6068,9 @@ def get_noise_noisy_latents_and_timesteps(
else:
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

# This moves the alphas_cumprod back to the CPU after it is moved in noise_scheduler.add_noise
noise_scheduler.alphas_cumprod = noise_scheduler.alphas_cumprod.cpu()

return noise, noisy_latents, timesteps


@@ -6151,6 +6249,11 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["scale"] = float(m.group(1))
continue

m = re.match(r"g ([\d\.]+)", parg, re.IGNORECASE)
if m: # guidance scale
prompt_dict["guidance_scale"] = float(m.group(1))
continue

m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
prompt_dict["negative_prompt"] = m.group(1)
@@ -6166,6 +6269,17 @@ def line_to_prompt_dict(line: str) -> dict:
prompt_dict["controlnet_image"] = m.group(1)
continue

m = re.match(r"ctr (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["cfg_trunc_ratio"] = float(m.group(1))
continue

m = re.match(r"rcfg (.+)", parg, re.IGNORECASE)
if m:
prompt_dict["renorm_cfg"] = float(m.group(1))
continue


except ValueError as ex:
logger.error(f"Exception in parsing / 解析エラー: {parg}")
logger.error(ex)
@@ -6533,3 +6647,4 @@ class LossRecorder:
if losses == 0:
return 0
return self.loss_total / losses


+ 115
- 3
scripts/dev/library/utils.py View File

@@ -16,7 +16,6 @@ from PIL import Image
import numpy as np
from safetensors.torch import load_file


def fire_in_thread(f, *args, **kwargs):
threading.Thread(target=f, args=args, kwargs=kwargs).start()

@@ -89,6 +88,8 @@ def setup_logging(args=None, log_level=None, reset=False):
logger = logging.getLogger(__name__)
logger.info(msg_init)

setup_logging()
logger = logging.getLogger(__name__)

# endregion

@@ -378,7 +379,7 @@ def load_safetensors(
# region Image utils


def pil_resize(image, size, interpolation=Image.LANCZOS):
def pil_resize(image, size, interpolation):
has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False

if has_alpha:
@@ -386,7 +387,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
else:
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

resized_pil = pil_image.resize(size, interpolation)
resized_pil = pil_image.resize(size, resample=interpolation)

# Convert back to cv2 format
if has_alpha:
@@ -397,6 +398,117 @@ def pil_resize(image, size, interpolation=Image.LANCZOS):
return resized_cv2


def resize_image(image: np.ndarray, width: int, height: int, resized_width: int, resized_height: int, resize_interpolation: Optional[str] = None):
"""
Resize image with resize interpolation. Default interpolation to AREA if image is smaller, else LANCZOS.

Args:
image: numpy.ndarray
width: int Original image width
height: int Original image height
resized_width: int Resized image width
resized_height: int Resized image height
resize_interpolation: Optional[str] Resize interpolation method "lanczos", "area", "bilinear", "bicubic", "nearest", "box"

Returns:
image
"""

# Ensure all size parameters are actual integers
width = int(width)
height = int(height)
resized_width = int(resized_width)
resized_height = int(resized_height)

if resize_interpolation is None:
if width >= resized_width and height >= resized_height:
resize_interpolation = "area"
else:
resize_interpolation = "lanczos"

# we use PIL for lanczos (for backward compatibility) and box, cv2 for others
use_pil = resize_interpolation in ["lanczos", "lanczos4", "box"]

resized_size = (resized_width, resized_height)
if use_pil:
interpolation = get_pil_interpolation(resize_interpolation)
image = pil_resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (PIL)")
else:
interpolation = get_cv2_interpolation(resize_interpolation)
image = cv2.resize(image, resized_size, interpolation=interpolation)
logger.debug(f"resize image using {resize_interpolation} (cv2)")

return image


def get_cv2_interpolation(interpolation: Optional[str]) -> Optional[int]:
"""
Convert interpolation value to cv2 interpolation integer

https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga5bb5a1fea74ea38e1a5445ca803ff121
"""
if interpolation is None:
return None

if interpolation == "lanczos" or interpolation == "lanczos4":
# Lanczos interpolation over 8x8 neighborhood
return cv2.INTER_LANCZOS4
elif interpolation == "nearest":
# Bit exact nearest neighbor interpolation. This will produce same results as the nearest neighbor method in PIL, scikit-image or Matlab.
return cv2.INTER_NEAREST_EXACT
elif interpolation == "bilinear" or interpolation == "linear":
# bilinear interpolation
return cv2.INTER_LINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# bicubic interpolation
return cv2.INTER_CUBIC
elif interpolation == "area":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
elif interpolation == "box":
# resampling using pixel area relation. It may be a preferred method for image decimation, as it gives moire'-free results. But when the image is zoomed, it is similar to the INTER_NEAREST method.
return cv2.INTER_AREA
else:
return None

def get_pil_interpolation(interpolation: Optional[str]) -> Optional[Image.Resampling]:
"""
Convert interpolation value to PIL interpolation

https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-filters
"""
if interpolation is None:
return None

if interpolation == "lanczos":
return Image.Resampling.LANCZOS
elif interpolation == "nearest":
# Pick one nearest pixel from the input image. Ignore all other input pixels.
return Image.Resampling.NEAREST
elif interpolation == "bilinear" or interpolation == "linear":
# For resize calculate the output pixel value using linear interpolation on all pixels that may contribute to the output value. For other transformations linear interpolation over a 2x2 environment in the input image is used.
return Image.Resampling.BILINEAR
elif interpolation == "bicubic" or interpolation == "cubic":
# For resize calculate the output pixel value using cubic interpolation on all pixels that may contribute to the output value. For other transformations cubic interpolation over a 4x4 environment in the input image is used.
return Image.Resampling.BICUBIC
elif interpolation == "area":
# Image.Resampling.BOX may be more appropriate if upscaling
# Area interpolation is related to cv2.INTER_AREA
# Produces a sharper image than Resampling.BILINEAR, doesn’t have dislocations on local level like with Resampling.BOX.
return Image.Resampling.HAMMING
elif interpolation == "box":
# Each pixel of source image contributes to one pixel of the destination image with identical weights. For upscaling is equivalent of Resampling.NEAREST.
return Image.Resampling.BOX
else:
return None

def validate_interpolation_fn(interpolation_str: str) -> bool:
"""
Check if a interpolation function is supported
"""
return interpolation_str in ["lanczos", "nearest", "bilinear", "linear", "bicubic", "cubic", "area", "box"]

# endregion

# TODO make inf_utils.py


+ 418
- 0
scripts/dev/lumina_minimal_inference.py View File

@@ -0,0 +1,418 @@
# Minimum Inference Code for Lumina
# Based on flux_minimal_inference.py

import logging
import argparse
import math
import os
import random
import time
from typing import Optional

import einops
import numpy as np
import torch
from accelerate import Accelerator
from PIL import Image
from safetensors.torch import load_file
from tqdm import tqdm
from transformers import Gemma2Model
from library.flux_models import AutoEncoder

from library import (
device_utils,
lumina_models,
lumina_train_util,
lumina_util,
sd3_train_utils,
strategy_lumina,
)
import networks.lora_lumina as lora_lumina
from library.device_utils import get_preferred_device, init_ipex
from library.utils import setup_logging, str_to_dtype

init_ipex()
setup_logging()
logger = logging.getLogger(__name__)


def generate_image(
model: lumina_models.NextDiT,
gemma2: Gemma2Model,
ae: AutoEncoder,
prompt: str,
system_prompt: str,
seed: Optional[int],
image_width: int,
image_height: int,
steps: int,
guidance_scale: float,
negative_prompt: Optional[str],
args: argparse.Namespace,
cfg_trunc_ratio: float = 0.25,
renorm_cfg: float = 1.0,
):
#
# 0. Prepare arguments
#
device = get_preferred_device()
if args.device:
device = torch.device(args.device)

dtype = str_to_dtype(args.dtype)
ae_dtype = str_to_dtype(args.ae_dtype)
gemma2_dtype = str_to_dtype(args.gemma2_dtype)

#
# 1. Prepare models
#
# model.to(device, dtype=dtype)
model.to(dtype)
model.eval()

gemma2.to(device, dtype=gemma2_dtype)
gemma2.eval()

ae.to(ae_dtype)
ae.eval()

#
# 2. Encode prompts
#
logger.info("Encoding prompts...")

tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(system_prompt, args.gemma2_max_token_length)
encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()

tokens_and_masks = tokenize_strategy.tokenize(prompt)
with torch.no_grad():
gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)

tokens_and_masks = tokenize_strategy.tokenize(
negative_prompt, is_negative=True and not args.add_system_prompt_to_negative_prompt
)
with torch.no_grad():
neg_gemma2_conds = encoding_strategy.encode_tokens(tokenize_strategy, [gemma2], tokens_and_masks)

# Unpack Gemma2 outputs
prompt_hidden_states, _, prompt_attention_mask = gemma2_conds
uncond_hidden_states, _, uncond_attention_mask = neg_gemma2_conds

if args.offload:
print("Offloading models to CPU to save VRAM...")
gemma2.to("cpu")
device_utils.clean_memory()

model.to(device)

#
# 3. Prepare latents
#
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {seed}")
torch.manual_seed(seed)

latent_height = image_height // 8
latent_width = image_width // 8
latent_channels = 16

latents = torch.randn(
(1, latent_channels, latent_height, latent_width),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)

#
# 4. Denoise
#
logger.info("Denoising...")
scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
scheduler.set_timesteps(steps, device=device)
timesteps = scheduler.timesteps

# # compare with lumina_train_util.retrieve_timesteps
# lumina_timestep = lumina_train_util.retrieve_timesteps(scheduler, num_inference_steps=steps)
# print(f"Using timesteps: {timesteps}")
# print(f"vs Lumina timesteps: {lumina_timestep}") # should be the same

with torch.autocast(device_type=device.type, dtype=dtype), torch.no_grad():
latents = lumina_train_util.denoise(
scheduler,
model,
latents.to(device),
prompt_hidden_states.to(device),
prompt_attention_mask.to(device),
uncond_hidden_states.to(device),
uncond_attention_mask.to(device),
timesteps,
guidance_scale,
cfg_trunc_ratio,
renorm_cfg,
)

if args.offload:
model.to("cpu")
device_utils.clean_memory()
ae.to(device)

#
# 5. Decode latents
#
logger.info("Decoding image...")
# latents = latents / ae.scale_factor + ae.shift_factor
with torch.no_grad():
image = ae.decode(latents.to(ae_dtype))
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")

#
# 6. Save image
#
pil_image = Image.fromarray(image[0])
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime())
seed_suffix = f"_{seed}"
output_path = os.path.join(output_dir, f"image_{ts_str}{seed_suffix}.png")
pil_image.save(output_path)
logger.info(f"Image saved to {output_path}")


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Lumina DiT model path / Lumina DiTモデルのパス",
)
parser.add_argument(
"--gemma2_path",
type=str,
default=None,
required=True,
help="Gemma2 model path / Gemma2モデルのパス",
)
parser.add_argument(
"--ae_path",
type=str,
default=None,
required=True,
help="Autoencoder model path / Autoencoderモデルのパス",
)
parser.add_argument("--prompt", type=str, default="A beautiful sunset over the mountains", help="Prompt for image generation")
parser.add_argument("--negative_prompt", type=str, default="", help="Negative prompt for image generation, default is empty")
parser.add_argument("--output_dir", type=str, default="outputs", help="Output directory for generated images")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--steps", type=int, default=36, help="Number of inference steps")
parser.add_argument("--guidance_scale", type=float, default=3.5, help="Guidance scale for classifier-free guidance")
parser.add_argument("--image_width", type=int, default=1024, help="Image width")
parser.add_argument("--image_height", type=int, default=1024, help="Image height")
parser.add_argument("--dtype", type=str, default="bf16", help="Data type for model (bf16, fp16, float)")
parser.add_argument("--gemma2_dtype", type=str, default="bf16", help="Data type for Gemma2 (bf16, fp16, float)")
parser.add_argument("--ae_dtype", type=str, default="bf16", help="Data type for Autoencoder (bf16, fp16, float)")
parser.add_argument("--device", type=str, default=None, help="Device to use (e.g., 'cuda:0')")
parser.add_argument("--offload", action="store_true", help="Offload models to CPU to save VRAM")
parser.add_argument("--system_prompt", type=str, default="", help="System prompt for Gemma2 model")
parser.add_argument("--add_system_prompt_to_negative_prompt", action="store_true", help="Add system prompt to negative prompt")
parser.add_argument(
"--gemma2_max_token_length",
type=int,
default=256,
help="Max token length for Gemma2 tokenizer",
)
parser.add_argument(
"--discrete_flow_shift",
type=float,
default=6.0,
help="Shift value for FlowMatchEulerDiscreteScheduler",
)
parser.add_argument(
"--cfg_trunc_ratio",
type=float,
default=0.25,
help="The ratio of the timestep interval to apply normalization-based guidance scale. For example, 0.25 means the first 25%% of timesteps will be guided.",
)
parser.add_argument(
"--renorm_cfg",
type=float,
default=1.0,
help="The factor to limit the maximum norm after guidance. Default: 1.0, 0.0 means no renormalization.",
)
parser.add_argument(
"--use_flash_attn",
action="store_true",
help="Use flash attention for Lumina model",
)
parser.add_argument(
"--use_sage_attn",
action="store_true",
help="Use sage attention for Lumina model",
)
parser.add_argument(
"--lora_weights",
type=str,
nargs="*",
default=[],
help="LoRA weights, each argument is a `path;multiplier` (semi-colon separated)",
)
parser.add_argument("--merge_lora_weights", action="store_true", help="Merge LoRA weights to model")
parser.add_argument(
"--interactive",
action="store_true",
help="Enable interactive mode for generating multiple images / 対話モードで複数の画像を生成する",
)
return parser


if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()

logger.info("Loading models...")
device = get_preferred_device()
if args.device:
device = torch.device(args.device)

# Load Lumina DiT model
model = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
dtype=None, # Load in fp32 and then convert
device="cpu",
use_flash_attn=args.use_flash_attn,
use_sage_attn=args.use_sage_attn,
)

# Load Gemma2
gemma2 = lumina_util.load_gemma2(args.gemma2_path, dtype=None, device="cpu")

# Load Autoencoder
ae = lumina_util.load_ae(args.ae_path, dtype=None, device="cpu")

# LoRA
lora_models = []
for weights_file in args.lora_weights:
if ";" in weights_file:
weights_file, multiplier = weights_file.split(";")
multiplier = float(multiplier)
else:
multiplier = 1.0

weights_sd = load_file(weights_file)
lora_model, _ = lora_lumina.create_network_from_weights(multiplier, None, ae, [gemma2], model, weights_sd, True)

if args.merge_lora_weights:
lora_model.merge_to([gemma2], model, weights_sd)
else:
lora_model.apply_to([gemma2], model)
info = lora_model.load_state_dict(weights_sd, strict=True)
logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
lora_model.to(device)
lora_model.set_multiplier(multiplier)
lora_model.eval()

lora_models.append(lora_model)

if not args.interactive:
generate_image(
model,
gemma2,
ae,
args.prompt,
args.system_prompt,
args.seed,
args.image_width,
args.image_height,
args.steps,
args.guidance_scale,
args.negative_prompt,
args,
args.cfg_trunc_ratio,
args.renorm_cfg,
)
else:
# Interactive mode loop
image_width = args.image_width
image_height = args.image_height
steps = args.steps
guidance_scale = args.guidance_scale
cfg_trunc_ratio = args.cfg_trunc_ratio
renorm_cfg = args.renorm_cfg

print("Entering interactive mode.")
while True:
print(
"\nEnter prompt (or 'exit'). Options: --w <int> --h <int> --s <int> --d <int> --g <float> --n <str> --ctr <float> --rcfg <float> --m <m1,m2...>"
)
user_input = input()
if user_input.lower() == "exit":
break
if not user_input:
continue

# Parse options
options = user_input.split("--")
prompt = options[0].strip()

# Set defaults for each generation
seed = None # New random seed each time unless specified
negative_prompt = args.negative_prompt # Reset to default

for opt in options[1:]:
try:
opt = opt.strip()
if not opt:
continue

key, value = (opt.split(None, 1) + [""])[:2]

if key == "w":
image_width = int(value)
elif key == "h":
image_height = int(value)
elif key == "s":
steps = int(value)
elif key == "d":
seed = int(value)
elif key == "g":
guidance_scale = float(value)
elif key == "n":
negative_prompt = value if value != "-" else ""
elif key == "ctr":
cfg_trunc_ratio = float(value)
elif key == "rcfg":
renorm_cfg = float(value)
elif key == "m":
multipliers = value.split(",")
if len(multipliers) != len(lora_models):
logger.error(f"Invalid number of multipliers, expected {len(lora_models)}")
continue
for i, lora_model in enumerate(lora_models):
lora_model.set_multiplier(float(multipliers[i].strip()))
else:
logger.warning(f"Unknown option: --{key}")

except (ValueError, IndexError) as e:
logger.error(f"Invalid value for option --{key}: '{value}'. Error: {e}")

generate_image(
model,
gemma2,
ae,
prompt,
args.system_prompt,
seed,
image_width,
image_height,
steps,
guidance_scale,
negative_prompt,
args,
cfg_trunc_ratio,
renorm_cfg,
)

logger.info("Done.")

+ 957
- 0
scripts/dev/lumina_train.py View File

@@ -0,0 +1,957 @@
# training with captions

# Swap blocks between CPU and GPU:
# This implementation is inspired by and based on the work of 2kpr.
# Many thanks to 2kpr for the original concept and implementation of memory-efficient offloading.
# The original idea has been adapted and extended to fit the current project's needs.

# Key features:
# - CPU offloading during forward and backward passes
# - Use of fused optimizer and grad_hook for efficient gradient processing
# - Per-block fused optimizer instances

import argparse
import copy
import math
import os
from multiprocessing import Value
import toml

from tqdm import tqdm

import torch
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from accelerate.utils import set_seed
from library import (
deepspeed_utils,
lumina_train_util,
lumina_util,
strategy_base,
strategy_lumina,
sai_model_spec
)
from library.sd3_train_utils import FlowMatchEulerDiscreteScheduler

import library.train_util as train_util

from library.utils import setup_logging, add_logging_arguments

setup_logging()
import logging

logger = logging.getLogger(__name__)

import library.config_util as config_util

# import library.sdxl_train_util as sdxl_train_util
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
)
from library.custom_train_functions import apply_masked_loss, add_custom_train_arguments


def train(args):
train_util.verify_training_args(args)
train_util.prepare_dataset_args(args, True)
# sdxl_train_util.verify_sdxl_training_args(args)
deepspeed_utils.prepare_deepspeed_args(args)
setup_logging(args, reset=True)

# temporary: backward compatibility for deprecated options. remove in the future
if not args.skip_cache_check:
args.skip_cache_check = args.skip_latents_validity_check

# assert (
# not args.weighted_captions
# ), "weighted_captions is not supported currently / weighted_captionsは現在サポートされていません"
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning(
"cache_text_encoder_outputs_to_disk is enabled, so cache_text_encoder_outputs is also enabled / cache_text_encoder_outputs_to_diskが有効になっているため、cache_text_encoder_outputsも有効になります"
)
args.cache_text_encoder_outputs = True

if args.cpu_offload_checkpointing and not args.gradient_checkpointing:
logger.warning(
"cpu_offload_checkpointing is enabled, so gradient_checkpointing is also enabled / cpu_offload_checkpointingが有効になっているため、gradient_checkpointingも有効になります"
)
args.gradient_checkpointing = True

# assert (
# args.blocks_to_swap is None or args.blocks_to_swap == 0
# ) or not args.cpu_offload_checkpointing, "blocks_to_swap is not supported with cpu_offload_checkpointing / blocks_to_swapはcpu_offload_checkpointingと併用できません"

cache_latents = args.cache_latents
use_dreambooth_method = args.in_json is None

if args.seed is not None:
set_seed(args.seed) # 乱数系列を初期化する

# prepare caching strategy: this must be set before preparing dataset. because dataset may use this strategy for initialization.
if args.cache_latents:
latents_caching_strategy = strategy_lumina.LuminaLatentsCachingStrategy(
args.cache_latents_to_disk, args.vae_batch_size, args.skip_cache_check
)
strategy_base.LatentsCachingStrategy.set_strategy(latents_caching_strategy)

# データセットを準備する
if args.dataset_class is None:
blueprint_generator = BlueprintGenerator(
ConfigSanitizer(True, True, args.masked_loss, True)
)
if args.dataset_config is not None:
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_util.load_user_config(args.dataset_config)
ignored = ["train_data_dir", "in_json"]
if any(getattr(args, attr) is not None for attr in ignored):
logger.warning(
"ignore following options because config file is found: {0} / 設定ファイルが利用されるため以下のオプションは無視されます: {0}".format(
", ".join(ignored)
)
)
else:
if use_dreambooth_method:
logger.info("Using DreamBooth method.")
user_config = {
"datasets": [
{
"subsets": config_util.generate_dreambooth_subsets_config_by_subdirs(
args.train_data_dir, args.reg_data_dir
)
}
]
}
else:
logger.info("Training with captions.")
user_config = {
"datasets": [
{
"subsets": [
{
"image_dir": args.train_data_dir,
"metadata_file": args.in_json,
}
]
}
]
}

blueprint = blueprint_generator.generate(user_config, args)
train_dataset_group, val_dataset_group = (
config_util.generate_dataset_group_by_blueprint(blueprint.dataset_group)
)
else:
train_dataset_group = train_util.load_arbitrary_dataset(args)
val_dataset_group = None

current_epoch = Value("i", 0)
current_step = Value("i", 0)
ds_for_collator = (
train_dataset_group if args.max_data_loader_n_workers == 0 else None
)
collator = train_util.collator_class(current_epoch, current_step, ds_for_collator)

train_dataset_group.verify_bucket_reso_steps(16) # TODO これでいいか確認

if args.debug_dataset:
if args.cache_text_encoder_outputs:
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
False,
)
)
strategy_base.TokenizeStrategy.set_strategy(
strategy_lumina.LuminaTokenizeStrategy(args.system_prompt)
)

train_dataset_group.set_current_strategies()
train_util.debug_dataset(train_dataset_group, True)
return
if len(train_dataset_group) == 0:
logger.error(
"No data found. Please verify the metadata file and train_data_dir option. / 画像がありません。メタデータおよびtrain_data_dirオプションを確認してください。"
)
return

if cache_latents:
assert (
train_dataset_group.is_latent_cacheable()
), "when caching latents, either color_aug or random_crop cannot be used / latentをキャッシュするときはcolor_augとrandom_cropは使えません"

if args.cache_text_encoder_outputs:
assert (
train_dataset_group.is_text_encoder_output_cacheable()
), "when caching text encoder output, either caption_dropout_rate, shuffle_caption, token_warmup_step or caption_tag_dropout_rate cannot be used / text encoderの出力をキャッシュするときはcaption_dropout_rate, shuffle_caption, token_warmup_step, caption_tag_dropout_rateは使えません"

# acceleratorを準備する
logger.info("prepare accelerator")
accelerator = train_util.prepare_accelerator(args)

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, save_dtype = train_util.prepare_dtype(args)

# モデルを読み込む

# load VAE for caching latents
ae = None
if cache_latents:
ae = lumina_util.load_ae(
args.ae, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
ae.to(accelerator.device, dtype=weight_dtype)
ae.requires_grad_(False)
ae.eval()

train_dataset_group.new_cache_latents(ae, accelerator)

ae.to("cpu") # if no sampling, vae can be deleted
clean_memory_on_device(accelerator.device)

accelerator.wait_for_everyone()

# prepare tokenize strategy
if args.gemma2_max_token_length is None:
gemma2_max_token_length = 256
else:
gemma2_max_token_length = args.gemma2_max_token_length

lumina_tokenize_strategy = strategy_lumina.LuminaTokenizeStrategy(
args.system_prompt, gemma2_max_token_length
)
strategy_base.TokenizeStrategy.set_strategy(lumina_tokenize_strategy)

# load gemma2 for caching text encoder outputs
gemma2 = lumina_util.load_gemma2(
args.gemma2, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)
gemma2.eval()
gemma2.requires_grad_(False)

text_encoding_strategy = strategy_lumina.LuminaTextEncodingStrategy()
strategy_base.TextEncodingStrategy.set_strategy(text_encoding_strategy)

# cache text encoder outputs
sample_prompts_te_outputs = None
if args.cache_text_encoder_outputs:
# Text Encodes are eval and no grad here
gemma2.to(accelerator.device)

text_encoder_caching_strategy = (
strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
False,
False,
)
)
strategy_base.TextEncoderOutputsCachingStrategy.set_strategy(
text_encoder_caching_strategy
)

with accelerator.autocast():
train_dataset_group.new_cache_text_encoder_outputs([gemma2], accelerator)

# cache sample prompt's embeddings to free text encoder's memory
if args.sample_prompts is not None:
logger.info(
f"cache Text Encoder outputs for sample prompt: {args.sample_prompts}"
)

text_encoding_strategy: strategy_lumina.LuminaTextEncodingStrategy = (
strategy_base.TextEncodingStrategy.get_strategy()
)

prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in prompts:
for i, p in enumerate([
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]):
if p not in sample_prompts_te_outputs:
logger.info(f"cache Text Encoder outputs for prompt: {p}")
tokens_and_masks = lumina_tokenize_strategy.tokenize(p, i == 1) # i == 1 means negative prompt
sample_prompts_te_outputs[p] = (
text_encoding_strategy.encode_tokens(
lumina_tokenize_strategy,
[gemma2],
tokens_and_masks,
)
)

accelerator.wait_for_everyone()

# now we can delete Text Encoders to free memory
gemma2 = None
clean_memory_on_device(accelerator.device)

# load lumina
nextdit = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
weight_dtype,
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
)

if args.gradient_checkpointing:
nextdit.enable_gradient_checkpointing(
cpu_offload=args.cpu_offload_checkpointing
)

nextdit.requires_grad_(True)

# block swap

# backward compatibility
# if args.blocks_to_swap is None:
# blocks_to_swap = args.double_blocks_to_swap or 0
# if args.single_blocks_to_swap is not None:
# blocks_to_swap += args.single_blocks_to_swap // 2
# if blocks_to_swap > 0:
# logger.warning(
# "double_blocks_to_swap and single_blocks_to_swap are deprecated. Use blocks_to_swap instead."
# " / double_blocks_to_swapとsingle_blocks_to_swapは非推奨です。blocks_to_swapを使ってください。"
# )
# logger.info(
# f"double_blocks_to_swap={args.double_blocks_to_swap} and single_blocks_to_swap={args.single_blocks_to_swap} are converted to blocks_to_swap={blocks_to_swap}."
# )
# args.blocks_to_swap = blocks_to_swap
# del blocks_to_swap

# is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0
# if is_swapping_blocks:
# # Swap blocks between CPU and GPU to reduce memory usage, in forward and backward passes.
# # This idea is based on 2kpr's great work. Thank you!
# logger.info(f"enable block swap: blocks_to_swap={args.blocks_to_swap}")
# flux.enable_block_swap(args.blocks_to_swap, accelerator.device)

if not cache_latents:
# load VAE here if not cached
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")
ae.requires_grad_(False)
ae.eval()
ae.to(accelerator.device, dtype=weight_dtype)

training_models = []
params_to_optimize = []
training_models.append(nextdit)
name_and_params = list(nextdit.named_parameters())
# single param group for now
params_to_optimize.append(
{"params": [p for _, p in name_and_params], "lr": args.learning_rate}
)
param_names = [[n for n, _ in name_and_params]]

# calculate number of trainable parameters
n_params = 0
for group in params_to_optimize:
for p in group["params"]:
n_params += p.numel()

accelerator.print(f"number of trainable parameters: {n_params}")

# 学習に必要なクラスを準備する
accelerator.print("prepare optimizer, data loader etc.")

if args.blockwise_fused_optimizers:
# fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
# Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each block of parameters.
# This balances memory usage and management complexity.

# split params into groups. currently different learning rates are not supported
grouped_params = []
param_group = {}
for group in params_to_optimize:
named_parameters = list(nextdit.named_parameters())
assert len(named_parameters) == len(
group["params"]
), "number of parameters does not match"
for p, np in zip(group["params"], named_parameters):
# determine target layer and block index for each parameter
block_type = "other" # double, single or other
if np[0].startswith("double_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "double"
elif np[0].startswith("single_blocks"):
block_index = int(np[0].split(".")[1])
block_type = "single"
else:
block_index = -1

param_group_key = (block_type, block_index)
if param_group_key not in param_group:
param_group[param_group_key] = []
param_group[param_group_key].append(p)

block_types_and_indices = []
for param_group_key, param_group in param_group.items():
block_types_and_indices.append(param_group_key)
grouped_params.append({"params": param_group, "lr": args.learning_rate})

num_params = 0
for p in param_group:
num_params += p.numel()
accelerator.print(f"block {param_group_key}: {num_params} parameters")

# prepare optimizers for each group
optimizers = []
for group in grouped_params:
_, _, optimizer = train_util.get_optimizer(args, trainable_params=[group])
optimizers.append(optimizer)
optimizer = optimizers[0] # avoid error in the following code

logger.info(
f"using {len(optimizers)} optimizers for blockwise fused optimizers"
)

if train_util.is_schedulefree_optimizer(optimizers[0], args):
raise ValueError(
"Schedule-free optimizer is not supported with blockwise fused optimizers"
)
optimizer_train_fn = lambda: None # dummy function
optimizer_eval_fn = lambda: None # dummy function
else:
_, _, optimizer = train_util.get_optimizer(
args, trainable_params=params_to_optimize
)
optimizer_train_fn, optimizer_eval_fn = train_util.get_optimizer_train_eval_fn(
optimizer, args
)

# prepare dataloader
# strategies are set here because they cannot be referenced in another process. Copy them with the dataset
# some strategies can be None
train_dataset_group.set_current_strategies()

# DataLoaderのプロセス数:0 は persistent_workers が使えないので注意
n_workers = min(
args.max_data_loader_n_workers, os.cpu_count()
) # cpu_count or max_data_loader_n_workers
train_dataloader = torch.utils.data.DataLoader(
train_dataset_group,
batch_size=1,
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
persistent_workers=args.persistent_data_loader_workers,
)

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
len(train_dataloader)
/ accelerator.num_processes
/ args.gradient_accumulation_steps
)
accelerator.print(
f"override steps. steps for {args.max_train_epochs} epochs is / 指定エポックまでのステップ数: {args.max_train_steps}"
)

# データセット側にも学習ステップを送信
train_dataset_group.set_max_train_steps(args.max_train_steps)

# lr schedulerを用意する
if args.blockwise_fused_optimizers:
# prepare lr schedulers for each optimizer
lr_schedulers = [
train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes)
for optimizer in optimizers
]
lr_scheduler = lr_schedulers[0] # avoid error in the following code
else:
lr_scheduler = train_util.get_scheduler_fix(
args, optimizer, accelerator.num_processes
)

# 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする
if args.full_fp16:
assert (
args.mixed_precision == "fp16"
), "full_fp16 requires mixed precision='fp16' / full_fp16を使う場合はmixed_precision='fp16'を指定してください。"
accelerator.print("enable full fp16 training.")
nextdit.to(weight_dtype)
if gemma2 is not None:
gemma2.to(weight_dtype)
elif args.full_bf16:
assert (
args.mixed_precision == "bf16"
), "full_bf16 requires mixed precision='bf16' / full_bf16を使う場合はmixed_precision='bf16'を指定してください。"
accelerator.print("enable full bf16 training.")
nextdit.to(weight_dtype)
if gemma2 is not None:
gemma2.to(weight_dtype)

# if we don't cache text encoder outputs, move them to device
if not args.cache_text_encoder_outputs:
gemma2.to(accelerator.device)

clean_memory_on_device(accelerator.device)

is_swapping_blocks = args.blocks_to_swap is not None and args.blocks_to_swap > 0

if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, nextdit=nextdit)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
training_models = [ds_model]

else:
# accelerator does some magic
# if we doesn't swap blocks, we can move the model to device
nextdit = accelerator.prepare(
nextdit, device_placement=[not is_swapping_blocks]
)
if is_swapping_blocks:
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(
accelerator.device
) # reduce peak memory usage
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
optimizer, train_dataloader, lr_scheduler
)

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
# During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do.
# -> But we think it's ok to patch accelerator even if deepspeed is enabled.
train_util.patch_accelerator_for_fp16_training(accelerator)

# resumeする
train_util.resume_from_local_or_hf_if_specified(accelerator, args)

if args.fused_backward_pass:
# use fused optimizer for backward pass: other optimizers will be supported in the future
import library.adafactor_fused

library.adafactor_fused.patch_adafactor_fused(optimizer)

for param_group, param_name_group in zip(optimizer.param_groups, param_names):
for parameter, param_name in zip(param_group["params"], param_name_group):
if parameter.requires_grad:

def create_grad_hook(p_name, p_group):
def grad_hook(tensor: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(tensor, args.max_grad_norm)
optimizer.step_param(tensor, p_group)
tensor.grad = None

return grad_hook

parameter.register_post_accumulate_grad_hook(
create_grad_hook(param_name, param_group)
)

elif args.blockwise_fused_optimizers:
# prepare for additional optimizers and lr schedulers
for i in range(1, len(optimizers)):
optimizers[i] = accelerator.prepare(optimizers[i])
lr_schedulers[i] = accelerator.prepare(lr_schedulers[i])

# counters are used to determine when to step the optimizer
global optimizer_hooked_count
global num_parameters_per_group
global parameter_optimizer_map

optimizer_hooked_count = {}
num_parameters_per_group = [0] * len(optimizers)
parameter_optimizer_map = {}

for opt_idx, optimizer in enumerate(optimizers):
for param_group in optimizer.param_groups:
for parameter in param_group["params"]:
if parameter.requires_grad:

def grad_hook(parameter: torch.Tensor):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
accelerator.clip_grad_norm_(
parameter, args.max_grad_norm
)

i = parameter_optimizer_map[parameter]
optimizer_hooked_count[i] += 1
if optimizer_hooked_count[i] == num_parameters_per_group[i]:
optimizers[i].step()
optimizers[i].zero_grad(set_to_none=True)

parameter.register_post_accumulate_grad_hook(grad_hook)
parameter_optimizer_map[parameter] = opt_idx
num_parameters_per_group[opt_idx] += 1

# epoch数を計算する
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps
)
num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if (args.save_n_epoch_ratio is not None) and (args.save_n_epoch_ratio > 0):
args.save_every_n_epochs = (
math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1
)

# 学習する
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(
f" num examples / サンプル数: {train_dataset_group.num_train_images}"
)
accelerator.print(
f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}"
)
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(
f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}"
)
accelerator.print(
f" total optimization steps / 学習ステップ数: {args.max_train_steps}"
)

progress_bar = tqdm(
range(args.max_train_steps),
smoothing=0,
disable=not accelerator.is_local_main_process,
desc="steps",
)
global_step = 0

noise_scheduler = FlowMatchEulerDiscreteScheduler(
num_train_timesteps=1000, shift=args.discrete_flow_shift
)
noise_scheduler_copy = copy.deepcopy(noise_scheduler)

if accelerator.is_main_process:
init_kwargs = {}
if args.wandb_run_name:
init_kwargs["wandb"] = {"name": args.wandb_run_name}
if args.log_tracker_config is not None:
init_kwargs = toml.load(args.log_tracker_config)
accelerator.init_trackers(
"finetuning" if args.log_tracker_name is None else args.log_tracker_name,
config=train_util.get_sanitized_config_or_none(args),
init_kwargs=init_kwargs,
)

if is_swapping_blocks:
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()

# For --sample_at_first
optimizer_eval_fn()
lumina_train_util.sample_images(
accelerator,
args,
0,
global_step,
nextdit,
ae,
gemma2,
sample_prompts_te_outputs,
)
optimizer_train_fn()
if len(accelerator.trackers) > 0:
# log empty object to commit the sample images to wandb
accelerator.log({}, step=0)

loss_recorder = train_util.LossRecorder()
epoch = 0 # avoid error when max_train_steps is 0
for epoch in range(num_train_epochs):
accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}")
current_epoch.value = epoch + 1

for m in training_models:
m.train()

for step, batch in enumerate(train_dataloader):
current_step.value = global_step

if args.blockwise_fused_optimizers:
optimizer_hooked_count = {
i: 0 for i in range(len(optimizers))
} # reset counter for each step

with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
latents = batch["latents"].to(
accelerator.device, dtype=weight_dtype
)
else:
with torch.no_grad():
# encode images to latents. images are [-1, 1]
latents = ae.encode(batch["images"].to(ae.dtype)).to(
accelerator.device, dtype=weight_dtype
)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
accelerator.print("NaN found in latents, replacing with zeros")
latents = torch.nan_to_num(latents, 0, out=latents)

text_encoder_outputs_list = batch.get("text_encoder_outputs_list", None)
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list
else:
# not cached or training, so get from text encoders
tokens_and_masks = batch["input_ids_list"]
with torch.no_grad():
input_ids = [
ids.to(accelerator.device)
for ids in batch["input_ids_list"]
]
text_encoder_conds = text_encoding_strategy.encode_tokens(
lumina_tokenize_strategy,
[gemma2],
input_ids,
)
if args.full_fp16:
text_encoder_conds = [
c.to(weight_dtype) for c in text_encoder_conds
]

# TODO support some features for noise implemented in get_noise_noisy_latents_and_timesteps

# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)

# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = (
lumina_train_util.get_noisy_model_input_and_timesteps(
args,
noise_scheduler_copy,
latents,
noise,
accelerator.device,
weight_dtype,
)
)
# call model
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds

with accelerator.autocast():
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
model_pred = nextdit(
x=noisy_model_input, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(
dtype=torch.int32
), # Gemma2的attention mask
)
# apply model prediction type
model_pred, weighting = lumina_train_util.apply_model_prediction_type(
args, model_pred, noisy_model_input, sigmas
)

# flow matching loss
target = latents - noise

# calculate loss
huber_c = train_util.get_huber_threshold_if_needed(
args, timesteps, noise_scheduler
)
loss = train_util.conditional_loss(
model_pred.float(), target.float(), args.loss_type, "none", huber_c
)
if weighting is not None:
loss = loss * weighting
if args.masked_loss or (
"alpha_masks" in batch and batch["alpha_masks"] is not None
):
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
loss = loss * loss_weights
loss = loss.mean()

# backward
accelerator.backward(loss)

if not (args.fused_backward_pass or args.blockwise_fused_optimizers):
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
params_to_clip = []
for m in training_models:
params_to_clip.extend(m.parameters())
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
else:
# optimizer.step() and optimizer.zero_grad() are called in the optimizer hook
lr_scheduler.step()
if args.blockwise_fused_optimizers:
for i in range(1, len(optimizers)):
lr_schedulers[i].step()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1

optimizer_eval_fn()
lumina_train_util.sample_images(
accelerator,
args,
None,
global_step,
nextdit,
ae,
gemma2,
sample_prompts_te_outputs,
)

# 指定ステップごとにモデルを保存
if (
args.save_every_n_steps is not None
and global_step % args.save_every_n_steps == 0
):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
args,
False,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(nextdit),
)
optimizer_train_fn()

current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
if len(accelerator.trackers) > 0:
logs = {"loss": current_loss}
train_util.append_lr_to_logs(
logs, lr_scheduler, args.optimizer_type, including_unet=True
)

accelerator.log(logs, step=global_step)

loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

if len(accelerator.trackers) > 0:
logs = {"loss/epoch": loss_recorder.moving_average}
accelerator.log(logs, step=epoch + 1)

accelerator.wait_for_everyone()

optimizer_eval_fn()
if args.save_every_n_epochs is not None:
if accelerator.is_main_process:
lumina_train_util.save_lumina_model_on_epoch_end_or_stepwise(
args,
True,
accelerator,
save_dtype,
epoch,
num_train_epochs,
global_step,
accelerator.unwrap_model(nextdit),
)

lumina_train_util.sample_images(
accelerator,
args,
epoch + 1,
global_step,
nextdit,
ae,
gemma2,
sample_prompts_te_outputs,
)
optimizer_train_fn()

is_main_process = accelerator.is_main_process
# if is_main_process:
nextdit = accelerator.unwrap_model(nextdit)

accelerator.end_training()
optimizer_eval_fn()

if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す

if is_main_process:
lumina_train_util.save_lumina_model_on_train_end(
args, save_dtype, epoch, global_step, nextdit
)
logger.info("model saved.")


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser) # TODO split this
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)
deepspeed_utils.add_deepspeed_arguments(parser)
train_util.add_sd_saving_arguments(parser)
train_util.add_optimizer_arguments(parser)
config_util.add_config_arguments(parser)
add_custom_train_arguments(parser) # TODO remove this from here
train_util.add_dit_training_arguments(parser)
lumina_train_util.add_lumina_train_arguments(parser)

parser.add_argument(
"--mem_eff_save",
action="store_true",
help="[EXPERIMENTAL] use memory efficient custom model saving method / メモリ効率の良い独自のモデル保存方法を使う",
)

parser.add_argument(
"--fused_optimizer_groups",
type=int,
default=None,
help="**this option is not working** will be removed in the future / このオプションは動作しません。将来削除されます",
)
parser.add_argument(
"--blockwise_fused_optimizers",
action="store_true",
help="enable blockwise optimizers for fused backward pass and optimizer step / fused backward passとoptimizer step のためブロック単位のoptimizerを有効にする",
)
parser.add_argument(
"--skip_latents_validity_check",
action="store_true",
help="[Deprecated] use 'skip_cache_check' instead / 代わりに 'skip_cache_check' を使用してください",
)
parser.add_argument(
"--cpu_offload_checkpointing",
action="store_true",
help="[EXPERIMENTAL] enable offloading of tensors to CPU during checkpointing / チェックポイント時にテンソルをCPUにオフロードする",
)
return parser


if __name__ == "__main__":
parser = setup_parser()

args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

train(args)

+ 383
- 0
scripts/dev/lumina_train_network.py View File

@@ -0,0 +1,383 @@
import argparse
import copy
from typing import Any, Tuple

import torch

from library.device_utils import clean_memory_on_device, init_ipex

init_ipex()

from torch import Tensor
from accelerate import Accelerator


import train_network
from library import (
lumina_models,
lumina_util,
lumina_train_util,
sd3_train_utils,
strategy_base,
strategy_lumina,
train_util,
)
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


class LuminaNetworkTrainer(train_network.NetworkTrainer):
def __init__(self):
super().__init__()
self.sample_prompts_te_outputs = None
self.is_swapping_blocks: bool = False

def assert_extra_args(self, args, train_dataset_group, val_dataset_group):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)

if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
logger.warning("Enabling cache_text_encoder_outputs due to disk caching")
args.cache_text_encoder_outputs = True

train_dataset_group.verify_bucket_reso_steps(32)
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(32)

self.train_gemma2 = not args.network_train_unet_only

def load_target_model(self, args, weight_dtype, accelerator):
loading_dtype = None if args.fp8_base else weight_dtype

model = lumina_util.load_lumina_model(
args.pretrained_model_name_or_path,
loading_dtype,
torch.device("cpu"),
disable_mmap=args.disable_mmap_load_safetensors,
use_flash_attn=args.use_flash_attn,
use_sage_attn=args.use_sage_attn,
)

if args.fp8_base:
# check dtype of model
if model.dtype == torch.float8_e4m3fnuz or model.dtype == torch.float8_e5m2 or model.dtype == torch.float8_e5m2fnuz:
raise ValueError(f"Unsupported fp8 model dtype: {model.dtype}")
elif model.dtype == torch.float8_e4m3fn:
logger.info("Loaded fp8 Lumina 2 model")
else:
logger.info(
"Cast Lumina 2 model to fp8. This may take a while. You can reduce the time by using fp8 checkpoint."
" / Lumina 2モデルをfp8に変換しています。これには時間がかかる場合があります。fp8チェックポイントを使用することで時間を短縮できます。"
)
model.to(torch.float8_e4m3fn)

if args.blocks_to_swap:
logger.info(f"Lumina 2: Enabling block swap: {args.blocks_to_swap}")
model.enable_block_swap(args.blocks_to_swap, accelerator.device)
self.is_swapping_blocks = True

gemma2 = lumina_util.load_gemma2(args.gemma2, weight_dtype, "cpu")
gemma2.eval()
ae = lumina_util.load_ae(args.ae, weight_dtype, "cpu")

return lumina_util.MODEL_VERSION_LUMINA_V2, [gemma2], ae, model

def get_tokenize_strategy(self, args):
return strategy_lumina.LuminaTokenizeStrategy(args.system_prompt, args.gemma2_max_token_length, args.tokenizer_cache_dir)

def get_tokenizers(self, tokenize_strategy: strategy_lumina.LuminaTokenizeStrategy):
return [tokenize_strategy.tokenizer]

def get_latents_caching_strategy(self, args):
return strategy_lumina.LuminaLatentsCachingStrategy(args.cache_latents_to_disk, args.vae_batch_size, False)

def get_text_encoding_strategy(self, args):
return strategy_lumina.LuminaTextEncodingStrategy()

def get_text_encoders_train_flags(self, args, text_encoders):
return [self.train_gemma2]

def get_text_encoder_outputs_caching_strategy(self, args):
if args.cache_text_encoder_outputs:
# if the text encoders is trained, we need tokenization, so is_partial is True
return strategy_lumina.LuminaTextEncoderOutputsCachingStrategy(
args.cache_text_encoder_outputs_to_disk,
args.text_encoder_batch_size,
args.skip_cache_check,
is_partial=self.train_gemma2,
)
else:
return None

def cache_text_encoder_outputs_if_needed(
self,
args,
accelerator: Accelerator,
unet,
vae,
text_encoders,
dataset,
weight_dtype,
):
if args.cache_text_encoder_outputs:
if not args.lowram:
# メモリ消費を減らす
logger.info("move vae and unet to cpu to save memory")
org_vae_device = vae.device
org_unet_device = unet.device
vae.to("cpu")
unet.to("cpu")
clean_memory_on_device(accelerator.device)

# When TE is not be trained, it will not be prepared so we need to use explicit autocast
logger.info("move text encoders to gpu")
text_encoders[0].to(accelerator.device, dtype=weight_dtype) # always not fp8

if text_encoders[0].dtype == torch.float8_e4m3fn:
# if we load fp8 weights, the model is already fp8, so we use it as is
self.prepare_text_encoder_fp8(1, text_encoders[1], text_encoders[1].dtype, weight_dtype)
else:
# otherwise, we need to convert it to target dtype
text_encoders[0].to(weight_dtype)

with accelerator.autocast():
dataset.new_cache_text_encoder_outputs(text_encoders, accelerator)

# cache sample prompts
if args.sample_prompts is not None:
logger.info(f"cache Text Encoder outputs for sample prompts: {args.sample_prompts}")

tokenize_strategy = strategy_base.TokenizeStrategy.get_strategy()
text_encoding_strategy = strategy_base.TextEncodingStrategy.get_strategy()

assert isinstance(tokenize_strategy, strategy_lumina.LuminaTokenizeStrategy)
assert isinstance(text_encoding_strategy, strategy_lumina.LuminaTextEncodingStrategy)

sample_prompts = train_util.load_prompts(args.sample_prompts)
sample_prompts_te_outputs = {} # key: prompt, value: text encoder outputs
with accelerator.autocast(), torch.no_grad():
for prompt_dict in sample_prompts:
prompts = [
prompt_dict.get("prompt", ""),
prompt_dict.get("negative_prompt", ""),
]
for i, prompt in enumerate(prompts):
if prompt in sample_prompts_te_outputs:
continue

logger.info(f"cache Text Encoder outputs for prompt: {prompt}")
tokens_and_masks = tokenize_strategy.tokenize(prompt, i == 1) # i == 1 means negative prompt
sample_prompts_te_outputs[prompt] = text_encoding_strategy.encode_tokens(
tokenize_strategy,
text_encoders,
tokens_and_masks,
)

self.sample_prompts_te_outputs = sample_prompts_te_outputs

accelerator.wait_for_everyone()

# move back to cpu
if not self.is_train_text_encoder(args):
logger.info("move Gemma 2 back to cpu")
text_encoders[0].to("cpu")
clean_memory_on_device(accelerator.device)

if not args.lowram:
logger.info("move vae and unet back to original device")
vae.to(org_vae_device)
unet.to(org_unet_device)
else:
# Text Encoderから毎回出力を取得するので、GPUに乗せておく
text_encoders[0].to(accelerator.device, dtype=weight_dtype)

def sample_images(
self,
accelerator,
args,
epoch,
global_step,
device,
vae,
tokenizer,
text_encoder,
lumina,
):
lumina_train_util.sample_images(
accelerator,
args,
epoch,
global_step,
lumina,
vae,
self.get_models_for_text_encoding(args, accelerator, text_encoder),
self.sample_prompts_te_outputs,
)

# Remaining methods maintain similar structure to flux implementation
# with Lumina-specific model calls and strategies

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = sd3_train_utils.FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=args.discrete_flow_shift)
self.noise_scheduler_copy = copy.deepcopy(noise_scheduler)
return noise_scheduler

def encode_images_to_latents(self, args, vae, images):
return vae.encode(images)

# not sure, they use same flux vae
def shift_scale_latents(self, args, latents):
return latents

def get_noise_pred_and_target(
self,
args,
accelerator: Accelerator,
noise_scheduler,
latents,
batch,
text_encoder_conds: Tuple[Tensor, Tensor, Tensor], # (hidden_states, input_ids, attention_masks)
dit: lumina_models.NextDiT,
network,
weight_dtype,
train_unet,
is_train=True,
):
assert isinstance(noise_scheduler, sd3_train_utils.FlowMatchEulerDiscreteScheduler)
noise = torch.randn_like(latents)
# get noisy model input and timesteps
noisy_model_input, timesteps, sigmas = lumina_train_util.get_noisy_model_input_and_timesteps(
args, noise_scheduler, latents, noise, accelerator.device, weight_dtype
)

# ensure the hidden state will require grad
if args.gradient_checkpointing:
noisy_model_input.requires_grad_(True)
for t in text_encoder_conds:
if t is not None and t.dtype.is_floating_point:
t.requires_grad_(True)

# Unpack Gemma2 outputs
gemma2_hidden_states, input_ids, gemma2_attn_mask = text_encoder_conds

def call_dit(img, gemma2_hidden_states, gemma2_attn_mask, timesteps):
with torch.set_grad_enabled(is_train), accelerator.autocast():
# NextDiT forward expects (x, t, cap_feats, cap_mask)
model_pred = dit(
x=img, # image latents (B, C, H, W)
t=timesteps / 1000, # timesteps需要除以1000来匹配模型预期
cap_feats=gemma2_hidden_states, # Gemma2的hidden states作为caption features
cap_mask=gemma2_attn_mask.to(dtype=torch.int32), # Gemma2的attention mask
)
return model_pred

model_pred = call_dit(
img=noisy_model_input,
gemma2_hidden_states=gemma2_hidden_states,
gemma2_attn_mask=gemma2_attn_mask,
timesteps=timesteps,
)

# apply model prediction type
model_pred, weighting = lumina_train_util.apply_model_prediction_type(args, model_pred, noisy_model_input, sigmas)

# flow matching loss
target = latents - noise

# differential output preservation
if "custom_attributes" in batch:
diff_output_pr_indices = []
for i, custom_attributes in enumerate(batch["custom_attributes"]):
if "diff_output_preservation" in custom_attributes and custom_attributes["diff_output_preservation"]:
diff_output_pr_indices.append(i)

if len(diff_output_pr_indices) > 0:
network.set_multiplier(0.0)
with torch.no_grad():
model_pred_prior = call_dit(
img=noisy_model_input[diff_output_pr_indices],
gemma2_hidden_states=gemma2_hidden_states[diff_output_pr_indices],
timesteps=timesteps[diff_output_pr_indices],
gemma2_attn_mask=(gemma2_attn_mask[diff_output_pr_indices]),
)
network.set_multiplier(1.0)

# model_pred_prior = lumina_util.unpack_latents(
# model_pred_prior, packed_latent_height, packed_latent_width
# )
model_pred_prior, _ = lumina_train_util.apply_model_prediction_type(
args,
model_pred_prior,
noisy_model_input[diff_output_pr_indices],
sigmas[diff_output_pr_indices] if sigmas is not None else None,
)
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)

return model_pred, target, timesteps, weighting

def post_process_loss(self, loss, args, timesteps, noise_scheduler):
return loss

def get_sai_model_spec(self, args):
return train_util.get_sai_model_spec(None, args, False, True, False, lumina="lumina2")

def update_metadata(self, metadata, args):
metadata["ss_weighting_scheme"] = args.weighting_scheme
metadata["ss_logit_mean"] = args.logit_mean
metadata["ss_logit_std"] = args.logit_std
metadata["ss_mode_scale"] = args.mode_scale
metadata["ss_timestep_sampling"] = args.timestep_sampling
metadata["ss_sigmoid_scale"] = args.sigmoid_scale
metadata["ss_model_prediction_type"] = args.model_prediction_type
metadata["ss_discrete_flow_shift"] = args.discrete_flow_shift

def is_text_encoder_not_needed_for_training(self, args):
return args.cache_text_encoder_outputs and not self.is_train_text_encoder(args)

def prepare_text_encoder_grad_ckpt_workaround(self, index, text_encoder):
text_encoder.embed_tokens.requires_grad_(True)

def prepare_text_encoder_fp8(self, index, text_encoder, te_weight_dtype, weight_dtype):
logger.info(f"prepare Gemma2 for fp8: set to {te_weight_dtype}, set embeddings to {weight_dtype}")
text_encoder.to(te_weight_dtype) # fp8
text_encoder.embed_tokens.to(dtype=weight_dtype)

def prepare_unet_with_accelerator(
self, args: argparse.Namespace, accelerator: Accelerator, unet: torch.nn.Module
) -> torch.nn.Module:
if not self.is_swapping_blocks:
return super().prepare_unet_with_accelerator(args, accelerator, unet)

# if we doesn't swap blocks, we can move the model to device
nextdit = unet
assert isinstance(nextdit, lumina_models.NextDiT)
nextdit = accelerator.prepare(nextdit, device_placement=[not self.is_swapping_blocks])
accelerator.unwrap_model(nextdit).move_to_device_except_swap_blocks(accelerator.device) # reduce peak memory usage
accelerator.unwrap_model(nextdit).prepare_block_swap_before_forward()

return nextdit

def on_validation_step_end(self, args, accelerator, network, text_encoders, unet, batch, weight_dtype):
if self.is_swapping_blocks:
# prepare for next forward: because backward pass is not called, we need to prepare it here
accelerator.unwrap_model(unet).prepare_block_swap_before_forward()


def setup_parser() -> argparse.ArgumentParser:
parser = train_network.setup_parser()
train_util.add_dit_training_arguments(parser)
lumina_train_util.add_lumina_train_arguments(parser)
return parser


if __name__ == "__main__":
parser = setup_parser()
args = parser.parse_args()
train_util.verify_command_line_training_args(args)
args = train_util.read_config_from_file(args, parser)

trainer = LuminaNetworkTrainer()
trainer.train(args)

+ 310
- 23
scripts/dev/networks/lora_flux.py View File

@@ -9,11 +9,13 @@

import math
import os
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import numpy as np
import torch
from torch import Tensor
import re
from library.utils import setup_logging
from library.sdxl_original_unet import SdxlUNet2DConditionModel
@@ -44,6 +46,8 @@ class LoRAModule(torch.nn.Module):
rank_dropout=None,
module_dropout=None,
split_dims: Optional[List[int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).
@@ -103,9 +107,20 @@ class LoRAModule(torch.nn.Module):
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout

self.ggpo_sigma = ggpo_sigma
self.ggpo_beta = ggpo_beta

if self.ggpo_beta is not None and self.ggpo_sigma is not None:
self.combined_weight_norms = None
self.grad_norms = None
self.perturbation_norm_factor = 1.0 / math.sqrt(org_module.weight.shape[0])
self.initialize_norm_cache(org_module.weight)
self.org_module_shape: tuple[int] = org_module.weight.shape

def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward

del self.org_module

def forward(self, x):
@@ -140,7 +155,25 @@ class LoRAModule(torch.nn.Module):

lx = self.lora_up(lx)

return org_forwarded + lx * self.multiplier * scale
# LoRA Gradient-Guided Perturbation Optimization
if (
self.training
and self.ggpo_sigma is not None
and self.ggpo_beta is not None
and self.combined_weight_norms is not None
and self.grad_norms is not None
):
with torch.no_grad():
perturbation_scale = (self.ggpo_sigma * torch.sqrt(self.combined_weight_norms**2)) + (
self.ggpo_beta * (self.grad_norms**2)
)
perturbation_scale_factor = (perturbation_scale * self.perturbation_norm_factor).to(self.device)
perturbation = torch.randn(self.org_module_shape, dtype=self.dtype, device=self.device)
perturbation.mul_(perturbation_scale_factor)
perturbation_output = x @ perturbation.T # Result: (batch × n)
return org_forwarded + (self.multiplier * scale * lx) + perturbation_output
else:
return org_forwarded + lx * self.multiplier * scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]

@@ -167,6 +200,115 @@ class LoRAModule(torch.nn.Module):

return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale

@torch.no_grad()
def initialize_norm_cache(self, org_module_weight: Tensor):
# Choose a reasonable sample size
n_rows = org_module_weight.shape[0]
sample_size = min(1000, n_rows) # Cap at 1000 samples or use all if smaller

# Sample random indices across all rows
indices = torch.randperm(n_rows)[:sample_size]

# Convert to a supported data type first, then index
# Use float32 for indexing operations
weights_float32 = org_module_weight.to(dtype=torch.float32)
sampled_weights = weights_float32[indices].to(device=self.device)

# Calculate sampled norms
sampled_norms = torch.norm(sampled_weights, dim=1, keepdim=True)

# Store the mean norm as our estimate
self.org_weight_norm_estimate = sampled_norms.mean()

# Optional: store standard deviation for confidence intervals
self.org_weight_norm_std = sampled_norms.std()

# Free memory
del sampled_weights, weights_float32

@torch.no_grad()
def validate_norm_approximation(self, org_module_weight: Tensor, verbose=True):
# Calculate the true norm (this will be slow but it's just for validation)
true_norms = []
chunk_size = 1024 # Process in chunks to avoid OOM

for i in range(0, org_module_weight.shape[0], chunk_size):
end_idx = min(i + chunk_size, org_module_weight.shape[0])
chunk = org_module_weight[i:end_idx].to(device=self.device, dtype=self.dtype)
chunk_norms = torch.norm(chunk, dim=1, keepdim=True)
true_norms.append(chunk_norms.cpu())
del chunk

true_norms = torch.cat(true_norms, dim=0)
true_mean_norm = true_norms.mean().item()

# Compare with our estimate
estimated_norm = self.org_weight_norm_estimate.item()

# Calculate error metrics
absolute_error = abs(true_mean_norm - estimated_norm)
relative_error = absolute_error / true_mean_norm * 100 # as percentage

if verbose:
logger.info(f"True mean norm: {true_mean_norm:.6f}")
logger.info(f"Estimated norm: {estimated_norm:.6f}")
logger.info(f"Absolute error: {absolute_error:.6f}")
logger.info(f"Relative error: {relative_error:.2f}%")

return {
"true_mean_norm": true_mean_norm,
"estimated_norm": estimated_norm,
"absolute_error": absolute_error,
"relative_error": relative_error,
}

@torch.no_grad()
def update_norms(self):
# Not running GGPO so not currently running update norms
if self.ggpo_beta is None or self.ggpo_sigma is None:
return

# only update norms when we are training
if self.training is False:
return

module_weights = self.lora_up.weight @ self.lora_down.weight
module_weights.mul(self.scale)

self.weight_norms = torch.norm(module_weights, dim=1, keepdim=True)
self.combined_weight_norms = torch.sqrt(
(self.org_weight_norm_estimate**2) + torch.sum(module_weights**2, dim=1, keepdim=True)
)

@torch.no_grad()
def update_grad_norms(self):
if self.training is False:
print(f"skipping update_grad_norms for {self.lora_name}")
return

lora_down_grad = None
lora_up_grad = None

for name, param in self.named_parameters():
if name == "lora_down.weight":
lora_down_grad = param.grad
elif name == "lora_up.weight":
lora_up_grad = param.grad

# Calculate gradient norms if we have both gradients
if lora_down_grad is not None and lora_up_grad is not None:
with torch.autocast(self.device.type):
approx_grad = self.scale * ((self.lora_up.weight @ lora_down_grad) + (lora_up_grad @ self.lora_down.weight))
self.grad_norms = torch.norm(approx_grad, dim=1, keepdim=True)

@property
def device(self):
return next(self.parameters()).device

@property
def dtype(self):
return next(self.parameters()).dtype


class LoRAInfModule(LoRAModule):
def __init__(
@@ -420,6 +562,15 @@ def create_network(
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False

ggpo_beta = kwargs.get("ggpo_beta", None)
ggpo_sigma = kwargs.get("ggpo_sigma", None)

if ggpo_beta is not None:
ggpo_beta = float(ggpo_beta)

if ggpo_sigma is not None:
ggpo_sigma = float(ggpo_sigma)

# train T5XXL
train_t5xxl = kwargs.get("train_t5xxl", False)
if train_t5xxl is not None:
@@ -430,6 +581,42 @@ def create_network(
if verbose is not None:
verbose = True if verbose == "True" else False

# regex-specific learning rates
def parse_kv_pairs(kv_pair_str: str, is_int: bool) -> Dict[str, float]:
"""
Parse a string of key-value pairs separated by commas.
"""
pairs = {}
for pair in kv_pair_str.split(","):
pair = pair.strip()
if not pair:
continue
if "=" not in pair:
logger.warning(f"Invalid format: {pair}, expected 'key=value'")
continue
key, value = pair.split("=", 1)
key = key.strip()
value = value.strip()
try:
pairs[key] = int(value) if is_int else float(value)
except ValueError:
logger.warning(f"Invalid value for {key}: {value}")
return pairs

# parse regular expression based learning rates
network_reg_lrs = kwargs.get("network_reg_lrs", None)
if network_reg_lrs is not None:
reg_lrs = parse_kv_pairs(network_reg_lrs, is_int=False)
else:
reg_lrs = None

# regex-specific dimensions (ranks)
network_reg_dims = kwargs.get("network_reg_dims", None)
if network_reg_dims is not None:
reg_dims = parse_kv_pairs(network_reg_dims, is_int=True)
else:
reg_dims = None

# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
@@ -449,6 +636,10 @@ def create_network(
in_dims=in_dims,
train_double_block_indices=train_double_block_indices,
train_single_block_indices=train_single_block_indices,
reg_dims=reg_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
reg_lrs=reg_lrs,
verbose=verbose,
)

@@ -466,7 +657,6 @@ def create_network(

# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open
@@ -497,22 +687,6 @@ def create_network_from_weights(multiplier, file, ae, text_encoders, flux, weigh
if train_t5xxl is None:
train_t5xxl = False

# # split qkv
# double_qkv_rank = None
# single_qkv_rank = None
# rank = None
# for lora_name, dim in modules_dim.items():
# if "double" in lora_name and "qkv" in lora_name:
# double_qkv_rank = dim
# elif "single" in lora_name and "linear1" in lora_name:
# single_qkv_rank = dim
# elif rank is None:
# rank = dim
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
# break
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
# single_qkv_rank is not None and single_qkv_rank != rank
# )
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined

module_class = LoRAInfModule if for_inference else LoRAModule
@@ -561,6 +735,10 @@ class LoRANetwork(torch.nn.Module):
in_dims: Optional[List[int]] = None,
train_double_block_indices: Optional[List[bool]] = None,
train_single_block_indices: Optional[List[bool]] = None,
reg_dims: Optional[Dict[str, int]] = None,
ggpo_beta: Optional[float] = None,
ggpo_sigma: Optional[float] = None,
reg_lrs: Optional[Dict[str, float]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
@@ -581,6 +759,8 @@ class LoRANetwork(torch.nn.Module):
self.in_dims = in_dims
self.train_double_block_indices = train_double_block_indices
self.train_single_block_indices = train_single_block_indices
self.reg_dims = reg_dims
self.reg_lrs = reg_lrs

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
@@ -599,10 +779,15 @@ class LoRANetwork(torch.nn.Module):
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )

if ggpo_beta is not None and ggpo_sigma is not None:
logger.info(f"LoRA-GGPO training sigma: {ggpo_sigma} beta: {ggpo_beta}")

if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")

if train_t5xxl:
logger.info(f"train T5XXL as well")

@@ -648,8 +833,16 @@ class LoRANetwork(torch.nn.Module):
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
# 通常、すべて対象とする
elif self.reg_dims is not None:
for reg, d in self.reg_dims.items():
if re.search(reg, lora_name):
dim = d
alpha = self.alpha
logger.info(f"LoRA {lora_name} matched with regex {reg}, using dim: {dim}")
break

# if modules_dim is None, we use default lora_dim. if modules_dim is not None, we use the specified dim (no default)
if dim is None and modules_dim is None:
if is_linear or is_conv2d_1x1:
dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha
@@ -722,6 +915,8 @@ class LoRANetwork(torch.nn.Module):
rank_dropout=rank_dropout,
module_dropout=module_dropout,
split_dims=split_dims,
ggpo_beta=ggpo_beta,
ggpo_sigma=ggpo_sigma,
)
loras.append(lora)

@@ -735,6 +930,9 @@ class LoRANetwork(torch.nn.Module):
skipped_te = []
for i, text_encoder in enumerate(text_encoders):
index = i
if text_encoder is None:
logger.info(f"Text Encoder {index+1} is None, skipping LoRA creation for this encoder.")
continue
if not train_t5xxl and index > 0: # 0: CLIP, 1: T5XXL, so we skip T5XXL if train_t5xxl is False
break

@@ -790,6 +988,35 @@ class LoRANetwork(torch.nn.Module):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled

def update_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_norms()

def update_grad_norms(self):
for lora in self.text_encoder_loras + self.unet_loras:
lora.update_grad_norms()

def grad_norms(self) -> Tensor | None:
grad_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "grad_norms") and lora.grad_norms is not None:
grad_norms.append(lora.grad_norms.mean(dim=0))
return torch.stack(grad_norms) if len(grad_norms) > 0 else None

def weight_norms(self) -> Tensor | None:
weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "weight_norms") and lora.weight_norms is not None:
weight_norms.append(lora.weight_norms.mean(dim=0))
return torch.stack(weight_norms) if len(weight_norms) > 0 else None

def combined_weight_norms(self) -> Tensor | None:
combined_weight_norms = []
for lora in self.text_encoder_loras + self.unet_loras:
if hasattr(lora, "combined_weight_norms") and lora.combined_weight_norms is not None:
combined_weight_norms.append(lora.combined_weight_norms.mean(dim=0))
return torch.stack(combined_weight_norms) if len(combined_weight_norms) > 0 else None

def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file
@@ -976,17 +1203,77 @@ class LoRANetwork(torch.nn.Module):
all_params = []
lr_descriptions = []

reg_lrs_list = list(self.reg_lrs.items()) if self.reg_lrs is not None else []

def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
# regular expression param groups: {"reg_lr_0": {"lora": {}, "plus": {}}, ...}
reg_groups = {}

for lora in loras:
# check if this lora matches any regex learning rate
matched_reg_lr = None
for i, (regex_str, reg_lr) in enumerate(reg_lrs_list):
try:
if re.search(regex_str, lora.lora_name):
matched_reg_lr = (i, reg_lr)
logger.info(f"Module {lora.lora_name} matched regex '{regex_str}' -> LR {reg_lr}")
break
except re.error:
# regex error should have been caught during parsing, but just in case
continue

for name, param in lora.named_parameters():
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
param_key = f"{lora.lora_name}.{name}"
is_plus = loraplus_ratio is not None and "lora_up" in name

if matched_reg_lr is not None:
# use regex-specific learning rate
reg_idx, reg_lr = matched_reg_lr
group_key = f"reg_lr_{reg_idx}"
if group_key not in reg_groups:
reg_groups[group_key] = {"lora": {}, "plus": {}, "lr": reg_lr}

if is_plus:
reg_groups[group_key]["plus"][param_key] = param
else:
reg_groups[group_key]["lora"][param_key] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param
# use default learning rate
if is_plus:
param_groups["plus"][param_key] = param
else:
param_groups["lora"][param_key] = param

params = []
descriptions = []

# process regex-specific groups first (higher priority)
for group_key in sorted(reg_groups.keys()):
group = reg_groups[group_key]
reg_lr = group["lr"]

for param_type in ["lora", "plus"]:
if len(group[param_type]) == 0:
continue

param_data = {"params": group[param_type].values()}

if param_type == "plus" and loraplus_ratio is not None:
param_data["lr"] = reg_lr * loraplus_ratio
else:
param_data["lr"] = reg_lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
continue

params.append(param_data)
desc = f"reg_lr_{group_key.split('_')[-1]}"
if param_type == "plus":
desc += " plus"
descriptions.append(desc)

# process default groups
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}



+ 1038
- 0
scripts/dev/networks/lora_lumina.py View File

@@ -0,0 +1,1038 @@
# temporary minimum implementation of LoRA
# Lumina 2 does not have Conv2d, so ignore
# TODO commonize with the original implementation

# LoRA network module
# reference:
# https://github.com/microsoft/LoRA/blob/main/loralib/layers.py
# https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py

import math
import os
from typing import Dict, List, Optional, Tuple, Type, Union
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from transformers import CLIPTextModel
import torch
from torch import Tensor, nn
from library.utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)


class LoRAModule(torch.nn.Module):
"""
replaces forward method of the original Linear, instead of replacing the original Linear module.
"""

def __init__(
self,
lora_name: str,
org_module: nn.Module,
multiplier: float =1.0,
lora_dim: int = 4,
alpha: Optional[float | int | Tensor] = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
split_dims: Optional[List[int]] = None,
):
"""
if alpha == 0 or None, alpha is rank (no scaling).

split_dims is used to mimic the split qkv of lumina as same as Diffusers
"""
super().__init__()
self.lora_name = lora_name

if org_module.__class__.__name__ == "Conv2d":
in_dim = org_module.in_channels
out_dim = org_module.out_channels
else:
in_dim = org_module.in_features
out_dim = org_module.out_features

assert isinstance(in_dim, int)
assert isinstance(out_dim, int)

self.lora_dim = lora_dim
self.split_dims = split_dims

if split_dims is None:
if org_module.__class__.__name__ == "Conv2d":
kernel_size = org_module.kernel_size
stride = org_module.stride
padding = org_module.padding
self.lora_down = nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False)
self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
else:
self.lora_down = nn.Linear(in_dim, self.lora_dim, bias=False)
self.lora_up = nn.Linear(self.lora_dim, out_dim, bias=False)

nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_up.weight)
else:
# conv2d not supported
assert sum(split_dims) == out_dim, "sum of split_dims must be equal to out_dim"
assert org_module.__class__.__name__ == "Linear", "split_dims is only supported for Linear"
# print(f"split_dims: {split_dims}")
self.lora_down = nn.ModuleList(
[nn.Linear(in_dim, self.lora_dim, bias=False) for _ in range(len(split_dims))]
)
self.lora_up = nn.ModuleList([torch.nn.Linear(self.lora_dim, split_dim, bias=False) for split_dim in split_dims])

for lora_down in self.lora_down:
nn.init.kaiming_uniform_(lora_down.weight, a=math.sqrt(5))
for lora_up in self.lora_up:
nn.init.zeros_(lora_up.weight)

if isinstance(alpha, Tensor):
alpha = alpha.detach().cpu().float().item() # without casting, bf16 causes error
alpha = self.lora_dim if alpha is None or alpha == 0 else alpha
self.scale = alpha / self.lora_dim
self.register_buffer("alpha", torch.tensor(alpha)) # 定数として扱える

# same as microsoft's
self.multiplier = multiplier
self.org_module = org_module # remove in applying
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout

def apply_to(self):
self.org_forward = self.org_module.forward
self.org_module.forward = self.forward
del self.org_module

def forward(self, x):
org_forwarded = self.org_forward(x)

# module dropout
if self.module_dropout is not None and self.training:
if torch.rand(1) < self.module_dropout:
return org_forwarded

if self.split_dims is None:
lx = self.lora_down(x)

# normal dropout
if self.dropout is not None and self.training:
lx = torch.nn.functional.dropout(lx, p=self.dropout)

# rank dropout
if self.rank_dropout is not None and self.training:
mask = torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout
if len(lx.size()) == 3:
mask = mask.unsqueeze(1) # for Text Encoder
elif len(lx.size()) == 4:
mask = mask.unsqueeze(-1).unsqueeze(-1) # for Conv2d
lx = lx * mask

# scaling for rank dropout: treat as if the rank is changed
# maskから計算することも考えられるが、augmentation的な効果を期待してrank_dropoutを用いる
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale

lx = self.lora_up(lx)

return org_forwarded + lx * self.multiplier * scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]

# normal dropout
if self.dropout is not None and self.training:
lxs = [torch.nn.functional.dropout(lx, p=self.dropout) for lx in lxs]

# rank dropout
if self.rank_dropout is not None and self.training:
masks = [torch.rand((lx.size(0), self.lora_dim), device=lx.device) > self.rank_dropout for lx in lxs]
for i in range(len(lxs)):
if len(lxs[i].size()) == 3:
masks[i] = masks[i].unsqueeze(1)
elif len(lxs[i].size()) == 4:
masks[i] = masks[i].unsqueeze(-1).unsqueeze(-1)
lxs[i] = lxs[i] * masks[i]

# scaling for rank dropout: treat as if the rank is changed
scale = self.scale * (1.0 / (1.0 - self.rank_dropout)) # redundant for readability
else:
scale = self.scale

lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]

return org_forwarded + torch.cat(lxs, dim=-1) * self.multiplier * scale


class LoRAInfModule(LoRAModule):
def __init__(
self,
lora_name,
org_module: torch.nn.Module,
multiplier=1.0,
lora_dim=4,
alpha=1,
**kwargs,
):
# no dropout for inference
super().__init__(lora_name, org_module, multiplier, lora_dim, alpha)

self.org_module_ref = [org_module] # 後から参照できるように
self.enabled = True
self.network: LoRANetwork = None

def set_network(self, network):
self.network = network

# freezeしてマージする
def merge_to(self, sd, dtype, device):
# extract weight from org_module
org_sd = self.org_module.state_dict()
weight = org_sd["weight"]
org_dtype = weight.dtype
org_device = weight.device
weight = weight.to(torch.float) # calc in float

if dtype is None:
dtype = org_dtype
if device is None:
device = org_device

if self.split_dims is None:
# get up/down weight
down_weight = sd["lora_down.weight"].to(torch.float).to(device)
up_weight = sd["lora_up.weight"].to(torch.float).to(device)

# merge weight
if len(weight.size()) == 2:
# linear
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
weight
+ self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
weight = weight + self.multiplier * conved * self.scale

# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)
else:
# split_dims
total_dims = sum(self.split_dims)
for i in range(len(self.split_dims)):
# get up/down weight
down_weight = sd[f"lora_down.{i}.weight"].to(torch.float).to(device) # (rank, in_dim)
up_weight = sd[f"lora_up.{i}.weight"].to(torch.float).to(device) # (split dim, rank)

# pad up_weight -> (total_dims, rank)
padded_up_weight = torch.zeros((total_dims, up_weight.size(0)), device=device, dtype=torch.float)
padded_up_weight[sum(self.split_dims[:i]) : sum(self.split_dims[: i + 1])] = up_weight

# merge weight
weight = weight + self.multiplier * (up_weight @ down_weight) * self.scale

# set weight to org_module
org_sd["weight"] = weight.to(dtype)
self.org_module.load_state_dict(org_sd)

# 復元できるマージのため、このモジュールのweightを返す
def get_weight(self, multiplier=None):
if multiplier is None:
multiplier = self.multiplier

# get up/down weight from module
up_weight = self.lora_up.weight.to(torch.float)
down_weight = self.lora_down.weight.to(torch.float)

# pre-calculated weight
if len(down_weight.size()) == 2:
# linear
weight = self.multiplier * (up_weight @ down_weight) * self.scale
elif down_weight.size()[2:4] == (1, 1):
# conv2d 1x1
weight = (
self.multiplier
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
* self.scale
)
else:
# conv2d 3x3
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
weight = self.multiplier * conved * self.scale

return weight

def set_region(self, region):
self.region = region
self.region_mask = None

def default_forward(self, x):
# logger.info(f"default_forward {self.lora_name} {x.size()}")
if self.split_dims is None:
lx = self.lora_down(x)
lx = self.lora_up(lx)
return self.org_forward(x) + lx * self.multiplier * self.scale
else:
lxs = [lora_down(x) for lora_down in self.lora_down]
lxs = [lora_up(lx) for lora_up, lx in zip(self.lora_up, lxs)]
return self.org_forward(x) + torch.cat(lxs, dim=-1) * self.multiplier * self.scale

def forward(self, x):
if not self.enabled:
return self.org_forward(x)
return self.default_forward(x)


def create_network(
multiplier: float,
network_dim: Optional[int],
network_alpha: Optional[float],
ae: AutoencoderKL,
text_encoders: List[CLIPTextModel],
lumina,
neuron_dropout: Optional[float] = None,
**kwargs,
):
if network_dim is None:
network_dim = 4 # default
if network_alpha is None:
network_alpha = 1.0

# extract dim/alpha for conv2d, and block dim
conv_dim = kwargs.get("conv_dim", None)
conv_alpha = kwargs.get("conv_alpha", None)
if conv_dim is not None:
conv_dim = int(conv_dim)
if conv_alpha is None:
conv_alpha = 1.0
else:
conv_alpha = float(conv_alpha)

# attn dim, mlp dim for JointTransformerBlock
attn_dim = kwargs.get("attn_dim", None) # attention dimension
mlp_dim = kwargs.get("mlp_dim", None) # MLP dimension
mod_dim = kwargs.get("mod_dim", None) # modulation dimension
refiner_dim = kwargs.get("refiner_dim", None) # refiner blocks dimension

if attn_dim is not None:
attn_dim = int(attn_dim)
if mlp_dim is not None:
mlp_dim = int(mlp_dim)
if mod_dim is not None:
mod_dim = int(mod_dim)
if refiner_dim is not None:
refiner_dim = int(refiner_dim)

type_dims = [attn_dim, mlp_dim, mod_dim, refiner_dim]
if all([d is None for d in type_dims]):
type_dims = None

# embedder_dims for embedders
embedder_dims = kwargs.get("embedder_dims", None)
if embedder_dims is not None:
embedder_dims = embedder_dims.strip()
if embedder_dims.startswith("[") and embedder_dims.endswith("]"):
embedder_dims = embedder_dims[1:-1]
embedder_dims = [int(d) for d in embedder_dims.split(",")]
assert len(embedder_dims) == 3, f"invalid embedder_dims: {embedder_dims}, must be 3 dimensions (x_embedder, t_embedder, cap_embedder)"

# rank/module dropout
rank_dropout = kwargs.get("rank_dropout", None)
if rank_dropout is not None:
rank_dropout = float(rank_dropout)
module_dropout = kwargs.get("module_dropout", None)
if module_dropout is not None:
module_dropout = float(module_dropout)

# single or double blocks
train_blocks = kwargs.get("train_blocks", None) # None (default), "all" (same as None), "transformer", "refiners", "noise_refiner", "context_refiner"
if train_blocks is not None:
assert train_blocks in ["all", "transformer", "refiners", "noise_refiner", "context_refiner"], f"invalid train_blocks: {train_blocks}"

# split qkv
split_qkv = kwargs.get("split_qkv", False)
if split_qkv is not None:
split_qkv = True if split_qkv == "True" else False

# verbose
verbose = kwargs.get("verbose", False)
if verbose is not None:
verbose = True if verbose == "True" else False

# すごく引数が多いな ( ^ω^)・・・
network = LoRANetwork(
text_encoders,
lumina,
multiplier=multiplier,
lora_dim=network_dim,
alpha=network_alpha,
dropout=neuron_dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
conv_lora_dim=conv_dim,
conv_alpha=conv_alpha,
train_blocks=train_blocks,
split_qkv=split_qkv,
type_dims=type_dims,
embedder_dims=embedder_dims,
verbose=verbose,
)

loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None)
loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None)
loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None)
loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None
loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None
loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None
if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None:
network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio)

return network


# Create network from weights for inference, weights are not loaded here (because can be merged)
def create_network_from_weights(multiplier, file, ae, text_encoders, lumina, weights_sd=None, for_inference=False, **kwargs):
# if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True
if weights_sd is None:
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file, safe_open

weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")

# get dim/alpha mapping, and train t5xxl
modules_dim = {}
modules_alpha = {}
for key, value in weights_sd.items():
if "." not in key:
continue

lora_name = key.split(".")[0]
if "alpha" in key:
modules_alpha[lora_name] = value
elif "lora_down" in key:
dim = value.size()[0]
modules_dim[lora_name] = dim
# logger.info(lora_name, value.size(), dim)

# # split qkv
# double_qkv_rank = None
# single_qkv_rank = None
# rank = None
# for lora_name, dim in modules_dim.items():
# if "double" in lora_name and "qkv" in lora_name:
# double_qkv_rank = dim
# elif "single" in lora_name and "linear1" in lora_name:
# single_qkv_rank = dim
# elif rank is None:
# rank = dim
# if double_qkv_rank is not None and single_qkv_rank is not None and rank is not None:
# break
# split_qkv = (double_qkv_rank is not None and double_qkv_rank != rank) or (
# single_qkv_rank is not None and single_qkv_rank != rank
# )
split_qkv = False # split_qkv is not needed to care, because state_dict is qkv combined

module_class = LoRAInfModule if for_inference else LoRAModule

network = LoRANetwork(
text_encoders,
lumina,
multiplier=multiplier,
modules_dim=modules_dim,
modules_alpha=modules_alpha,
module_class=module_class,
split_qkv=split_qkv,
)
return network, weights_sd


class LoRANetwork(torch.nn.Module):
LUMINA_TARGET_REPLACE_MODULE = ["JointTransformerBlock", "FinalLayer"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["Gemma2Attention", "Gemma2FlashAttention2", "Gemma2SdpaAttention", "Gemma2MLP"]
LORA_PREFIX_LUMINA = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" # Simplified prefix since we only have one text encoder

def __init__(
self,
text_encoders, # Now this will be a single Gemma2 model
unet,
multiplier: float = 1.0,
lora_dim: int = 4,
alpha: float = 1,
dropout: Optional[float] = None,
rank_dropout: Optional[float] = None,
module_dropout: Optional[float] = None,
conv_lora_dim: Optional[int] = None,
conv_alpha: Optional[float] = None,
module_class: Type[LoRAModule] = LoRAModule,
modules_dim: Optional[Dict[str, int]] = None,
modules_alpha: Optional[Dict[str, int]] = None,
train_blocks: Optional[str] = None,
split_qkv: bool = False,
type_dims: Optional[List[int]] = None,
embedder_dims: Optional[List[int]] = None,
train_block_indices: Optional[List[bool]] = None,
verbose: Optional[bool] = False,
) -> None:
super().__init__()
self.multiplier = multiplier

self.lora_dim = lora_dim
self.alpha = alpha
self.conv_lora_dim = conv_lora_dim
self.conv_alpha = conv_alpha
self.dropout = dropout
self.rank_dropout = rank_dropout
self.module_dropout = module_dropout
self.train_blocks = train_blocks if train_blocks is not None else "all"
self.split_qkv = split_qkv

self.type_dims = type_dims
self.embedder_dims = embedder_dims

self.train_block_indices = train_block_indices

self.loraplus_lr_ratio = None
self.loraplus_unet_lr_ratio = None
self.loraplus_text_encoder_lr_ratio = None

if modules_dim is not None:
logger.info(f"create LoRA network from weights")
self.embedder_dims = [0] * 5 # create embedder_dims
# verbose = True
else:
logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}")
logger.info(
f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}"
)
# if self.conv_lora_dim is not None:
# logger.info(
# f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}"
# )
if self.split_qkv:
logger.info(f"split qkv for LoRA")
if self.train_blocks is not None:
logger.info(f"train {self.train_blocks} blocks only")

# create module instances
def create_modules(
is_lumina: bool,
root_module: torch.nn.Module,
target_replace_modules: Optional[List[str]],
filter: Optional[str] = None,
default_dim: Optional[int] = None,
) -> List[LoRAModule]:
prefix = self.LORA_PREFIX_LUMINA if is_lumina else self.LORA_PREFIX_TEXT_ENCODER

loras = []
skipped = []
for name, module in root_module.named_modules():
if target_replace_modules is None or module.__class__.__name__ in target_replace_modules:
if target_replace_modules is None: # for handling embedders
module = root_module

for child_name, child_module in module.named_modules():
is_linear = child_module.__class__.__name__ == "Linear"

lora_name = prefix + "." + (name + "." if name else "") + child_name
lora_name = lora_name.replace(".", "_")

# Only Linear is supported
if not is_linear:
skipped.append(lora_name)
continue

if filter is not None and filter not in lora_name:
continue

dim = default_dim if default_dim is not None else self.lora_dim
alpha = self.alpha

# Set dim/alpha to modules dim/alpha
if modules_dim is not None and modules_alpha is not None:
# network from weights
if lora_name in modules_dim:
dim = modules_dim[lora_name]
alpha = modules_alpha[lora_name]
else:
dim = 0 # skip if not found

else:
# Set dims to type_dims
if is_lumina and type_dims is not None:
identifier = [
("attention",), # attention layers
("mlp",), # MLP layers
("modulation",), # modulation layers
("refiner",), # refiner blocks
]
for i, d in enumerate(type_dims):
if d is not None and all([id in lora_name for id in identifier[i]]):
dim = d # may be 0 for skip
break

# Drop blocks if we are only training some blocks
if (
is_lumina
and dim
and (
self.train_block_indices is not None
)
and ("layer" in lora_name)
):
# "lora_unet_layers_0_..." or "lora_unet_cap_refiner_0_..." or or "lora_unet_noise_refiner_0_..."
block_index = int(lora_name.split("_")[3]) # bit dirty
if (
"layer" in lora_name
and self.train_block_indices is not None
and not self.train_block_indices[block_index]
):
dim = 0


if dim is None or dim == 0:
# skipした情報を出力
skipped.append(lora_name)
continue

lora = module_class(
lora_name,
child_module,
self.multiplier,
dim,
alpha,
dropout=dropout,
rank_dropout=rank_dropout,
module_dropout=module_dropout,
)
loras.append(lora)

if target_replace_modules is None:
break # all modules are searched
return loras, skipped

# create LoRA for text encoder (Gemma2)
self.text_encoder_loras: List[Union[LoRAModule, LoRAInfModule]] = []
skipped_te = []

logger.info(f"create LoRA for Gemma2 Text Encoder:")
text_encoder_loras, skipped = create_modules(False, text_encoders[0], LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE)
logger.info(f"create LoRA for Gemma2 Text Encoder: {len(text_encoder_loras)} modules.")
self.text_encoder_loras.extend(text_encoder_loras)
skipped_te += skipped

# create LoRA for U-Net
if self.train_blocks == "all":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
# TODO: limit different blocks
elif self.train_blocks == "transformer":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "refiners":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "noise_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE
elif self.train_blocks == "cap_refiner":
target_replace_modules = LoRANetwork.LUMINA_TARGET_REPLACE_MODULE

self.unet_loras: List[Union[LoRAModule, LoRAInfModule]]
self.unet_loras, skipped_un = create_modules(True, unet, target_replace_modules)

# Handle embedders
if self.embedder_dims:
for filter, embedder_dim in zip(["x_embedder", "t_embedder", "cap_embedder"], self.embedder_dims):
loras, _ = create_modules(True, unet, None, filter=filter, default_dim=embedder_dim)
self.unet_loras.extend(loras)

logger.info(f"create LoRA for Lumina blocks: {len(self.unet_loras)} modules.")
if verbose:
for lora in self.unet_loras:
logger.info(f"\t{lora.lora_name:50} {lora.lora_dim}, {lora.alpha}")

skipped = skipped_te + skipped_un
if verbose and len(skipped) > 0:
logger.warning(
f"because dim (rank) is 0, {len(skipped)} LoRA modules are skipped / dim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:"
)
for name in skipped:
logger.info(f"\t{name}")

# assertion
names = set()
for lora in self.text_encoder_loras + self.unet_loras:
assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}"
names.add(lora.lora_name)

def set_multiplier(self, multiplier):
self.multiplier = multiplier
for lora in self.text_encoder_loras + self.unet_loras:
lora.multiplier = self.multiplier

def set_enabled(self, is_enabled):
for lora in self.text_encoder_loras + self.unet_loras:
lora.enabled = is_enabled

def load_weights(self, file):
if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import load_file

weights_sd = load_file(file)
else:
weights_sd = torch.load(file, map_location="cpu")

info = self.load_state_dict(weights_sd, False)
return info

def load_state_dict(self, state_dict, strict=True):
# override to convert original weight to split qkv
if not self.split_qkv:
return super().load_state_dict(state_dict, strict)

# # split qkv
# for key in list(state_dict.keys()):
# if "double" in key and "qkv" in key:
# split_dims = [3072] * 3
# elif "single" in key and "linear1" in key:
# split_dims = [3072] * 3 + [12288]
# else:
# continue

# weight = state_dict[key]
# lora_name = key.split(".")[0]

# if key not in state_dict:
# continue # already merged

# # (rank, in_dim) * 3
# down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
# # (split dim, rank) * 3
# up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]

# alpha = state_dict.pop(f"{lora_name}.alpha")

# # merge down weight
# down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)

# # merge up weight (sum of split_dim, rank*3)
# rank = up_weights[0].size(1)
# up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
# i = 0
# for j in range(len(split_dims)):
# up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
# i += split_dims[j]

# state_dict[f"{lora_name}.lora_down.weight"] = down_weight
# state_dict[f"{lora_name}.lora_up.weight"] = up_weight
# state_dict[f"{lora_name}.alpha"] = alpha

# # print(
# # f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# # )
# print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")

return super().load_state_dict(state_dict, strict)

def state_dict(self, destination=None, prefix="", keep_vars=False):
if not self.split_qkv:
return super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)

# merge qkv
state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
new_state_dict = {}
for key in list(state_dict.keys()):
if "double" in key and "qkv" in key:
split_dims = [3072] * 3
elif "single" in key and "linear1" in key:
split_dims = [3072] * 3 + [12288]
else:
new_state_dict[key] = state_dict[key]
continue

if key not in state_dict:
continue # already merged

lora_name = key.split(".")[0]

# (rank, in_dim) * 3
down_weights = [state_dict.pop(f"{lora_name}.lora_down.{i}.weight") for i in range(len(split_dims))]
# (split dim, rank) * 3
up_weights = [state_dict.pop(f"{lora_name}.lora_up.{i}.weight") for i in range(len(split_dims))]

alpha = state_dict.pop(f"{lora_name}.alpha")

# merge down weight
down_weight = torch.cat(down_weights, dim=0) # (rank, split_dim) * 3 -> (rank*3, sum of split_dim)

# merge up weight (sum of split_dim, rank*3)
rank = up_weights[0].size(1)
up_weight = torch.zeros((sum(split_dims), down_weight.size(0)), device=down_weight.device, dtype=down_weight.dtype)
i = 0
for j in range(len(split_dims)):
up_weight[i : i + split_dims[j], j * rank : (j + 1) * rank] = up_weights[j]
i += split_dims[j]

new_state_dict[f"{lora_name}.lora_down.weight"] = down_weight
new_state_dict[f"{lora_name}.lora_up.weight"] = up_weight
new_state_dict[f"{lora_name}.alpha"] = alpha

# print(
# f"merged {lora_name}: {lora_name}, {[w.shape for w in down_weights]}, {[w.shape for w in up_weights]} to {down_weight.shape}, {up_weight.shape}"
# )
print(f"new key: {lora_name}.lora_down.weight, {lora_name}.lora_up.weight, {lora_name}.alpha")

return new_state_dict

def apply_to(self, text_encoders, flux, apply_text_encoder=True, apply_unet=True):
if apply_text_encoder:
logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules")
else:
self.text_encoder_loras = []

if apply_unet:
logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules")
else:
self.unet_loras = []

for lora in self.text_encoder_loras + self.unet_loras:
lora.apply_to()
self.add_module(lora.lora_name, lora)

# マージできるかどうかを返す
def is_mergeable(self):
return True

# TODO refactor to common function with apply_to
def merge_to(self, text_encoders, flux, weights_sd, dtype=None, device=None):
apply_text_encoder = apply_unet = False
for key in weights_sd.keys():
if key.startswith(LoRANetwork.LORA_PREFIX_TEXT_ENCODER):
apply_text_encoder = True
elif key.startswith(LoRANetwork.LORA_PREFIX_LUMINA):
apply_unet = True

if apply_text_encoder:
logger.info("enable LoRA for text encoder")
else:
self.text_encoder_loras = []

if apply_unet:
logger.info("enable LoRA for U-Net")
else:
self.unet_loras = []

for lora in self.text_encoder_loras + self.unet_loras:
sd_for_lora = {}
for key in weights_sd.keys():
if key.startswith(lora.lora_name):
sd_for_lora[key[len(lora.lora_name) + 1 :]] = weights_sd[key]
lora.merge_to(sd_for_lora, dtype, device)

logger.info(f"weights are merged")

def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio):
self.loraplus_lr_ratio = loraplus_lr_ratio
self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio
self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio

logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}")
logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}")

def prepare_optimizer_params_with_multiple_te_lrs(self, text_encoder_lr, unet_lr, default_lr):
# make sure text_encoder_lr as list of two elements
# if float, use the same value for both text encoders
if text_encoder_lr is None or (isinstance(text_encoder_lr, list) and len(text_encoder_lr) == 0):
text_encoder_lr = [default_lr, default_lr]
elif isinstance(text_encoder_lr, float) or isinstance(text_encoder_lr, int):
text_encoder_lr = [float(text_encoder_lr), float(text_encoder_lr)]
elif len(text_encoder_lr) == 1:
text_encoder_lr = [text_encoder_lr[0], text_encoder_lr[0]]

self.requires_grad_(True)

all_params = []
lr_descriptions = []

def assemble_params(loras, lr, loraplus_ratio):
param_groups = {"lora": {}, "plus": {}}
for lora in loras:
for name, param in lora.named_parameters():
if loraplus_ratio is not None and "lora_up" in name:
param_groups["plus"][f"{lora.lora_name}.{name}"] = param
else:
param_groups["lora"][f"{lora.lora_name}.{name}"] = param

params = []
descriptions = []
for key in param_groups.keys():
param_data = {"params": param_groups[key].values()}

if len(param_data["params"]) == 0:
continue

if lr is not None:
if key == "plus":
param_data["lr"] = lr * loraplus_ratio
else:
param_data["lr"] = lr

if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None:
logger.info("NO LR skipping!")
continue

params.append(param_data)
descriptions.append("plus" if key == "plus" else "")

return params, descriptions

if self.text_encoder_loras:
loraplus_lr_ratio = self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio

# split text encoder loras for te1 and te3
te_loras = [lora for lora in self.text_encoder_loras]
if len(te_loras) > 0:
logger.info(f"Text Encoder: {len(te_loras)} modules, LR {text_encoder_lr[0]}")
params, descriptions = assemble_params(te_loras, text_encoder_lr[0], loraplus_lr_ratio)
all_params.extend(params)
lr_descriptions.extend(["textencoder " + (" " + d if d else "") for d in descriptions])

if self.unet_loras:
params, descriptions = assemble_params(
self.unet_loras,
unet_lr if unet_lr is not None else default_lr,
self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio,
)
all_params.extend(params)
lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions])

return all_params, lr_descriptions

def enable_gradient_checkpointing(self):
# not supported
pass

def prepare_grad_etc(self, text_encoder, unet):
self.requires_grad_(True)

def on_epoch_start(self, text_encoder, unet):
self.train()

def get_trainable_params(self):
return self.parameters()

def save_weights(self, file, dtype, metadata):
if metadata is not None and len(metadata) == 0:
metadata = None

state_dict = self.state_dict()

if dtype is not None:
for key in list(state_dict.keys()):
v = state_dict[key]
v = v.detach().clone().to("cpu").to(dtype)
state_dict[key] = v

if os.path.splitext(file)[1] == ".safetensors":
from safetensors.torch import save_file
from library import train_util

# Precalculate model hashes to save time on indexing
if metadata is None:
metadata = {}
model_hash, legacy_hash = train_util.precalculate_safetensors_hashes(state_dict, metadata)
metadata["sshs_model_hash"] = model_hash
metadata["sshs_legacy_hash"] = legacy_hash

save_file(state_dict, file, metadata)
else:
torch.save(state_dict, file)

def backup_weights(self):
# 重みのバックアップを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not hasattr(org_module, "_lora_org_weight"):
sd = org_module.state_dict()
org_module._lora_org_weight = sd["weight"].detach().clone()
org_module._lora_restored = True

def restore_weights(self):
# 重みのリストアを行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
if not org_module._lora_restored:
sd = org_module.state_dict()
sd["weight"] = org_module._lora_org_weight
org_module.load_state_dict(sd)
org_module._lora_restored = True

def pre_calculation(self):
# 事前計算を行う
loras: List[LoRAInfModule] = self.text_encoder_loras + self.unet_loras
for lora in loras:
org_module = lora.org_module_ref[0]
sd = org_module.state_dict()

org_weight = sd["weight"]
lora_weight = lora.get_weight().to(org_weight.device, dtype=org_weight.dtype)
sd["weight"] = org_weight + lora_weight
assert sd["weight"].shape == org_weight.shape
org_module.load_state_dict(sd)

org_module._lora_restored = False
lora.enabled = False

def apply_max_norm_regularization(self, max_norm_value, device):
downkeys = []
upkeys = []
alphakeys = []
norms = []
keys_scaled = 0

state_dict = self.state_dict()
for key in state_dict.keys():
if "lora_down" in key and "weight" in key:
downkeys.append(key)
upkeys.append(key.replace("lora_down", "lora_up"))
alphakeys.append(key.replace("lora_down.weight", "alpha"))

for i in range(len(downkeys)):
down = state_dict[downkeys[i]].to(device)
up = state_dict[upkeys[i]].to(device)
alpha = state_dict[alphakeys[i]].to(device)
dim = down.shape[0]
scale = alpha / dim

if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down

updown *= scale

norm = updown.norm().clamp(min=max_norm_value / 2)
desired = torch.clamp(norm, max=max_norm_value)
ratio = desired.cpu() / norm.cpu()
sqrt_ratio = ratio**0.5
if ratio != 1:
keys_scaled += 1
state_dict[upkeys[i]] *= sqrt_ratio
state_dict[downkeys[i]] *= sqrt_ratio
scalednorm = updown.norm() * ratio
norms.append(scalednorm.item())

return keys_scaled, sum(norms) / len(norms), max(norms)

+ 41
- 28
scripts/dev/networks/resize_lora.py View File

@@ -20,6 +20,13 @@ logger = logging.getLogger(__name__)

MIN_SV = 1e-6

LORA_DOWN_UP_FORMATS = [
("lora_down", "lora_up"), # sd-scripts LoRA
("lora_A", "lora_B"), # PEFT LoRA
("down", "up"), # ControlLoRA
]


# Model save and load functions


@@ -192,24 +199,11 @@ def rank_resize(S, rank, dynamic_method, dynamic_param, scale=1):


def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dynamic_method, dynamic_param, verbose):
network_alpha = None
network_dim = None
max_old_rank = None
new_alpha = None
verbose_str = "\n"
fro_list = []

# Extract loaded lora dim and alpha
for key, value in lora_sd.items():
if network_alpha is None and "alpha" in key:
network_alpha = value
if network_dim is None and "lora_down" in key and len(value.size()) == 2:
network_dim = value.size()[0]
if network_alpha is not None and network_dim is not None:
break
if network_alpha is None:
network_alpha = network_dim

scale = network_alpha / network_dim

if dynamic_method:
logger.info(
f"Dynamically determining new alphas and dims based off {dynamic_method}: {dynamic_param}, max rank is {new_rank}"
@@ -224,17 +218,33 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna

with torch.no_grad():
for key, value in tqdm(lora_sd.items()):
weight_name = None
if "lora_down" in key:
block_down_name = key.rsplit(".lora_down", 1)[0]
weight_name = key.rsplit(".", 1)[-1]
lora_down_weight = value
else:
key_parts = key.split(".")
block_down_name = None
for _format in LORA_DOWN_UP_FORMATS:
# Currently we only match lora_down_name in the last two parts of key
# because ("down", "up") are general words and may appear in block_down_name
if len(key_parts) >= 2 and _format[0] == key_parts[-2]:
block_down_name = ".".join(key_parts[:-2])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = "." + key_parts[-1]
break
if len(key_parts) >= 1 and _format[0] == key_parts[-1]:
block_down_name = ".".join(key_parts[:-1])
lora_down_name = "." + _format[0]
lora_up_name = "." + _format[1]
weight_name = ""
break

if block_down_name is None:
# This parameter is not lora_down
continue

# find corresponding lora_up and alpha
# Now weight_name can be ".weight" or ""
# Find corresponding lora_up and alpha
block_up_name = block_down_name
lora_up_weight = lora_sd.get(block_up_name + ".lora_up." + weight_name, None)
lora_down_weight = value
lora_up_weight = lora_sd.get(block_up_name + lora_up_name + weight_name, None)
lora_alpha = lora_sd.get(block_down_name + ".alpha", None)

weights_loaded = lora_down_weight is not None and lora_up_weight is not None
@@ -242,10 +252,13 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
if weights_loaded:

conv2d = len(lora_down_weight.size()) == 4
old_rank = lora_down_weight.size()[0]
max_old_rank = max(max_old_rank or 0, old_rank)

if lora_alpha is None:
scale = 1.0
else:
scale = lora_alpha / lora_down_weight.size()[0]
scale = lora_alpha / old_rank

if conv2d:
full_weight_matrix = merge_conv(lora_down_weight, lora_up_weight, device)
@@ -272,9 +285,9 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
verbose_str += "\n"

new_alpha = param_dict["new_alpha"]
o_lora_sd[block_down_name + "." + "lora_down.weight"] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." + "lora_up.weight"] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + "." "alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)
o_lora_sd[block_down_name + lora_down_name + weight_name] = param_dict["lora_down"].to(save_dtype).contiguous()
o_lora_sd[block_up_name + lora_up_name + weight_name] = param_dict["lora_up"].to(save_dtype).contiguous()
o_lora_sd[block_down_name + ".alpha"] = torch.tensor(param_dict["new_alpha"]).to(save_dtype)

block_down_name = None
block_up_name = None
@@ -287,7 +300,7 @@ def resize_lora_model(lora_sd, new_rank, new_conv_rank, save_dtype, device, dyna
print(verbose_str)
print(f"Average Frobenius norm retention: {np.mean(fro_list):.2%} | std: {np.std(fro_list):0.3f}")
logger.info("resizing complete")
return o_lora_sd, network_dim, new_alpha
return o_lora_sd, max_old_rank, new_alpha


def resize(args):


+ 1
- 0
scripts/dev/pytest.ini View File

@@ -6,3 +6,4 @@ filterwarnings =
ignore::DeprecationWarning
ignore::UserWarning
ignore::FutureWarning
pythonpath = .

+ 1
- 1
scripts/dev/requirements.txt View File

@@ -9,7 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.44.0
lion-pytorch==0.0.6
schedulefree==1.4
pytorch-optimizer==3.5.0
pytorch-optimizer==3.7.0
prodigy-plus-schedule-free==1.9.0
prodigyopt==1.1.2
tensorboard


+ 3
- 0
scripts/dev/sd3_train.py View File

@@ -20,6 +20,8 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sd3_models, sd3_train_utils, sd3_utils, strategy_base, strategy_sd3

import library.sai_model_spec as sai_model_spec
from library.sdxl_train_util import match_mixed_precision

# , sdxl_model_util
@@ -986,6 +988,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)


+ 2
- 1
scripts/dev/sdxl_train.py View File

@@ -17,7 +17,7 @@ init_ipex()

from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl
from library import deepspeed_utils, sdxl_model_util, strategy_base, strategy_sd, strategy_sdxl, sai_model_spec

import library.train_util as train_util

@@ -893,6 +893,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, False)
train_util.add_masked_loss_arguments(parser)


+ 2
- 0
scripts/dev/sdxl_train_control_net.py View File

@@ -25,6 +25,7 @@ from library import (
strategy_base,
strategy_sd,
strategy_sdxl,
sai_model_spec
)

import library.train_util as train_util
@@ -664,6 +665,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
# train_util.add_masked_loss_arguments(parser)


+ 2
- 0
scripts/dev/sdxl_train_control_net_lllite.py View File

@@ -32,6 +32,7 @@ from library import (
strategy_base,
strategy_sd,
strategy_sdxl,
sai_model_spec,
)

import library.model_util as model_util
@@ -589,6 +590,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)


+ 2
- 0
scripts/dev/sdxl_train_control_net_lllite_old.py View File

@@ -24,6 +24,7 @@ from library import deepspeed_utils, sai_model_spec, sdxl_model_util, sdxl_origi
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -536,6 +537,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, False, True, True)
train_util.add_training_arguments(parser, False)
deepspeed_utils.add_deepspeed_arguments(parser)


+ 0
- 1
scripts/dev/sdxl_train_network.py View File

@@ -24,7 +24,6 @@ class SdxlNetworkTrainer(train_network.NetworkTrainer):
self.is_sdxl = True

def assert_extra_args(self, args, train_dataset_group: Union[train_util.DatasetGroup, train_util.MinimalDataset], val_dataset_group: Optional[train_util.DatasetGroup]):
super().assert_extra_args(args, train_dataset_group, val_dataset_group)
sdxl_train_util.verify_sdxl_training_args(args)

if args.cache_text_encoder_outputs:


+ 2
- 0
scripts/dev/tools/cache_latents.py View File

@@ -12,6 +12,7 @@ from tqdm import tqdm
from library import config_util, flux_train_utils, flux_utils, strategy_base, strategy_flux, strategy_sd, strategy_sdxl
from library import train_util
from library import sdxl_train_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -161,6 +162,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)


+ 2
- 0
scripts/dev/tools/cache_text_encoder_outputs.py View File

@@ -22,6 +22,7 @@ from library import (
from library import train_util
from library import sdxl_train_util
from library import utils
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -188,6 +189,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_training_arguments(parser, True)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_masked_loss_arguments(parser)


+ 4
- 7
scripts/dev/tools/detect_face_rotate.py View File

@@ -15,7 +15,7 @@ import os
from anime_face_detector import create_detector
from tqdm import tqdm
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -170,12 +170,9 @@ def process(args):
scale = max(cur_crop_width / w, cur_crop_height / h)

if scale != 1.0:
w = int(w * scale + .5)
h = int(h * scale + .5)
if scale < 1.0:
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
else:
face_img = pil_resize(face_img, (w, h))
rw = int(w * scale + .5)
rh = int(h * scale + .5)
face_img = resize_image(face_img, w, h, rw, rh)
cx = int(cx * scale + .5)
cy = int(cy * scale + .5)
fw = int(fw * scale + .5)


+ 4
- 16
scripts/dev/tools/resize_images_to_resolution.py View File

@@ -6,7 +6,7 @@ import shutil
import math
from PIL import Image
import numpy as np
from library.utils import setup_logging, pil_resize
from library.utils import setup_logging, resize_image
setup_logging()
import logging
logger = logging.getLogger(__name__)
@@ -22,14 +22,6 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
if not os.path.exists(dst_img_folder):
os.makedirs(dst_img_folder)

# Select interpolation method
if interpolation == 'lanczos4':
pil_interpolation = Image.LANCZOS
elif interpolation == 'cubic':
pil_interpolation = Image.BICUBIC
else:
cv2_interpolation = cv2.INTER_AREA

# Iterate through all files in src_img_folder
img_exts = (".png", ".jpg", ".jpeg", ".webp", ".bmp") # copy from train_util.py
for filename in os.listdir(src_img_folder):
@@ -63,11 +55,7 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
new_height = int(img.shape[0] * math.sqrt(scale_factor))
new_width = int(img.shape[1] * math.sqrt(scale_factor))

# Resize image
if cv2_interpolation:
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
else:
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
img = resize_image(img, img.shape[0], img.shape[1], new_height, new_width, interpolation)
else:
new_height, new_width = img.shape[0:2]

@@ -113,8 +101,8 @@ def setup_parser() -> argparse.ArgumentParser:
help='Maximum resolution(s) in the format "512x512,384x384, etc, etc" / 最大画像サイズをカンマ区切りで指定 ("512x512,384x384, etc, etc" など)', default="512x512,384x384,256x256,128x128")
parser.add_argument('--divisible_by', type=int,
help='Ensure new dimensions are divisible by this value / リサイズ後の画像のサイズをこの値で割り切れるようにします', default=1)
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4'],
default='area', help='Interpolation method for resizing / リサイズ時の補完方法')
parser.add_argument('--interpolation', type=str, choices=['area', 'cubic', 'lanczos4', 'nearest', 'linear', 'box'],
default=None, help='Interpolation method for resizing. Default to area if smaller, lanczos if larger / サイズ変更の補間方法。小さい場合はデフォルトでエリア、大きい場合はランチョスになります。')
parser.add_argument('--save_as_png', action='store_true', help='Save as png format / png形式で保存')
parser.add_argument('--copy_associated_files', action='store_true',
help='Copy files with same base name to images (captions etc) / 画像と同じファイル名(拡張子を除く)のファイルもコピーする')


+ 1
- 0
scripts/dev/train_control_net.py View File

@@ -25,6 +25,7 @@ from safetensors.torch import load_file
import library.model_util as model_util
import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,


+ 2
- 0
scripts/dev/train_db.py View File

@@ -22,6 +22,7 @@ from diffusers import DDPMScheduler

import library.train_util as train_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -512,6 +513,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, False, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)


+ 69
- 12
scripts/dev/train_network.py View File

@@ -24,7 +24,7 @@ from accelerate.utils import set_seed
from accelerate import Accelerator
from diffusers import DDPMScheduler
from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
from library import deepspeed_utils, model_util, sai_model_spec, strategy_base, strategy_sd, sai_model_spec

import library.train_util as train_util
from library.train_util import DreamBoothDataset
@@ -69,13 +69,20 @@ class NetworkTrainer:
keys_scaled=None,
mean_norm=None,
maximum_norm=None,
mean_grad_norm=None,
mean_combined_norm=None,
):
logs = {"loss/current": current_loss, "loss/average": avr_loss}

if keys_scaled is not None:
logs["max_norm/keys_scaled"] = keys_scaled
logs["max_norm/average_key_norm"] = mean_norm
logs["max_norm/max_key_norm"] = maximum_norm
if mean_norm is not None:
logs["norm/avg_key_norm"] = mean_norm
if mean_grad_norm is not None:
logs["norm/avg_grad_norm"] = mean_grad_norm
if mean_combined_norm is not None:
logs["norm/avg_combined_norm"] = mean_combined_norm

lrs = lr_scheduler.get_last_lr()
for i, lr in enumerate(lrs):
@@ -168,7 +175,7 @@ class NetworkTrainer:
if val_dataset_group is not None:
val_dataset_group.verify_bucket_reso_steps(64)

def load_target_model(self, args, weight_dtype, accelerator):
def load_target_model(self, args, weight_dtype, accelerator) -> tuple:
text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator)

# モデルに xformers とか memory efficient attention を組み込む
@@ -382,7 +389,18 @@ class NetworkTrainer:
latents = typing.cast(torch.FloatTensor, batch["latents"].to(accelerator.device))
else:
# latentに変換
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
if args.vae_batch_size is None or len(batch["images"]) <= args.vae_batch_size:
latents = self.encode_images_to_latents(args, vae, batch["images"].to(accelerator.device, dtype=vae_dtype))
else:
chunks = [
batch["images"][i : i + args.vae_batch_size] for i in range(0, len(batch["images"]), args.vae_batch_size)
]
list_latents = []
for chunk in chunks:
with torch.no_grad():
chunk = self.encode_images_to_latents(args, vae, chunk.to(accelerator.device, dtype=vae_dtype))
list_latents.append(chunk)
latents = torch.cat(list_latents, dim=0)

# NaNが含まれていれば警告を表示し0に置き換える
if torch.any(torch.isnan(latents)):
@@ -396,12 +414,13 @@ class NetworkTrainer:
if text_encoder_outputs_list is not None:
text_encoder_conds = text_encoder_outputs_list # List of text encoder outputs


if len(text_encoder_conds) == 0 or text_encoder_conds[0] is None or train_text_encoder:
# TODO this does not work if 'some text_encoders are trained' and 'some are not and not cached'
with torch.set_grad_enabled(is_train and train_text_encoder), accelerator.autocast():
# Get the text embedding for conditioning
if args.weighted_captions:
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch["captions"])
input_ids_list, weights_list = tokenize_strategy.tokenize_with_weights(batch['captions'])
encoded_text_encoder_conds = text_encoding_strategy.encode_tokens_with_weights(
tokenize_strategy,
self.get_models_for_text_encoding(args, accelerator, text_encoders),
@@ -626,7 +645,7 @@ class NetworkTrainer:
net_kwargs = {}
if args.network_args is not None:
for net_arg in args.network_args:
key, value = net_arg.split("=")
key, value = net_arg.split("=", 1)
net_kwargs[key] = value

# if a new network is added in future, add if ~ then blocks for each network (;'∀')
@@ -651,6 +670,10 @@ class NetworkTrainer:
return
network_has_multiplier = hasattr(network, "set_multiplier")

# TODO remove `hasattr`s by setting up methods if not defined in the network like (hacky but works):
# if not hasattr(network, "prepare_network"):
# network.prepare_network = lambda args: None

if hasattr(network, "prepare_network"):
network.prepare_network(args)
if args.scale_weight_norms and not hasattr(network, "apply_max_norm_regularization"):
@@ -1017,6 +1040,7 @@ class NetworkTrainer:
"ss_max_validation_steps": args.max_validation_steps,
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
"ss_validate_every_n_steps": args.validate_every_n_steps,
"ss_resize_interpolation": args.resize_interpolation,
}

self.update_metadata(metadata, args) # architecture specific metadata
@@ -1042,6 +1066,7 @@ class NetworkTrainer:
"max_bucket_reso": dataset.max_bucket_reso,
"tag_frequency": dataset.tag_frequency,
"bucket_info": dataset.bucket_info,
"resize_interpolation": dataset.resize_interpolation,
}

subsets_metadata = []
@@ -1059,6 +1084,7 @@ class NetworkTrainer:
"enable_wildcard": bool(subset.enable_wildcard),
"caption_prefix": subset.caption_prefix,
"caption_suffix": subset.caption_suffix,
"resize_interpolation": subset.resize_interpolation,
}

image_dir_or_metadata_file = None
@@ -1400,6 +1426,11 @@ class NetworkTrainer:
params_to_clip = accelerator.unwrap_model(network).get_trainable_params()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

if hasattr(network, "update_grad_norms"):
network.update_grad_norms()
if hasattr(network, "update_norms"):
network.update_norms()

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
@@ -1408,9 +1439,25 @@ class NetworkTrainer:
keys_scaled, mean_norm, maximum_norm = accelerator.unwrap_model(network).apply_max_norm_regularization(
args.scale_weight_norms, accelerator.device
)
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {"Keys Scaled": keys_scaled, "Average key norm": mean_norm}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
if hasattr(network, "weight_norms"):
weight_norms = network.weight_norms()
mean_norm = weight_norms.mean().item() if weight_norms is not None else None
grad_norms = network.grad_norms()
mean_grad_norm = grad_norms.mean().item() if grad_norms is not None else None
combined_weight_norms = network.combined_weight_norms()
mean_combined_norm = combined_weight_norms.mean().item() if combined_weight_norms is not None else None
maximum_norm = weight_norms.max().item() if weight_norms is not None else None
keys_scaled = None
max_mean_logs = {}
else:
keys_scaled, mean_norm, maximum_norm = None, None, None
mean_grad_norm = None
mean_combined_norm = None
max_mean_logs = {}

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -1421,6 +1468,7 @@ class NetworkTrainer:
self.sample_images(
accelerator, args, None, global_step, accelerator.device, vae, tokenizers, text_encoder, unet
)
progress_bar.unpause()

# 指定ステップごとにモデルを保存
if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0:
@@ -1442,14 +1490,21 @@ class NetworkTrainer:
loss_recorder.add(epoch=epoch, step=step, loss=current_loss)
avr_loss: float = loss_recorder.moving_average
logs = {"avr_loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if args.scale_weight_norms:
progress_bar.set_postfix(**{**max_mean_logs, **logs})
progress_bar.set_postfix(**{**max_mean_logs, **logs})

if is_tracking:
logs = self.generate_step_logs(
args, current_loss, avr_loss, lr_scheduler, lr_descriptions, optimizer, keys_scaled, mean_norm, maximum_norm
args,
current_loss,
avr_loss,
lr_scheduler,
lr_descriptions,
optimizer,
keys_scaled,
mean_norm,
maximum_norm,
mean_grad_norm,
mean_combined_norm,
)
self.step_logging(accelerator, logs, global_step, epoch + 1)

@@ -1627,6 +1682,7 @@ class NetworkTrainer:
train_util.save_and_remove_state_on_epoch_end(args, accelerator, epoch + 1)

self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizers, text_encoder, unet)
progress_bar.unpause()
optimizer_train_fn()

# end of epoch
@@ -1655,6 +1711,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, True)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)


+ 2
- 1
scripts/dev/train_textual_inversion.py View File

@@ -16,7 +16,7 @@ init_ipex()
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from transformers import CLIPTokenizer
from library import deepspeed_utils, model_util, strategy_base, strategy_sd
from library import deepspeed_utils, model_util, strategy_base, strategy_sd, sai_model_spec

import library.train_util as train_util
import library.huggingface_util as huggingface_util
@@ -771,6 +771,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)


+ 2
- 0
scripts/dev/train_textual_inversion_XTI.py View File

@@ -21,6 +21,7 @@ import library
import library.train_util as train_util
import library.huggingface_util as huggingface_util
import library.config_util as config_util
import library.sai_model_spec as sai_model_spec
from library.config_util import (
ConfigSanitizer,
BlueprintGenerator,
@@ -668,6 +669,7 @@ def setup_parser() -> argparse.ArgumentParser:

add_logging_arguments(parser)
train_util.add_sd_models_arguments(parser)
sai_model_spec.add_model_spec_arguments(parser)
train_util.add_dataset_arguments(parser, True, True, False)
train_util.add_training_arguments(parser, True)
train_util.add_masked_loss_arguments(parser)


Loading…
Cancel
Save
Baidu
map