Source code for imgutils.tagging.pixai
"""
Overview:
This module provides utilities for image tagging using PixAI taggers, which are specialized models
for analyzing anime-style images and extracting relevant tags. The module supports loading ONNX
models from Hugging Face Hub and processing images to generate categorized tags with confidence scores.
The models are originally developed by the PixAI team and available at
`pixai-labs <https://huggingface.co/pixai-labs>`_ on Hugging Face. This module uses ONNX-converted
versions of these models for efficient inference, available at
`deepghs <https://huggingface.co/deepghs>`_ repositories.
In addition to standard tagging, the models can identify anime character IP (Intellectual Property)
associations. For example, if a character like "misaka_mikoto" is detected, the system can map
this to the "toaru_kagaku_no_railgun" (A Certain Scientific Railgun) IP. All IP names follow
Danbooru-style tag conventions for consistency with existing anime tagging systems.
Example::
>>> from imgutils.tagging.pixai import get_pixai_tags
>>> # Get tags with default thresholds
>>> result = get_pixai_tags('path/to/anime_image.jpg', model_name='v0.9')
>>> general_tags, character_tags = result
>>> print("General tags:", general_tags)
>>> print("Character tags:", character_tags)
>>> # Get all tags in a single dictionary
>>> all_tags = get_pixai_tags('path/to/image.jpg', fmt='tag')
>>> print("All tags:", all_tags)
>>> # Get IP information for detected characters
>>> ips = get_pixai_tags('path/to/image.jpg', fmt='ips')
>>> print("Detected IPs:", ips)
"""
import json
from collections import defaultdict
from typing import Union, Dict, Any, Tuple, List
import pandas as pd
from hbutils.design import SingletonMark
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import EntryNotFoundError
from imgutils.data import ImageTyping, load_image
from imgutils.preprocess import create_pillow_transforms
from imgutils.utils import open_onnx_model, ts_lru_cache, vreplace
FMT_UNSET = SingletonMark('FMT_UNSET')
def _get_repo_id(model_name: str) -> str:
"""
Get the repository ID for the specified model name.
:param model_name: Name of the model (e.g., 'v0.9') or full repository path
:type model_name: str
:return: Full repository ID for Hugging Face Hub
:rtype: str
Example::
>>> _get_repo_id('v0.9')
'deepghs/pixai-tagger-v0.9-onnx'
>>> _get_repo_id('custom/model-repo')
'custom/model-repo'
"""
if '/' in model_name:
return model_name
else:
return f'deepghs/pixai-tagger-{model_name}-onnx'
@ts_lru_cache()
def _open_onnx_model(model_name: str):
"""
Load the ONNX model from Hugging Face Hub with caching.
This function downloads and loads the ONNX model file for the specified PixAI tagger.
Results are cached to avoid repeated downloads and model loading.
:param model_name: Name of the model to load
:type model_name: str
:return: The loaded ONNX model session
:rtype: onnxruntime.InferenceSession
"""
return open_onnx_model(hf_hub_download(
repo_id=_get_repo_id(model_name),
repo_type='model',
filename='model.onnx',
))
@ts_lru_cache()
def _open_tags(model_name: str) -> Tuple[pd.DataFrame, Dict[str, List[str]]]:
"""
Load the tag metadata from Hugging Face Hub with caching.
This function downloads and loads the CSV file containing tag names, categories,
and IP (Intellectual Property) associations for the specified model. The DataFrame
contains columns for tag names, categories, and other metadata including character
IP mappings when available.
:param model_name: Name of the model
:type model_name: str
:return: Tuple containing (DataFrame with tag information, dictionary mapping character tags to their IPs)
:rtype: Tuple[pd.DataFrame, Dict[str, List[str]]]
"""
df_tags = pd.read_csv(hf_hub_download(
repo_id=_get_repo_id(model_name),
repo_type='model',
filename='selected_tags.csv',
))
d_ips = {}
if 'ips' in df_tags:
df_tags['ips'] = df_tags['ips'].map(json.loads)
for name, ips in zip(df_tags['name'], df_tags['ips']):
if ips:
d_ips[name] = ips
return df_tags, d_ips
@ts_lru_cache()
def _open_preprocess(model_name: str):
"""
Load the preprocessing pipeline configuration from Hugging Face Hub with caching.
This function downloads the preprocessing configuration and creates a PIL transforms
pipeline for image preprocessing before model inference.
:param model_name: Name of the model
:type model_name: str
:return: Preprocessing transform pipeline
"""
with open(hf_hub_download(
repo_id=_get_repo_id(model_name),
repo_type='model',
filename='preprocess.json'
), 'r') as f:
data_ = json.load(f)
return create_pillow_transforms(data_['stages'])
@ts_lru_cache()
def _open_default_category_thresholds(model_name: str) -> Tuple[Dict[int, float], Dict[int, str]]:
"""
Load default category thresholds and names from the Hugging Face Hub with caching.
This function attempts to load predefined threshold values for each category from
a CSV file. If the file doesn't exist, empty dictionaries are returned.
:param model_name: Name of the model
:type model_name: str
:return: Tuple containing (category_thresholds, category_names) dictionaries
:rtype: tuple[Dict[int, float], Dict[int, str]]
Example::
>>> thresholds, names = _open_default_category_thresholds('v0.9')
>>> print(thresholds) # {0: 0.35, 1: 0.4, ...}
>>> print(names) # {0: 'general', 1: 'character', ...}
"""
_default_category_thresholds: Dict[int, float] = {}
_category_names: Dict[int, str] = {}
try:
df_category_thresholds = pd.read_csv(hf_hub_download(
repo_id=_get_repo_id(model_name),
repo_type='model',
filename='thresholds.csv'
), keep_default_na=False)
except (EntryNotFoundError,):
pass
else:
for item in df_category_thresholds.to_dict('records'):
if item['category'] not in _default_category_thresholds:
_default_category_thresholds[item['category']] = item['threshold']
_category_names[item['category']] = item['name']
return _default_category_thresholds, _category_names
def _raw_predict(image: ImageTyping, model_name: str):
"""
Make a raw prediction with the PixAI tagger model.
This function preprocesses the input image and runs inference using the specified
ONNX model. It returns the raw model outputs without any post-processing or
threshold application.
:param image: The input image to analyze
:type image: ImageTyping
:param model_name: Name of the model to use for prediction
:type model_name: str
:return: Dictionary containing raw model outputs with keys like 'prediction', 'embedding', 'logits'
:rtype: dict
Example::
>>> raw_output = _raw_predict('anime_image.jpg', 'v0.9')
>>> print(raw_output.keys()) # dict_keys(['prediction', 'embedding', 'logits'])
"""
image = load_image(image, force_background='white', mode='RGB')
model = _open_onnx_model(model_name=model_name)
trans = _open_preprocess(model_name=model_name)
input_ = trans(image)[None, ...]
output_names = [output.name for output in model.get_outputs()]
output_values = model.run(output_names, {'input': input_})
return {name: value[0] for name, value in zip(output_names, output_values)}
[docs]def get_pixai_tags(image: ImageTyping, model_name: str = 'v0.9',
thresholds: Union[float, Dict[Any, float]] = None, fmt=FMT_UNSET):
"""
Extract tags from an image using PixAI tagger models.
This function processes an image through a PixAI tagger model and applies confidence
thresholds to determine which tags to include in the results. The output format can
be customized to return specific categories or all tags together.
:param image: The input image to analyze (file path, PIL Image, numpy array, etc.)
:type image: ImageTyping
:param model_name: Name or repository ID of the PixAI tagger model to use
:type model_name: str
:param thresholds: Confidence threshold values. Can be a single float applied to all
categories, or a dictionary mapping category IDs/names to specific thresholds
:type thresholds: Union[float, Dict[Any, float]], optional
:param fmt: Output format specification. If FMT_UNSET, returns all available categories.
Can be a tuple of category names to include in output
:type fmt: Any
:return: Formatted prediction results. Default returns tuple of (general_tags, character_tags, ...)
based on available categories. Can return custom format based on fmt parameter
:rtype: Any
.. note::
The fmt parameter can include the following keys:
- Category names (e.g., 'general', 'character'): dictionaries containing category-specific
tags and their confidence scores
- ``tag``: a dictionary containing all tags across categories and their confidences
- ``embedding``: a 1-dimensional embedding vector of the image, recommended for similarity
search after L2 normalization
- ``logits``: raw 1-dimensional logits output from the model
- ``prediction``: 1-dimensional prediction probabilities from the model
- ``ips_mapping``: a dictionary mapping detected character tags to their associated
IP (Intellectual Property) names in Danbooru-style format
- ``ips_count``: a dictionary containing IP names and their occurrence counts based
on detected characters
- ``ips``: a list of IP names sorted by occurrence count (descending) and name
(ascending), representing the most likely anime/game series in the image
Default category thresholds are used if not specified. These vary by model and category
but typically range from 0.35 to 0.5.
You can extract embedding of the given image with the following code
>>> from imgutils.tagging import get_pixai_tags
>>>
>>> embedding = get_pixai_tags('skadi.jpg', fmt='embedding')
>>> embedding.shape
(1024, )
This embedding is valuable for constructing indices that enable rapid querying of images based on
visual features within large-scale datasets.
Example::
>>> from imgutils.tagging.pixai import get_pixai_tags
>>>
>>> # Get tags with default format (all categories)
>>> general_tags, character_tags = get_pixai_tags('anime_image.jpg', model_name='v0.9')
>>> print("General tags:", general_tags)
>>> print("Character tags:", character_tags)
>>>
>>> # Get all tags in a single dictionary
>>> all_tags = get_pixai_tags('image.jpg', fmt='tag')
>>> print("All tags:", all_tags)
>>>
>>> # Use custom thresholds
>>> result = get_pixai_tags('image.jpg', thresholds={'general': 0.3, 'character': 0.5})
>>>
>>> # Get embedding for similarity search
>>> embedding = get_pixai_tags('image.jpg', fmt='embedding')
>>> # Normalize for cosine similarity
>>> import numpy as np
>>> normalized_embedding = embedding / np.linalg.norm(embedding)
>>>
>>> # Get IP information for character identification
>>> ips_mapping = get_pixai_tags('image.jpg', fmt='ips_mapping')
>>> print("Character to IP mapping:", ips_mapping)
>>> # Example output: {'misaka_mikoto': ['toaru_kagaku_no_railgun'], 'hu_tao_(genshin_impact)': ['genshin_impact']}
>>>
>>> # Get most likely anime/game series
>>> top_ips = get_pixai_tags('image.jpg', fmt='ips')
>>> print("Most likely series:", top_ips)
>>> # Example output: ['genshin_impact', 'toaru_kagaku_no_railgun']
Here are some images for example
.. image:: tagging_demo.plot.py.svg
:align: center
>>> general, character = get_pixai_tags('skadi.jpg')
>>> general
{'patreon_username': 0.9988852739334106, 'baseball_bat': 0.9977256059646606, 'holding_baseball_bat': 0.9858889579772949, 'navel': 0.9830228090286255, 'crop_top': 0.9666315317153931, 'sportswear': 0.9664723873138428, '1girl': 0.9572311639785767, 'long_hair': 0.9550737738609314, 'outdoors': 0.9501817226409912, 'solo': 0.9466996788978577, 'day': 0.9394471049308777, 'breasts': 0.938787579536438, 'web_address': 0.9387772679328918, 'stomach': 0.935083270072937, 'red_eyes': 0.9326196908950806, 'shorts': 0.9305683374404907, 'motion_blur': 0.9278550148010254, 'playing_sports': 0.9263769388198853, 'blue_sky': 0.9213213920593262, 'midriff': 0.9191423654556274, 'large_breasts': 0.9174768924713135, 'artist_name': 0.9089528322219849, 'sky': 0.9054281711578369, 'baseball': 0.904181957244873, 'gloves': 0.9033604860305786, 'thighs': 0.893738865852356, 'black_shorts': 0.8926981687545776, 'volleyball': 0.8198539614677429, 'very_long_hair': 0.7967187166213989, 'short_shorts': 0.7873305082321167, 'black_gloves': 0.7765249013900757, 'white_hair': 0.770541787147522, 'baseball_mitt': 0.7684446573257446, 'thigh_gap': 0.73811936378479, 'sweat': 0.7263807654380798, 'cowboy_shot': 0.7235408425331116, 'short_sleeves': 0.7062878012657166, 'parted_lips': 0.7025120258331299, 'patreon_logo': 0.6970672607421875, 'cloud': 0.6967148780822754, 'looking_at_viewer': 0.6898926496505737, 'holding': 0.6879030466079712, 'swinging': 0.6736525893211365, 'ass_visible_through_thighs': 0.6636734008789062, 'elbow_pads': 0.6630151867866516, 'shirt': 0.6250661611557007, 'hair_between_eyes': 0.6075361967086792, 'standing': 0.5285079479217529, 'black_shirt': 0.5173394680023193, 'linea_alba': 0.513701319694519, 'baseball_uniform': 0.48175835609436035, 'crop_top_overhang': 0.4682246744632721, 'ball': 0.43616628646850586, 'blurry': 0.4201475977897644, 'baseball_stadium': 0.41493287682533264, 'grey_hair': 0.39384859800338745, 'watermark': 0.3919041156768799, 'black_sports_bra': 0.3877854645252228, 'fanbox_username': 0.3790855407714844, 'narrow_waist': 0.36392998695373535}
>>> character
{'skadi_(arknights)': 0.8926791548728943}
>>>
>>> general, character = get_pixai_tags('hutao.jpg')
>>> general
{'bag': 0.9833353757858276, 'backpack': 0.9766197204589844, 'flower-shaped_pupils': 0.962916910648346, 'tongue_out': 0.960152804851532, 'tongue': 0.9526823163032532, 'ghost': 0.9514724016189575, 'plaid_skirt': 0.9499615430831909, '1girl': 0.9378864765167236, 'skirt': 0.9353114366531372, 'bag_charm': 0.9314961433410645, 'symbol-shaped_pupils': 0.9252510070800781, 'charm_(object)': 0.9249529242515564, 'twintails': 0.9239017367362976, 'flower': 0.9175764322280884, 'outdoors': 0.9175151586532593, 'hair_ornament': 0.9161680936813354, 'plaid_clothes': 0.9144806861877441, 'long_hair': 0.8768749833106995, 'pleated_skirt': 0.8597153425216675, 'school_uniform': 0.8573414087295532, 'looking_at_viewer': 0.8392735719680786, ':p': 0.8193913698196411, 'hair_between_eyes': 0.8070638179779053, 'hair_flower': 0.8054562211036682, 'nail_polish': 0.8011300563812256, 'building': 0.7961824536323547, 'jacket': 0.7647742629051208, 'brown_hair': 0.7541910409927368, 'solo': 0.7539198398590088, 'long_sleeves': 0.7471930980682373, 'ahoge': 0.7171378135681152, 'hair_ribbon': 0.6994943618774414, 'red_eyes': 0.6819245219230652, 'bowtie': 0.6639955043792725, 'sidelocks': 0.6275356411933899, 'bush': 0.6164096593856812, 'gate': 0.612525224685669, 'smile': 0.6077383160591125, 'shirt': 0.6042965650558472, 'contemporary': 0.5968752503395081, 'brick_floor': 0.5933602452278137, 'cardigan': 0.5810519456863403, 'gradient_hair': 0.5570307970046997, 'diagonal-striped_bowtie': 0.5565160512924194, 'alternate_costume': 0.5535630583763123, 'school_bag': 0.5535626411437988, 'black_hair': 0.5434530973434448, 'ribbon': 0.5332301259040833, 'hairclip': 0.523446261882782, 'day': 0.5164296627044678, 'street': 0.49987316131591797, 'bow': 0.4941294193267822, 'plum_blossoms': 0.4940766990184784, 'collared_shirt': 0.49013280868530273, 'standing': 0.4820355772972107, 'blue_cardigan': 0.47836002707481384, 'cowboy_shot': 0.4782080054283142, 'pocket': 0.477585107088089, 'pavement': 0.4712265729904175, 'multicolored_hair': 0.4610708951950073, 'blue_jacket': 0.45271334052085876, 'blush': 0.45005902647972107, 'sleeves_past_wrists': 0.440649151802063, 'black_nails': 0.4402858316898346, 'black_bag': 0.4206739366054535, 'miniskirt': 0.4187837243080139, 'red_bow': 0.414681613445282, 'very_long_hair': 0.4129619002342224, 'diagonal-striped_clothes': 0.4112803339958191, 'blazer': 0.40750616788864136, 'striped_bowtie': 0.40123170614242554, 'sunlight': 0.4008329212665558, 'grey_skirt': 0.3930213749408722, 'road': 0.3819067180156708, 'black_ribbon': 0.3776353895664215, 'thighs': 0.3722286820411682, 'hug': 0.37215015292167664, 'brick_wall': 0.3717171549797058, 'white_shirt': 0.3694952428340912, 'open_clothes': 0.36442798376083374, 'open_jacket': 0.3525886535644531, ':d': 0.3343055844306946, 'multicolored_nails': 0.32190075516700745, 'red_bowtie': 0.3157669007778168, 'star-shaped_pupils': 0.309164822101593, 'open_mouth': 0.30890953540802, 'beads': 0.3084579110145569, 'stone_stairs': 0.30559882521629333, 'randoseru': 0.30517613887786865}
>>> character
{'hu_tao_(genshin_impact)': 0.9997367858886719, 'boo_tao_(genshin_impact)': 0.999537467956543}
"""
df_tags, d_ips = _open_tags(model_name=model_name)
values = _raw_predict(image, model_name=model_name)
prediction = values['prediction']
tags = {}
default_category_thresholds, category_names = _open_default_category_thresholds(model_name=model_name)
if fmt is FMT_UNSET:
fmt = tuple(category_names[category] for category in sorted(set(df_tags['category'].tolist())))
for category in sorted(set(df_tags['category'].tolist())):
mask = df_tags['category'] == category
tag_names = df_tags['name'][mask]
category_pred = prediction[mask]
if isinstance(thresholds, float):
category_threshold = thresholds
elif isinstance(thresholds, dict) and \
(category in thresholds or category_names[category] in thresholds):
if category in thresholds:
category_threshold = thresholds[category]
elif category_names[category] in thresholds:
category_threshold = thresholds[category_names[category]]
else:
assert False, 'Should not reach this line' # pragma: no cover
else:
if category in default_category_thresholds:
category_threshold = default_category_thresholds[category]
else:
category_threshold = 0.4
mask = category_pred >= category_threshold
tag_names = tag_names[mask].tolist()
category_pred = category_pred[mask].tolist()
cate_tags = dict(sorted(zip(tag_names, category_pred), key=lambda x: (-x[1], x[0])))
values[category_names[category]] = cate_tags
tags.update(cate_tags)
values['tag'] = tags
if 'ips' in df_tags.columns:
ips_mapping, ips_counts = {}, defaultdict(lambda: 0)
for tag, _ in tags.items():
if tag in d_ips:
ips_mapping[tag] = d_ips[tag]
for ip_name in d_ips[tag]:
ips_counts[ip_name] += 1
values['ips_mapping'] = ips_mapping
values['ips_count'] = dict(ips_counts)
values['ips'] = [x for x, _ in sorted(ips_counts.items(), key=lambda x: (-x[1], x[0]))]
return vreplace(fmt, values)