6 Commits

Author SHA1 Message Date
  Akegarasu e0f5194815
add prompt_file, close #680 6 months ago
  秋葉杏 2e956c3fbc
update bitsandbytes 6 months ago
  Rnglg2 ebf5889815
新增lumina2-lora训练的部分支持 (#688) 7 months ago
  Akegarasu b0bfb582ca
close #684 7 months ago
  Akegarasu a82077d54a
update to cu128 7 months ago
  Akegarasu 9c451919b1
update bitsandbytes for cu128 7 months ago
9 changed files with 157 additions and 27 deletions
Split View
  1. +2
    -2
      install-cn.ps1
  2. +3
    -3
      install.bash
  3. +2
    -2
      install.ps1
  4. +19
    -13
      mikazuki/app/api.py
  5. +20
    -4
      mikazuki/launch_utils.py
  6. +99
    -0
      mikazuki/schema/lumina2-lora.ts
  7. +1
    -0
      mikazuki/schema/shared.ts
  8. +10
    -2
      mikazuki/utils/train_utils.py
  9. +1
    -1
      requirements.txt

+ 2
- 2
install-cn.ps1 View File

@@ -38,9 +38,9 @@ Write-Output "
Write-Output "受限于国内加速镜像,torch 安装无法使用镜像源,安装较为缓慢。"
$install_torch = Read-Host "是否需要安装 Torch+xformers? [y/n] (默认为 y)"
if ($install_torch -eq "y" -or $install_torch -eq "Y" -or $install_torch -eq "") {
python -m pip install torch==2.4.1+cu124 torchvision==0.19.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
python -m pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --index-url https://download.pytorch.org/whl/cu128
Check "torch 安装失败,请删除 venv 文件夹后重新运行。"
python -m pip install -U -I --no-deps xformers===0.0.28.post1 --extra-index-url https://download.pytorch.org/whl/cu124
python -m pip install -U -I --no-deps xformers===0.0.30 --extra-index-url https://download.pytorch.org/whl/cu128
Check "xformers 安装失败。"
}



+ 3
- 3
install.bash View File

@@ -36,9 +36,9 @@ echo "CUDA Version: $cuda_version"


if (( cuda_major_version >= 12 )); then
echo "install torch 2.4.1+cu124"
pip install torch==2.4.1+cu124 torchvision==0.19.1+cu124 --extra-index-url https://download.pytorch.org/whl/cu124
pip install --no-deps xformers==0.0.28.post1 --extra-index-url https://download.pytorch.org/whl/cu124
echo "install torch 2.7.0+cu128"
pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
pip install --no-deps xformers==0.0.30 --extra-index-url https://download.pytorch.org/whl/cu128
elif (( cuda_major_version == 11 && cuda_minor_version >= 8 )); then
echo "install torch 2.4.0+cu118"
pip install torch==2.4.0+cu118 torchvision==0.19.0+cu118 --extra-index-url https://download.pytorch.org/whl/cu118


+ 2
- 2
install.ps1 View File

@@ -8,8 +8,8 @@ if (!(Test-Path -Path "venv")) {

Write-Output "Installing deps..."

pip install torch==2.4.1 torchvision==0.19.1 --extra-index-url https://download.pytorch.org/whl/cu124
pip install -U -I --no-deps xformers==0.0.28.post1 --extra-index-url https://download.pytorch.org/whl/cu124
pip install torch==2.7.0+cu128 torchvision==0.22.0+cu128 --extra-index-url https://download.pytorch.org/whl/cu128
pip install -U -I --no-deps xformers==0.0.30 --extra-index-url https://download.pytorch.org/whl/cu128
pip install --upgrade -r requirements.txt

Write-Output "Install completed"


+ 19
- 13
mikazuki/app/api.py View File

@@ -141,19 +141,25 @@ async def create_toml_file(request: Request):
if not validated:
return APIResponseFail(message=message)

try:
positive_prompt, sample_prompts_arg = get_sample_prompts(config=config)

if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
f.write(sample_prompts_arg)

config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote prompts to file {sample_prompts_file}")
except ValueError as e:
log.error(f"Error while processing prompts: {e}")
return APIResponseFail(message=str(e))
if "prompt_file" in config and config["prompt_file"].strip() != "":
prompt_file = config["prompt_file"].strip()
if not os.path.exists(prompt_file):
return APIResponseFail(message=f"Prompt 文件 {prompt_file} 不存在,请检查路径。")
config["sample_prompts"] = prompt_file
else:
try:
positive_prompt, sample_prompts_arg = get_sample_prompts(config=config)

if positive_prompt is not None and train_utils.is_promopt_like(sample_prompts_arg):
sample_prompts_file = os.path.join(os.getcwd(), f"config", "autosave", f"{timestamp}-promopt.txt")
with open(sample_prompts_file, "w", encoding="utf-8") as f:
f.write(sample_prompts_arg)
config["sample_prompts"] = sample_prompts_file
log.info(f"Wrote prompts to file {sample_prompts_file}")

except ValueError as e:
log.error(f"Error while processing prompts: {e}")
return APIResponseFail(message=str(e))

with open(toml_file, "w", encoding="utf-8") as f:
f.write(toml.dumps(config))


+ 20
- 4
mikazuki/launch_utils.py View File

@@ -202,10 +202,10 @@ def setup_windows_bitsandbytes():
return

# bnb_windows_index = os.environ.get("BNB_WINDOWS_INDEX", "https://jihulab.com/api/v4/projects/140618/packages/pypi/simple")
bnb_package = "bitsandbytes==0.44.0"
bnb_package = "bitsandbytes==0.46.0"
bnb_path = os.path.join(sysconfig.get_paths()["purelib"], "bitsandbytes")

installed_bnb = is_installed(bnb_package)
installed_bnb = is_installed("bitsandbytes") # don't check version here
bnb_cuda_setup = len([f for f in os.listdir(bnb_path) if re.findall(r"libbitsandbytes_cuda.+?\.dll", f)]) != 0

if not installed_bnb or not bnb_cuda_setup:
@@ -259,6 +259,22 @@ def check_run(file: str) -> bool:
return result.returncode == 0


def network_gfw_test(timeout=3):
try:
import requests
# requests will auto detect system proxies
response = requests.get("https://www.google.com", timeout=timeout)
if response.status_code == 200:
log.info("Network test passed")
return True
else:
log.error(f"Network test failed: {response.status_code}")
return False
except requests.exceptions.RequestException as e:
log.error(f"Network test failed: {e}")
return False


def prepare_environment(disable_auto_mirror: bool = True):
if sys.platform == "win32":
# disable triton on windows
@@ -269,8 +285,8 @@ def prepare_environment(disable_auto_mirror: bool = True):
os.environ["PYTHONWARNINGS"] = "ignore::UserWarning"
os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"

if not disable_auto_mirror and locale.getdefaultlocale()[0] == "zh_CN":
log.info("detected locale zh_CN, use pip & huggingface mirrors")
if not disable_auto_mirror and not network_gfw_test():
log.info("use pip & huggingface mirrors")
os.environ.setdefault("PIP_FIND_LINKS", "https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html")
os.environ.setdefault("PIP_INDEX_URL", "https://pypi.tuna.tsinghua.edu.cn/simple")
os.environ.setdefault("HF_ENDPOINT", "https://hf-mirror.com")


+ 99
- 0
mikazuki/schema/lumina2-lora.ts View File

@@ -0,0 +1,99 @@
//使用sd-script的配置
Schema.intersect([
Schema.object({
model_train_type: Schema.string().default("lumina-lora").disabled().description("训练种类"),
pretrained_model_name_or_path: Schema.string().role('filepicker', { type: "model-file" }).default("./sd-models/model.safetensors").description("Lumina 模型路径"),
ae: Schema.string().role('filepicker', { type: "model-file" }).description("AE 模型文件路径"),
gemma2: Schema.string().role('filepicker', { type: "model-file" }).description("gemma2 模型文件路径"),
resume: Schema.string().role('filepicker', { type: "folder" }).description("从某个 `save_state` 保存的中断状态继续训练,填写文件路径"),
}).description("训练用模型"),

Schema.object({
timestep_sampling: Schema.union(["sigma", "uniform", "sigmoid", "shift", "nextdit_shift"]).default("nextdit_shift").description("时间步采样"),
sigmoid_scale: Schema.number().step(0.001).default(1.0).description("sigmoid 缩放"),
model_prediction_type: Schema.union(["raw", "additive", "sigma_scaled"]).default("raw").description("模型预测类型"),
discrete_flow_shift: Schema.number().step(0.001).default(3.185).description("Euler 调度器离散流位移"),
//loss_type: Schema.union(["l1", "l2", "huber", "smooth_l1"]).default("l2").description("损失函数类型"),
guidance_scale: Schema.number().step(0.01).default(1.0).description("CFG 引导缩放"),
use_flash_attn: Schema.boolean().default(false).description("是否使用 Flash Attention"),
cfg_trunc: Schema.number().step(0.01).default(0.25).description("CFG 截断"),
renorm_cfg: Schema.number().step(0.01).default(1.0).description("重归一化 CFG"),
system_prompt: Schema.string().default("You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> ").description("Gemma2b的系统提示"),
}).description("Lumina 专用参数"),

Schema.object(
UpdateSchema(SHARED_SCHEMAS.RAW.DATASET_SETTINGS, {
resolution: Schema.string().default("1024,1024").description("训练图片分辨率,宽x高。支持非正方形,但必须是 64 倍数。"),
enable_bucket: Schema.boolean().default(true).description("启用 arb 桶以允许非固定宽高比的图片"),
min_bucket_reso: Schema.number().default(256).description("arb 桶最小分辨率"),
max_bucket_reso: Schema.number().default(2048).description("arb 桶最大分辨率"),
bucket_reso_steps: Schema.number().default(64).description("arb 桶分辨率划分单位"),
})
).description("数据集设置"),

// 保存设置
SHARED_SCHEMAS.SAVE_SETTINGS,

Schema.object({
max_train_epochs: Schema.number().min(1).default(10).description("最大训练 epoch(轮数)"),
train_batch_size: Schema.number().min(1).default(2).description("批量大小, 越高显存占用越高"),
gradient_checkpointing: Schema.boolean().default(true).description("梯度检查点"),
gradient_accumulation_steps: Schema.number().min(1).default(1).description("梯度累加步数"),
network_train_unet_only: Schema.boolean().default(true).description("仅训练 U-Net"),
network_train_text_encoder_only: Schema.boolean().default(false).description("仅训练文本编码器"),
}).description("训练相关参数"),

// 学习率&优化器设置
SHARED_SCHEMAS.LR_OPTIMIZER,

Schema.intersect([
Schema.object({
network_module: Schema.union(["networks.lora_lumina", "networks.oft_lumina", "lycoris.kohya"]).default("networks.lora_lumina").description("训练网络模块"),
network_weights: Schema.string().role('filepicker').description("从已有的 LoRA 模型上继续训练,填写路径"),
network_dim: Schema.number().min(1).default(16).description("网络维度,常用 4~128,不是越大越好, 低dim可以降低显存占用"),
network_alpha: Schema.number().min(1).default(8).description("常用值:等于 network_dim 或 network_dim*1/2 或 1。使用较小的 alpha 需要提升学习率"),
network_dropout: Schema.number().step(0.01).default(0).description('dropout 概率 (与 lycoris 不兼容,需要用 lycoris 自带的)'),
scale_weight_norms: Schema.number().step(0.01).min(0).default(1.0).description("最大范数正则化。如果使用,推荐为 1"),
network_args_custom: Schema.array(String).role('table').description('自定义 network_args,一行一个'),
enable_base_weight: Schema.boolean().default(false).description('启用基础权重(差异炼丹)'),
}).description("网络设置"),

// lycoris 参数
SHARED_SCHEMAS.LYCORIS_MAIN,
SHARED_SCHEMAS.LYCORIS_LOKR,

SHARED_SCHEMAS.NETWORK_OPTION_BASEWEIGHT,
]),

// 预览图设置
SHARED_SCHEMAS.PREVIEW_IMAGE,

// 日志设置
SHARED_SCHEMAS.LOG_SETTINGS,

// caption 选项
Schema.object(UpdateSchema(SHARED_SCHEMAS.RAW.CAPTION_SETTINGS, {}, ["max_token_length"])).description("caption(Tag)选项"),

// 噪声设置
SHARED_SCHEMAS.NOISE_SETTINGS,

// 数据增强
SHARED_SCHEMAS.DATA_ENCHANCEMENT,

// 其他选项
SHARED_SCHEMAS.OTHER,

// 速度优化选项
Schema.object(
UpdateSchema(SHARED_SCHEMAS.RAW.PRECISION_CACHE_BATCH, {
fp8_base: Schema.boolean().default(false).description("对基础模型使用 FP8 精度"), // lumina 默认为 false
fp8_base_unet: Schema.boolean().default(false).description("仅对 U-Net 使用 FP8 精度(CLIP-L不使用)"), // lumina 默认为 false
sdpa: Schema.boolean().default(true).description("启用 sdpa"), // 脚本中未明确指定,但通常建议开启
cache_text_encoder_outputs: Schema.boolean().default(true).description("缓存文本编码器的输出,减少显存使用。使用时需要关闭 shuffle_caption"),
cache_text_encoder_outputs_to_disk: Schema.boolean().default(true).description("缓存文本编码器的输出到磁盘"),
}, ["xformers"])
).description("速度优化选项"),

// 分布式训练
SHARED_SCHEMAS.DISTRIBUTED_TRAINING
]);

+ 1
- 0
mikazuki/schema/shared.ts View File

@@ -182,6 +182,7 @@
Schema.object({
enable_preview: Schema.const(true).required(),
randomly_choice_prompt: Schema.boolean().default(false).description('随机选择预览图 Prompt'),
prompt_file: Schema.string().role('textarea').description('预览图 Prompt 文件路径。填写后将采用文件内的 prompt,而下方的选项将失效。'),
positive_prompts: Schema.string().role('textarea').default('masterpiece, best quality, 1girl, solo').description("Prompt"),
negative_prompts: Schema.string().role('textarea').default('lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry').description("Negative Prompt"),
sample_width: Schema.number().default(512).description('预览图宽'),


+ 10
- 2
mikazuki/utils/train_utils.py View File

@@ -19,10 +19,18 @@ class ModelType(Enum):
SDXL = 3
SD3 = 4
FLUX = 5
LUMINA = 6
LoRA = 10


MODEL_SIGNATURE = [
{
"type": ModelType.LUMINA,
"signature": [
"cap_embedder.0.weight",
"context_refiner.0.attention.k_norm.weight",
]
},
{
"type": ModelType.FLUX,
"signature": [
@@ -143,8 +151,8 @@ def validate_model(model_name: str, training_type: str = "sd-lora"):
if model_type == ModelType.UNKNOWN:
log.error(f"Can't match model type from {model_name}")

if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX]:
return False, "Pretrained model is not a Stable Diffusion or Flux checkpoint / 校验失败:底模不是 Stable Diffusion 或 Flux 模型"
if model_type not in [ModelType.SD15, ModelType.SD2, ModelType.SD3, ModelType.SDXL, ModelType.FLUX, ModelType.LUMINA]:
return False, "Pretrained model is not a Stable Diffusion, Flux or Lumina checkpoint / 校验失败:底模不是 Stable Diffusion, Flux 或 Lumina 模型"

if model_type == ModelType.SDXL and training_type == "sd-lora":
return False, "Pretrained model is SDXL, but you are training with SD1.5 LoRA / 校验失败:你选择的是 SD1.5 LoRA 训练,但预训练模型是 SDXL。请前往专家模式选择正确的模型种类。"


+ 1
- 1
requirements.txt View File

@@ -6,7 +6,7 @@ ftfy==6.1.1
opencv-python==4.8.1.78
einops==0.7.0
pytorch-lightning==1.9.0
bitsandbytes==0.44.0
bitsandbytes==0.46.0
lion-pytorch==0.1.2
schedulefree==1.4
pytorch-optimizer==3.5.0


Loading…
Cancel
Save
Baidu
map