"""
Overview:
This module provides utilities for image tagging using WD14 taggers.
It includes functions for loading models, processing images, and extracting tags.
The module is inspired by the `SmilingWolf/wd-v1-4-tags <https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags>`_
project on Hugging Face.
"""
from typing import List, Tuple
import numpy as np
import onnxruntime
import pandas as pd
from PIL import Image
from hbutils.testing.requires.version import VersionInfo
from huggingface_hub import hf_hub_download
from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..utils import open_onnx_model, vreplace, sigmoid, ts_lru_cache
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
MOAT_MODEL_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2"
CONV_V3_MODEL_REPO = 'SmilingWolf/wd-convnext-tagger-v3'
SWIN_V3_MODEL_REPO = 'SmilingWolf/wd-swinv2-tagger-v3'
VIT_V3_MODEL_REPO = 'SmilingWolf/wd-vit-tagger-v3'
VIT_LARGE_MODEL_REPO = 'SmilingWolf/wd-vit-large-tagger-v3'
EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"
_IS_V3_SUPPORT = VersionInfo(onnxruntime.__version__) >= '1.17'
MODEL_NAMES = {
"EVA02_Large": EVA02_LARGE_MODEL_DSV3_REPO,
"ViT_Large": VIT_LARGE_MODEL_REPO,
"SwinV2": SWIN_MODEL_REPO,
"ConvNext": CONV_MODEL_REPO,
"ConvNextV2": CONV2_MODEL_REPO,
"ViT": VIT_MODEL_REPO,
"MOAT": MOAT_MODEL_REPO,
"SwinV2_v3": SWIN_V3_MODEL_REPO,
"ConvNext_v3": CONV_V3_MODEL_REPO,
"ViT_v3": VIT_V3_MODEL_REPO,
}
_DEFAULT_MODEL_NAME = 'SwinV2_v3'
def _version_support_check(model_name):
"""
Check if the current onnxruntime version supports the given model.
:param model_name: The name of the model to check.
:type model_name: str
:raises EnvironmentError: If the model is not supported by the current onnxruntime version.
"""
if model_name.endswith('_v3') and not _IS_V3_SUPPORT:
raise EnvironmentError(f'V3 taggers not supported on onnxruntime {onnxruntime.__version__}, '
f'please upgrade it to 1.17+ version.\n'
f'If you are running on CPU, use "pip install -U onnxruntime" .\n'
f'If you are running on GPU, use "pip install -U onnxruntime-gpu" .') # pragma: no cover
@ts_lru_cache()
def _get_wd14_model(model_name):
"""
Load an ONNX model from the Hugging Face Hub.
:param model_name: The name of the model to load.
:type model_name: str
:return: The loaded ONNX model.
:rtype: ONNXModel
"""
_version_support_check(model_name)
return open_onnx_model(hf_hub_download(
repo_id='deepghs/wd14_tagger_with_embeddings',
filename=f'{MODEL_NAMES[model_name]}/model.onnx',
))
@ts_lru_cache()
def _get_wd14_weights(model_name):
"""
Load the weights for a WD14 model.
:param model_name: The name of the model.
:type model_name: str
:return: The loaded weights.
:rtype: numpy.ndarray
"""
_version_support_check(model_name)
return np.load(hf_hub_download(
repo_id='deepghs/wd14_tagger_with_embeddings',
filename=f'{MODEL_NAMES[model_name]}/inv.npz',
))
@ts_lru_cache()
def _get_wd14_labels(model_name, no_underline: bool = False) -> Tuple[List[str], List[int], List[int], List[int]]:
"""
Get labels for the WD14 model.
:param model_name: The name of the model.
:type model_name: str
:param no_underline: If True, replaces underscores in tag names with spaces.
:type no_underline: bool
:return: A tuple containing the list of tag names, and lists of indexes for rating, general, and character categories.
:rtype: Tuple[List[str], List[int], List[int], List[int]]
"""
path = hf_hub_download(MODEL_NAMES[model_name], LABEL_FILENAME)
df = pd.read_csv(path)
name_series = df["name"]
if no_underline:
name_series = name_series.map(remove_underline)
tag_names = name_series.tolist()
rating_indexes = list(np.where(df["category"] == 9)[0])
general_indexes = list(np.where(df["category"] == 0)[0])
character_indexes = list(np.where(df["category"] == 4)[0])
return tag_names, rating_indexes, general_indexes, character_indexes
def _mcut_threshold(probs) -> float:
"""
Compute the Maximum Cut Thresholding (MCut) for multi-label classification.
This method is based on the paper:
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy
for Multi-label Classification. In 11th International Symposium, IDA 2012
(pp. 172-183).
:param probs: Array of probabilities.
:type probs: numpy.ndarray
:return: The computed threshold.
:rtype: float
"""
sorted_probs = probs[probs.argsort()[::-1]]
difs = sorted_probs[:-1] - sorted_probs[1:]
t = difs.argmax()
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
return thresh
def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
"""
Prepare an image for tagging by resizing and padding it.
:param image: The input image.
:type image: ImageTyping
:param target_size: The target size for the image.
:type target_size: int
:return: The prepared image as a numpy array.
:rtype: numpy.ndarray
"""
image = load_image(image, force_background=None, mode=None)
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2
padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
try:
padded_image.paste(image, (pad_left, pad_top), mask=image)
except ValueError:
padded_image.paste(image, (pad_left, pad_top))
if max_dim != target_size:
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
image_array = np.asarray(padded_image, dtype=np.float32)
image_array = image_array[:, :, ::-1]
return np.expand_dims(image_array, axis=0)
def _postprocess_embedding(
pred, embedding,
model_name: str = _DEFAULT_MODEL_NAME,
general_threshold: float = 0.35,
general_mcut_enabled: bool = False,
character_threshold: float = 0.85,
character_mcut_enabled: bool = False,
no_underline: bool = False,
drop_overlap: bool = False,
fmt=('rating', 'general', 'character'),
):
"""
Post-process the embedding and prediction results.
:param pred: The prediction array.
:type pred: numpy.ndarray
:param embedding: The embedding array.
:type embedding: numpy.ndarray
:param model_name: The name of the model used.
:type model_name: str
:param general_threshold: Threshold for general tags.
:type general_threshold: float
:param general_mcut_enabled: Whether to use MCut for general tags.
:type general_mcut_enabled: bool
:param character_threshold: Threshold for character tags.
:type character_threshold: float
:param character_mcut_enabled: Whether to use MCut for character tags.
:type character_mcut_enabled: bool
:param no_underline: Whether to remove underscores from tag names.
:type no_underline: bool
:param drop_overlap: Whether to drop overlapping tags.
:type drop_overlap: bool
:param fmt: The format of the output.
:return: The post-processed results.
"""
tag_names, rating_indexes, general_indexes, character_indexes = _get_wd14_labels(model_name, no_underline)
labels = list(zip(tag_names, pred.astype(float)))
rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}
general_names = [labels[i] for i in general_indexes]
if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
general_threshold = _mcut_threshold(general_probs)
general_res = {x: v.item() for x, v in general_names if v > general_threshold}
if drop_overlap:
general_res = drop_overlap_tags(general_res)
character_names = [labels[i] for i in character_indexes]
if character_mcut_enabled:
character_probs = np.array([x[1] for x in character_names])
character_threshold = _mcut_threshold(character_probs)
character_threshold = max(0.15, character_threshold)
character_res = {x: v.item() for x, v in character_names if v > character_threshold}
return vreplace(
fmt,
{
'rating': rating,
'general': general_res,
'character': character_res,
'tag': {**general_res, **character_res},
'embedding': embedding.astype(np.float32),
'prediction': pred.astype(np.float32),
}
)
[docs]def convert_wd14_emb_to_prediction(
emb: np.ndarray,
model_name: str = _DEFAULT_MODEL_NAME,
general_threshold: float = 0.35,
general_mcut_enabled: bool = False,
character_threshold: float = 0.85,
character_mcut_enabled: bool = False,
no_underline: bool = False,
drop_overlap: bool = False,
fmt=('rating', 'general', 'character'),
):
"""
Convert WD14 embedding to understandable prediction result.
:param emb: The 1-dim extracted embedding.
:type emb: numpy.ndarray
:param model_name: The name of the model to use.
:type model_name: str
:param general_threshold: The threshold for general tags.
:type general_threshold: float
:param general_mcut_enabled: If True, applies MCut thresholding to general tags.
:type general_mcut_enabled: bool
:param character_threshold: The threshold for character tags.
:type character_threshold: float
:param character_mcut_enabled: If True, applies MCut thresholding to character tags.
:type character_mcut_enabled: bool
:param no_underline: If True, replaces underscores in tag names with spaces.
:type no_underline: bool
:param drop_overlap: If True, drops overlapping tags.
:type drop_overlap: bool
:param fmt: Return format, default is ``('rating', 'general', 'character')``.
:return: Prediction result based on the provided fmt.
.. note::
Only the embeddings not get normalized can be converted to understandable prediction result.
Example:
>>> import os
>>> from imgutils.tagging import get_wd14_tags, convert_wd14_emb_to_prediction
>>>
>>> # extract the feature embedding
>>> embedding = get_wd14_tags('skadi.jpg', fmt='embedding')
>>>
>>> # convert to understandable result
>>> rating, general, character = convert_wd14_emb_to_prediction(embedding)
>>> # these 3 dicts will be the same as that returned by `get_wd14_tags('skadi.jpg')`
"""
z_weights = _get_wd14_weights(model_name)
weights, bias = z_weights['weights'], z_weights['bias']
pred = sigmoid(emb @ weights + bias)
return _postprocess_embedding(
pred=pred,
embedding=emb,
model_name=model_name,
general_threshold=general_threshold,
general_mcut_enabled=general_mcut_enabled,
character_threshold=character_threshold,
character_mcut_enabled=character_mcut_enabled,
no_underline=no_underline,
drop_overlap=drop_overlap,
fmt=fmt,
)