5 Commits

Author SHA1 Message Date
  Akegarasu fdefba9799
fix 3 months ago
  Akegarasu 47129e02f0
fix #714 3 months ago
  Akegarasu 4e2fdf2915
update fronted 3 months ago
  Akegarasu 3c7802fc05
feat: support cl tagger 3 months ago
  Akegarasu 9a09518786
feat: chroma support & preset 3 months ago
10 changed files with 555 additions and 187 deletions
Split View
  1. +13
    -0
      config/presets/chroma.toml
  2. +1
    -1
      frontend
  3. +3
    -0
      mikazuki/app/api.py
  4. +7
    -0
      mikazuki/app/models.py
  5. +3
    -0
      mikazuki/schema/flux-lora.ts
  6. +20
    -185
      mikazuki/tagger/interrogator.py
  7. +108
    -0
      mikazuki/tagger/interrogators/base.py
  8. +268
    -0
      mikazuki/tagger/interrogators/cl.py
  9. +131
    -0
      mikazuki/tagger/interrogators/wd14.py
  10. +1
    -1
      requirements.txt

+ 13
- 0
config/presets/chroma.toml View File

@@ -0,0 +1,13 @@
[metadata]
name = "Chroma LoRA 训练"
version = "1.0"
author = "秋叶"
train_type = "flux-lora"
description = "这是一个样例模板,用于使用 Chroma LoRA 训练。"

[data]
model_type = "chroma"
apply_t5_attn_mask = true
timestep_sampling = "sigmoid"
model_prediction_type = "raw"
guidance_scale = 0.0

+ 1
- 1
frontend

@@ -1 +1 @@
Subproject commit a61f4b8b7409e4bf5a94bbdad18de1a89519487f
Subproject commit 6dcccc00709fb8cd68f47bd8636cdc2a02086830

+ 3
- 0
mikazuki/app/api.py View File

@@ -207,6 +207,9 @@ async def run_interrogate(req: TaggerInterrogateRequest, background_tasks: Backg
batch_output_save_json=False,
interrogator=interrogator,
threshold=req.threshold,
character_threshold=req.character_threshold,
add_rating_tag=req.add_rating_tag,
add_model_tag=req.add_model_tag,
additional_tags=req.additional_tags,
exclude_tags=req.exclude_tags,
sort_by_alphabetical_order=False,


+ 7
- 0
mikazuki/app/models.py View File

@@ -12,6 +12,13 @@ class TaggerInterrogateRequest(BaseModel):
ge=0,
le=1
)
character_threshold: float = Field(
default=0.6,
ge=0,
le=1
)
add_rating_tag: bool = False
add_model_tag: bool = False
additional_tags: str = ""
exclude_tags: str = ""
escape_tag: bool = True


+ 3
- 0
mikazuki/schema/flux-lora.ts View File

@@ -1,6 +1,7 @@
Schema.intersect([
Schema.object({
model_train_type: Schema.string().default("flux-lora").disabled().description("训练种类"),
model_type: Schema.union(["flux", "chroma"]).default("flux").description("FLUX 模型类型 (支持 Chroma)"),
pretrained_model_name_or_path: Schema.string().role('filepicker', { type: "model-file" }).default("./sd-models/model.safetensors").description("Flux 模型路径"),
ae: Schema.string().role('filepicker', { type: "model-file" }).description("AE 模型文件路径"),
clip_l: Schema.string().role('filepicker', { type: "model-file" }).description("clip_l 模型文件路径"),
@@ -17,6 +18,8 @@ Schema.intersect([
guidance_scale: Schema.number().step(0.01).default(1.0).description("CFG 引导缩放"),
t5xxl_max_token_length: Schema.number().step(1).description("T5XXL 最大 token 长度(不填写使用自动)"),
train_t5xxl: Schema.boolean().default(false).description("训练 T5XXL(不推荐)"),
// apply attention mask to T5-XXL encode and FLUX double blocks
apply_t5_attn_mask: Schema.boolean().default(true).description("对 T5-XXL 编码器和 FLUX double块 应用注意力掩码"),
}).description("Flux 专用参数"),

Schema.object(


+ 20
- 185
mikazuki/tagger/interrogator.py View File

@@ -14,193 +14,13 @@ from PIL import UnidentifiedImageError
from huggingface_hub import hf_hub_download

from mikazuki.tagger import dbimutils, format
from mikazuki.tagger.interrogators.base import Interrogator
from mikazuki.tagger.interrogators.wd14 import WaifuDiffusionInterrogator
from mikazuki.tagger.interrogators.cl import CLTaggerInterrogator

tag_escape_pattern = re.compile(r'([\\()])')


class Interrogator:
@staticmethod
def postprocess_tags(
tags: Dict[str, float],

threshold=0.35,
additional_tags: List[str] = [],
exclude_tags: List[str] = [],
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=False,
replace_underscore_excludes: List[str] = [],
escape_tag=False
) -> Dict[str, float]:
for t in additional_tags:
tags[t] = 1.0

# those lines are totally not "pythonic" but looks better to me
tags = {
t: c

# sort by tag name or confident
for t, c in sorted(
tags.items(),
key=lambda i: i[0 if sort_by_alphabetical_order else 1],
reverse=not sort_by_alphabetical_order
)

# filter tags
if (
c >= threshold
and t not in exclude_tags
)
}

new_tags = []
for tag in list(tags):
new_tag = tag

if replace_underscore and tag not in replace_underscore_excludes:
new_tag = new_tag.replace('_', ' ')

if escape_tag:
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)

if add_confident_as_weight:
new_tag = f'({new_tag}:{tags[tag]})'

new_tags.append((new_tag, tags[tag]))
tags = dict(new_tags)

return tags

def __init__(self, name: str) -> None:
self.name = name

def load(self):
raise NotImplementedError()

def unload(self) -> bool:
unloaded = False

if hasattr(self, 'model') and self.model is not None:
del self.model
unloaded = True
print(f'Unloaded {self.name}')

if hasattr(self, 'tags'):
del self.tags

return unloaded

def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidents
Dict[str, float] # tag confidents
]:
raise NotImplementedError()


class WaifuDiffusionInterrogator(Interrogator):
def __init__(
self,
name: str,
model_path='model.onnx',
tags_path='selected_tags.csv',
**kwargs
) -> None:
super().__init__(name)
self.model_path = model_path
self.tags_path = tags_path
self.kwargs = kwargs

def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")

model_path = Path(hf_hub_download(
**self.kwargs, filename=self.model_path))
tags_path = Path(hf_hub_download(
**self.kwargs, filename=self.tags_path))
return model_path, tags_path

def load(self) -> None:
model_path, tags_path = self.download()

# only one of these packages should be installed at a time in any one environment
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
# from mikazuki.launch_utils import is_installed, run_pip
# if not is_installed('onnxruntime'):
# package = os.environ.get(
# 'ONNXRUNTIME_PACKAGE',
# 'onnxruntime-gpu'
# )

# run_pip(f'install {package}', 'onnxruntime')

# Load torch to load cuda libs built in torch for onnxruntime, do not delete this.
import torch
from onnxruntime import InferenceSession

# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

self.model = InferenceSession(str(model_path), providers=providers)

print(f'Loaded {self.name} model from {model_path}')

self.tags = pd.read_csv(tags_path)

def interrogate(
self,
image: Image
) -> Tuple[
Dict[str, float], # rating confidents
Dict[str, float] # tag confidents
]:
# init model
if not hasattr(self, 'model') or self.model is None:
self.load()

# code for converting the image and running the model is taken from the link below
# thanks, SmilingWolf!
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py

# convert an image to fit the model
_, height, _, _ = self.model.get_inputs()[0].shape

# alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = np.asarray(image)

# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]

image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)

# evaluate model
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
confidents = self.model.run([label_name], {input_name: image})[0]

tags = self.tags[:][['name']]
tags['confidents'] = confidents[0]

# first 4 items are for rating (general, sensitive, questionable, explicit)
ratings = dict(tags[:4].values)

# rest are regular tags
tags = dict(tags[4:].values)

return ratings, tags


available_interrogators = {
'wd-convnext-v3': WaifuDiffusionInterrogator(
'wd-convnext-v3',
@@ -239,6 +59,12 @@ available_interrogators = {
'wd-vit-large-tagger-v3',
repo_id='SmilingWolf/wd-vit-large-tagger-v3',
),
'cl_tagger_1_01': CLTaggerInterrogator(
'cl_tagger_1_01',
repo_id='cella110n/cl_tagger',
model_path='cl_tagger_1_01/model.onnx',
tag_mapping_path='cl_tagger_1_01/tag_mapping.json',
),
}


@@ -257,7 +83,13 @@ def on_interrogate(
batch_output_save_json: bool,

interrogator: Interrogator,

threshold: float,
character_threshold: float,

add_rating_tag: bool,
add_model_tag: bool,

additional_tags: str,
exclude_tags: str,
sort_by_alphabetical_order: bool,
@@ -270,6 +102,9 @@ def on_interrogate(
):
postprocess_opts = (
threshold,
character_threshold,
add_rating_tag,
add_model_tag,
split_str(additional_tags),
split_str(exclude_tags),
sort_by_alphabetical_order,
@@ -361,7 +196,7 @@ def on_interrogate(
print(f'skipping {path}')
continue

ratings, tags = interrogator.interrogate(image)
tags = interrogator.interrogate(image)
processed_tags = Interrogator.postprocess_tags(
tags,
*postprocess_opts
@@ -398,7 +233,7 @@ def on_interrogate(

if batch_output_save_json:
output_path.with_suffix('.json').write_text(
json.dumps([ratings, tags])
json.dumps(tags)
)

print('all done / 识别完成')


+ 108
- 0
mikazuki/tagger/interrogators/base.py View File

@@ -0,0 +1,108 @@
import re
from typing import Dict, List, Tuple
from PIL import Image

tag_escape_pattern = re.compile(r'([\\()])')


class Interrogator:
@staticmethod
def postprocess_tags(
tags: Dict[str, List[Tuple[str, float]]],

threshold=0.35,
character_threshold=0.6,

add_rating_tag=False,
add_model_tag=False,

additional_tags: List[str] = [],
exclude_tags: List[str] = [],
sort_by_alphabetical_order=False,
add_confident_as_weight=False,
replace_underscore=False,
replace_underscore_excludes: List[str] = [],
escape_tag=False
) -> Dict[str, float]:

ok_tags = {}

if not add_rating_tag and 'rating' in tags:
del tags['rating']

if not add_model_tag and 'model' in tags:
del tags['model']

if 'character' in tags:
for t, c in tags['character']:
if c >= character_threshold:
ok_tags[t] = c

del tags['character']

for t in additional_tags:
ok_tags[t] = 1.0

for category in tags:
for t, c in tags[category]:
if c >= threshold:
ok_tags[t] = c

for e in exclude_tags:
del ok_tags[e]

if sort_by_alphabetical_order:
ok_tags = dict(sorted(ok_tags.items()))
# sort tag by confidence
else:
ok_tags = dict(sorted(ok_tags.items(), key=lambda item: item[1], reverse=True))

new_tags = []
for tag in list(ok_tags):
new_tag = tag

if replace_underscore and tag not in replace_underscore_excludes:
new_tag = new_tag.replace('_', ' ')

if escape_tag:
new_tag = tag_escape_pattern.sub(r'\\\1', new_tag)

if add_confident_as_weight:
new_tag = f'({new_tag}:{ok_tags[tag]})'

new_tags.append((new_tag, ok_tags[tag]))

return dict(new_tags)

def __init__(self, name: str) -> None:
self.name = name

def load(self):
raise NotImplementedError()

def unload(self) -> bool:
unloaded = False

if hasattr(self, 'model') and self.model is not None:
del self.model
unloaded = True
print(f'Unloaded {self.name}')

if hasattr(self, 'tags'):
del self.tags

return unloaded

def interrogate(
self,
image: Image
) -> Dict[str, List[Tuple[str, float]]]:
"""
Interrogate the given image and return tags with their confidence scores.
:param image: The input image to be interrogated.
:return: A dictionary with categories as keys and lists of (tag, confidence)

categories: "rating", "general", "character", "copyright", "artist", "meta", "quality", "model"
"""

raise NotImplementedError()

+ 268
- 0
mikazuki/tagger/interrogators/cl.py View File

@@ -0,0 +1,268 @@
import json
import os
import re
from collections import OrderedDict
from glob import glob
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from PIL import UnidentifiedImageError
from huggingface_hub import hf_hub_download
from dataclasses import dataclass
from mikazuki.tagger import dbimutils, format
from mikazuki.tagger.interrogators.base import Interrogator


@dataclass
class LabelData:
names: list[str]
rating: list[np.int64]
general: list[np.int64]
artist: list[np.int64]
character: list[np.int64]
copyright: list[np.int64]
meta: list[np.int64]
quality: list[np.int64]
model: list[np.int64]


def pil_ensure_rgb(image: Image.Image) -> Image.Image:
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
if image.mode == "RGBA":
background = Image.new("RGB", image.size, (255, 255, 255))
background.paste(image, mask=image.split()[3])
image = background
return image


def pil_pad_square(image: Image.Image) -> Image.Image:
width, height = image.size
if width == height:
return image
new_size = max(width, height)
new_image = Image.new(image.mode, (new_size, new_size), (255, 255, 255)) # Use image.mode
paste_position = ((new_size - width) // 2, (new_size - height) // 2)
new_image.paste(image, paste_position)
return new_image


def get_tags(probs, labels: LabelData):
result = {
"rating": [],
"general": [],
"character": [],
"copyright": [],
"artist": [],
"meta": [],
"quality": [],
"model": []
}
# Rating (select max)
if len(labels.rating) > 0:
valid_indices = labels.rating[labels.rating < len(probs)]
if len(valid_indices) > 0:
rating_probs = probs[valid_indices]
if len(rating_probs) > 0:
rating_idx_local = np.argmax(rating_probs)
rating_idx_global = valid_indices[rating_idx_local]
if rating_idx_global < len(labels.names) and labels.names[rating_idx_global] is not None:
rating_name = labels.names[rating_idx_global]
rating_conf = float(rating_probs[rating_idx_local])
result["rating"].append((rating_name, rating_conf))
else:
print(f"Warning: Invalid global index {rating_idx_global} for rating tag.")
else:
print("Warning: rating_probs became empty after filtering.")
else:
print("Warning: No valid indices found for rating tags within probs length.")

# Quality (select max)
if len(labels.quality) > 0:
valid_indices = labels.quality[labels.quality < len(probs)]
if len(valid_indices) > 0:
quality_probs = probs[valid_indices]
if len(quality_probs) > 0:
quality_idx_local = np.argmax(quality_probs)
quality_idx_global = valid_indices[quality_idx_local]
if quality_idx_global < len(labels.names) and labels.names[quality_idx_global] is not None:
quality_name = labels.names[quality_idx_global]
quality_conf = float(quality_probs[quality_idx_local])
result["quality"].append((quality_name, quality_conf))
else:
print(f"Warning: Invalid global index {quality_idx_global} for quality tag.")
else:
print("Warning: quality_probs became empty after filtering.")
else:
print("Warning: No valid indices found for quality tags within probs length.")

# All tags for each category (no threshold)
category_map = {
"general": labels.general,
"character": labels.character,
"copyright": labels.copyright,
"artist": labels.artist,
"meta": labels.meta,
"model": labels.model
}
for category, indices in category_map.items():
if len(indices) > 0:
valid_indices = indices[(indices < len(probs))]
if len(valid_indices) > 0:
category_probs = probs[valid_indices]
for idx_local, idx_global in enumerate(valid_indices):
if idx_global < len(labels.names) and labels.names[idx_global] is not None:
result[category].append((labels.names[idx_global], float(category_probs[idx_local])))
else:
print(f"Warning: Invalid global index {idx_global} for {category} tag.")

# Sort by probability (descending)
for k in result:
result[k] = sorted(result[k], key=lambda x: x[1], reverse=True)
return result


class CLTaggerInterrogator(Interrogator):
def __init__(
self,
name: str,
model_path='model.onnx',
tag_mapping_path='tag_mapping.json',
**kwargs
) -> None:
super().__init__(name)
self.model_path = model_path
self.tag_mapping_path = tag_mapping_path
self.kwargs = kwargs

def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")

model_path = Path(hf_hub_download(
**self.kwargs, filename=self.model_path))
tag_mapping_path = Path(hf_hub_download(
**self.kwargs, filename=self.tag_mapping_path))
return model_path, tag_mapping_path

def load(self) -> None:
model_path, tag_mapping_path = self.download()

import torch
from onnxruntime import InferenceSession

providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

self.model = InferenceSession(str(model_path), providers=providers)

print(f'Loaded {self.name} model from {model_path}')

self.tags = self.load_tag_mapping(tag_mapping_path)

def load_tag_mapping(self, mapping_path):
# Use the implementation from the original app.py as it was confirmed working
with open(mapping_path, 'r', encoding='utf-8') as f:
tag_mapping_data = json.load(f)
# Check format compatibility (can be dict of dicts or dict with idx_to_tag/tag_to_category)
if isinstance(tag_mapping_data, dict) and "idx_to_tag" in tag_mapping_data:
idx_to_tag = {int(k): v for k, v in tag_mapping_data["idx_to_tag"].items()}
tag_to_category = tag_mapping_data["tag_to_category"]
elif isinstance(tag_mapping_data, dict):
# Assuming the dict-of-dicts format from previous tests
try:
tag_mapping_data_int_keys = {int(k): v for k, v in tag_mapping_data.items()}
idx_to_tag = {idx: data['tag'] for idx, data in tag_mapping_data_int_keys.items()}
tag_to_category = {data['tag']: data['category'] for data in tag_mapping_data_int_keys.values()}
except (KeyError, ValueError) as e:
raise ValueError(f"Unsupported tag mapping format (dict): {e}. Expected int keys with 'tag' and 'category'.")
else:
raise ValueError("Unsupported tag mapping format: Expected a dictionary.")

names = [None] * (max(idx_to_tag.keys()) + 1)
rating, general, artist, character, copyright, meta, quality, model_name = [], [], [], [], [], [], [], []
for idx, tag in idx_to_tag.items():
if idx >= len(names):
names.extend([None] * (idx - len(names) + 1))
names[idx] = tag
category = tag_to_category.get(tag, 'Unknown') # Handle missing category mapping gracefully
idx_int = int(idx)
if category == 'Rating':
rating.append(idx_int)
elif category == 'General':
general.append(idx_int)
elif category == 'Artist':
artist.append(idx_int)
elif category == 'Character':
character.append(idx_int)
elif category == 'Copyright':
copyright.append(idx_int)
elif category == 'Meta':
meta.append(idx_int)
elif category == 'Quality':
quality.append(idx_int)
elif category == 'Model':
model_name.append(idx_int)

return LabelData(names=names, rating=np.array(rating, dtype=np.int64), general=np.array(general, dtype=np.int64), artist=np.array(artist, dtype=np.int64),
character=np.array(character, dtype=np.int64), copyright=np.array(copyright, dtype=np.int64), meta=np.array(meta, dtype=np.int64), quality=np.array(quality, dtype=np.int64), model=np.array(model_name, dtype=np.int64)), idx_to_tag, tag_to_category

def preprocess_image(self, image: Image.Image, target_size=(448, 448)):
# Adapted from onnx_predict.py's version
image = pil_ensure_rgb(image)
image = pil_pad_square(image)
image_resized = image.resize(target_size, Image.BICUBIC)
img_array = np.array(image_resized, dtype=np.float32) / 255.0
img_array = img_array.transpose(2, 0, 1) # HWC -> CHW
# Assuming model expects RGB based on original code, no BGR conversion here
img_array = img_array[::-1, :, :] # BGR conversion if needed - UNCOMMENTED based on user feedback
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32).reshape(3, 1, 1)
img_array = (img_array - mean) / std
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
return image, img_array

def interrogate(
self,
image: Image
) -> dict[str, list]:

# init model
if not hasattr(self, 'model') or self.model is None:
self.load()

input_name = self.model.get_inputs()[0].name
output_name = self.model.get_outputs()[0].name

original_pil_image, input_tensor = self.preprocess_image(image)
input_tensor = input_tensor.astype(np.float32)

outputs = self.model.run([output_name], {input_name: input_tensor})[0]

if np.isnan(outputs).any() or np.isinf(outputs).any():
print("Warning: NaN or Inf detected in model output. Clamping...")
outputs = np.nan_to_num(outputs, nan=0.0, posinf=1.0, neginf=0.0) # Clamp to 0-1 range

# Apply sigmoid (outputs are likely logits)
# Use a stable sigmoid implementation
def stable_sigmoid(x):
return 1 / (1 + np.exp(-np.clip(x, -30, 30))) # Clip to avoid overflow
probs = stable_sigmoid(outputs[0]) # Assuming batch size 1

predictions = get_tags(probs, self.tags[0]) # g_labels_data
# output_tags = []
# if predictions.get("rating"): output_tags.append(predictions["rating"][0][0].replace("_", " "))
# if predictions.get("quality"): output_tags.append(predictions["quality"][0][0].replace("_", " "))
# # Add other categories, respecting order and filtering meta if needed
# for category in ["artist", "character", "copyright", "general", "meta", "model"]:
# tags_in_category = predictions.get(category, [])
# for tag, prob in tags_in_category:
# # Basic meta tag filtering for text output
# if category == "meta" and any(p in tag.lower() for p in ['id', 'commentary', 'request', 'mismatch']):
# continue
# output_tags.append(tag.replace("_", " "))
# output_text = ", ".join(output_tags)

print(predictions)
return predictions

+ 131
- 0
mikazuki/tagger/interrogators/wd14.py View File

@@ -0,0 +1,131 @@
# from https://github.com/toriato/stable-diffusion-webui-wd14-tagger
import json
import os
import re
from collections import OrderedDict
from glob import glob
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
from PIL import Image
from PIL import UnidentifiedImageError
from huggingface_hub import hf_hub_download
from mikazuki.tagger.interrogators.base import Interrogator
from mikazuki.tagger import dbimutils, format


class WaifuDiffusionInterrogator(Interrogator):
def __init__(
self,
name: str,
model_path='model.onnx',
tags_path='selected_tags.csv',
**kwargs
) -> None:
super().__init__(name)
self.model_path = model_path
self.tags_path = tags_path
self.kwargs = kwargs

def download(self) -> Tuple[os.PathLike, os.PathLike]:
print(f"Loading {self.name} model file from {self.kwargs['repo_id']}")

model_path = Path(hf_hub_download(
**self.kwargs, filename=self.model_path))
tags_path = Path(hf_hub_download(
**self.kwargs, filename=self.tags_path))
return model_path, tags_path

def load(self) -> None:
model_path, tags_path = self.download()

# only one of these packages should be installed at a time in any one environment
# https://onnxruntime.ai/docs/get-started/with-python.html#install-onnx-runtime
# TODO: remove old package when the environment changes?
# from mikazuki.launch_utils import is_installed, run_pip
# if not is_installed('onnxruntime'):
# package = os.environ.get(
# 'ONNXRUNTIME_PACKAGE',
# 'onnxruntime-gpu'
# )

# run_pip(f'install {package}', 'onnxruntime')

# Load torch to load cuda libs built in torch for onnxruntime, do not delete this.
import torch
from onnxruntime import InferenceSession

# https://onnxruntime.ai/docs/execution-providers/
# https://github.com/toriato/stable-diffusion-webui-wd14-tagger/commit/e4ec460122cf674bbf984df30cdb10b4370c1224#r92654958
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']

self.model = InferenceSession(str(model_path), providers=providers)

print(f'Loaded {self.name} model from {model_path}')

self.tags = pd.read_csv(tags_path)

def interrogate(
self,
image: Image
) -> Dict[str, List[Tuple[str, float]]]:
# init model
if not hasattr(self, 'model') or self.model is None:
self.load()

# code for converting the image and running the model is taken from the link below
# thanks, SmilingWolf!
# https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags/blob/main/app.py

# convert an image to fit the model
_, height, _, _ = self.model.get_inputs()[0].shape

# alpha to white
image = image.convert('RGBA')
new_image = Image.new('RGBA', image.size, 'WHITE')
new_image.paste(image, mask=image)
image = new_image.convert('RGB')
image = np.asarray(image)

# PIL RGB to OpenCV BGR
image = image[:, :, ::-1]

image = dbimutils.make_square(image, height)
image = dbimutils.smart_resize(image, height)
image = image.astype(np.float32)
image = np.expand_dims(image, 0)

# evaluate model
input_name = self.model.get_inputs()[0].name
label_name = self.model.get_outputs()[0].name
confidents = self.model.run([label_name], {input_name: image})[0]

tags = self.tags[:][['name']]
tags['confidents'] = confidents[0]

# first 4 items are for rating (general, sensitive, questionable, explicit)
ratings = dict(tags[:4].values)

# rest are regular tags
tags = dict(tags[4:].values)

result = {
"rating": [],
"general": [],
"character": [],
"copyright": [],
"artist": [],
"meta": [],
"quality": [],
"model": []
}

for tag, conf in ratings.items():
result["rating"].append((tag, conf))

for tag, conf in tags.items():
result["general"].append((tag, conf))

return result

+ 1
- 1
requirements.txt View File

@@ -40,5 +40,5 @@ wandb==0.16.2
httpx==0.24.1
# extra
open-clip-torch==2.20.0
lycoris-lora==2.1.0.post3
lycoris-lora==3.2.0.post2
dadaptation==3.1

Loading…
Cancel
Save
Baidu
map