imgutils.generic.multilabel_timm

Multi-Label TIMM Model Module

This module provides functionality for working with multi-label image classification models trained with TIMM (PyTorch Image Models) and exported to ONNX format. It includes:

  1. The MultiLabelTIMMModel class for loading and making predictions with models hosted on Hugging Face Hub

  2. Functions for batch prediction and demo interface creation

  3. Support for custom thresholds at both category and tag levels

  4. Flexible output formatting options for different use cases

The models are expected to be stored on Hugging Face Hub with specific files:

  • model.onnx: The ONNX model file

  • selected_tags.csv: CSV file containing tag information and categories

  • preprocess.json: JSON configuration for image preprocessing

  • thresholds.csv: Optional CSV file with recommended thresholds

  • categories.json: Category ID and name mapping json file.

This module is designed to work with multi-label classification tasks where images can belong to multiple categories and have multiple tags within each category.

MultiLabelTIMMModel

class imgutils.generic.multilabel_timm.MultiLabelTIMMModel(repo_id: str, hf_token: str | None = None)[source]

A class for working with multi-label image classification models trained with TIMM.

This class handles loading models from Hugging Face Hub, preprocessing images, and making predictions with customizable thresholds.

Parameters:
  • repo_id (str) – The Hugging Face Hub repository ID containing the model

  • hf_token (Optional[str]) – Optional Hugging Face authentication token for private repositories

__init__(repo_id: str, hf_token: str | None = None)[source]

Initialize a MultiLabelTIMMModel.

Parameters:
  • repo_id (str) – The Hugging Face Hub repository ID containing the model

  • hf_token (Optional[str]) – Optional Hugging Face authentication token for private repositories

launch_demo(default_thresholds: float | Dict[Any, float] | None = None, default_use_tag_thresholds: bool = True, server_name: str | None = None, server_port: int | None = None, **kwargs)[source]

Launch a Gradio demo for the model.

Parameters:
  • default_thresholds (Union[float, Dict[Any, float]]) – Default threshold values to use in the demo

  • default_use_tag_thresholds (bool) – Whether to use tag-level thresholds by default

  • server_name (Optional[str]) – Server name for the Gradio app

  • server_port (Optional[int]) – Server port for the Gradio app

  • kwargs (Any) – Additional keyword arguments to pass to gr.launch()

Returns:

None

Raises:

EnvironmentError – If Gradio is not installed

make_ui(default_thresholds: float | Dict[Any, float] | None = None, default_use_tag_thresholds: bool = True)[source]

Create a Gradio UI for the model.

Parameters:
  • default_thresholds (Union[float, Dict[Any, float]]) – Default threshold values to use in the UI

  • default_use_tag_thresholds (bool) – Whether to use tag-level thresholds by default

Returns:

None

Raises:

EnvironmentError – If Gradio is not installed

predict(image: str | ~os.PathLike | bytes | bytearray | ~typing.BinaryIO | ~PIL.Image.Image, preprocessor: ~typing.Literal['test', 'val'] = 'test', thresholds: float | ~typing.Dict[~typing.Any, float] | None = None, use_tag_thresholds: bool = True, fmt=<object object>)[source]

Make a prediction and format the results.

This method processes an image through the model and applies thresholds to determine which tags to include in the results. The output format can be customized using the fmt parameter.

Parameters:
  • image (ImageTyping) – The input image

  • preprocessor (Literal['test', 'val']) – Which preprocessor to use (‘test’ or ‘val’)

  • thresholds (Union[float, Dict[Any, float]]) – Threshold values for tag confidence. Can be a single float applied to all categories or a dictionary mapping category IDs or names to threshold values

  • use_tag_thresholds (bool) – Whether to use tag-level thresholds if available

  • fmt (Any) – Output format specification. Can be a tuple of category names to include, or FMT_UNSET to use all categories

Returns:

Formatted prediction results according to the fmt parameter

Return type:

Any

Note

The fmt argument can include the following keys:

  • Category names: dicts containing category-specific tags and their confidences

  • tag: a dict containing all tags across categories and their confidences

  • embedding: a 1-dim embedding of image, recommended for index building after L2 normalization

  • logits: a 1-dim logits result of image

  • prediction: a 1-dim prediction result of image

You can extract specific category predictions or all tags based on your needs.

For more details see documentation of multilabel_timm_predict().

multilabel_timm_predict

imgutils.generic.multilabel_timm.multilabel_timm_predict(image: str | ~os.PathLike | bytes | bytearray | ~typing.BinaryIO | ~PIL.Image.Image, repo_id: str, preprocessor: ~typing.Literal['test', 'val'] = 'test', thresholds: float | ~typing.Dict[~typing.Any, float] | None = None, use_tag_thresholds: bool = True, fmt=<object object>, hf_token: str | None = None)[source]

Make predictions using a multi-label TIMM model.

This function provides a convenient interface for making predictions with models hosted on Hugging Face Hub without directly instantiating a MultiLabelTIMMModel.

Parameters:
  • image (ImageTyping) – The input image

  • repo_id (str) – The Hugging Face Hub repository ID containing the model

  • preprocessor (Literal['test', 'val']) – Which preprocessor to use (‘test’ or ‘val’)

  • thresholds (Union[float, Dict[Any, float]]) – Threshold values for tag confidence. Can be a single float applied to all categories or a dictionary mapping category IDs or names to threshold values

  • use_tag_thresholds (bool) – Whether to use tag-level thresholds if available

  • fmt (Any) – Output format specification. Can be a tuple of category names to include, or FMT_UNSET to use all categories

  • hf_token (Optional[str]) – Optional Hugging Face authentication token for private repositories

Returns:

Formatted prediction results according to the fmt parameter

Return type:

Any

Example:

Here are some images for example

../../_images/multilabel_demo.plot.py.svg
>>> from imgutils.generic import multilabel_timm_predict
>>>
>>> general, character, rating = multilabel_timm_predict(
...     'skadi.jpg',
...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
... )
>>> general
{'1girl': 0.9963783025741577, 'long_hair': 0.9685494899749756, 'solo': 0.9548443555831909, 'navel': 0.9415484666824341, 'breasts': 0.9369214177131653, 'red_eyes': 0.9019639492034912, 'shirt': 0.873087465763092, 'outdoors': 0.866461992263794, 'crop_top': 0.862577497959137, 'midriff': 0.8544420003890991, 'sportswear': 0.849435567855835, 'shorts': 0.8209151029586792, 'short_sleeves': 0.817188560962677, 'holding': 0.811793327331543, 'very_long_hair': 0.8082301616668701, 'gloves': 0.7840366363525391, 'black_gloves': 0.7765430808067322, 'thighs': 0.7542579770088196, 'looking_at_viewer': 0.7331588268280029, 'day': 0.7203925251960754, 'hair_between_eyes': 0.7121687531471252, 'large_breasts': 0.6990523338317871, 'baseball_bat': 0.6809443831443787, 'grey_hair': 0.6790007948875427, 'sky': 0.6716539263725281, 'stomach': 0.6698249578475952, 'sweat': 0.6454322934150696, 'black_shirt': 0.6270318031311035, 'cowboy_shot': 0.6216483116149902, 'blue_sky': 0.5898874998092651, 'black_shorts': 0.5445142984390259, 'holding_baseball_bat': 0.5013713836669922, 'white_hair': 0.4999670684337616, 'blush': 0.4860053062438965, 'cloud': 0.474183052778244, 'standing': 0.4724341332912445, 'thigh_gap': 0.4330931305885315, 'short_shorts': 0.39793258905410767, 'parted_lips': 0.36694538593292236, 'crop_top_overhang': 0.3321989178657532, 'official_alternate_costume': 0.3157039284706116, 'blurry': 0.24181532859802246, 'groin': 0.21906554698944092, 'ass_visible_through_thighs': 0.2188207507133484, 'cropped_shirt': 0.18700966238975525, 'taut_shirt': 0.08612403273582458, 'taut_clothes': 0.0701744556427002}
>>> character
{'skadi_(arknights)': 0.9796262979507446}
>>> rating
{'sensitive': 0.9580697417259216}
>>>
>>> general, character, rating = multilabel_timm_predict(
...     'hutao.jpg',
...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
... )
>>> general
{'1girl': 0.988956093788147, 'twintails': 0.9650213718414307, 'ghost': 0.940951943397522, 'tongue_out': 0.9330000877380371, 'tongue': 0.9267600774765015, 'skirt': 0.9194451570510864, 'symbol-shaped_pupils': 0.9103127717971802, 'brown_hair': 0.9067947268486023, 'long_hair': 0.8872615098953247, 'red_eyes': 0.8631541728973389, 'looking_at_viewer': 0.8235997557640076, 'solo': 0.8214132785797119, 'long_sleeves': 0.7965610027313232, 'bag': 0.7958617210388184, 'jacket': 0.7932659387588501, 'flower-shaped_pupils': 0.7630170583724976, 'shirt': 0.7500981092453003, 'hair_ornament': 0.738053023815155, 'flower': 0.7321316599845886, 'plaid_skirt': 0.7173646688461304, 'white_shirt': 0.6631225347518921, 'pleated_skirt': 0.6344470977783203, 'hair_flower': 0.6293849945068359, 'nail_polish': 0.6136130094528198, 'multicolored_hair': 0.5703858733177185, 'blush': 0.5195141434669495, 'plaid_clothes': 0.503984808921814, 'gradient_hair': 0.49658203125, 'alternate_costume': 0.4947473704814911, ':p': 0.493851900100708, 'hair_between_eyes': 0.484821081161499, 'smile': 0.4778161942958832, 'black_nails': 0.4747253358364105, 'collared_shirt': 0.46951043605804443, 'outdoors': 0.46920245885849, 'holding': 0.45227500796318054, 'school_uniform': 0.4197554290294647, 'very_long_hair': 0.41959843039512634, 'miniskirt': 0.3916422426700592, 'cowboy_shot': 0.38207799196243286, 'blue_jacket': 0.3614964485168457, 'sleeves_past_wrists': 0.3611966073513031, 'backpack': 0.32487112283706665, 'colored_tips': 0.314140260219574, 'sidelocks': 0.3062695264816284, 'black_jacket': 0.299169659614563, 'standing': 0.29005059599876404, 'charm_(object)': 0.22183549404144287, 'multiple_rings': 0.2172674536705017, 'open_jacket': 0.2046721875667572, 'ring': 0.18625634908676147, 'brown_skirt': 0.18045437335968018, 'contemporary': 0.13890522718429565}
>>> character
{'hu_tao_(genshin_impact)': 0.9779937267303467, 'boo_tao_(genshin_impact)': 0.8973554372787476}
>>> rating
{'general': 0.6215817332267761, 'sensitive': 0.3872501254081726}

Note

For different models, the default format is different. That depends on the categories that model supported.

For example, for model animetimm/mobilenetv3_large_150d.dbv4-full-witha

>>> from imgutils.generic import multilabel_timm_predict
>>>
>>> general, artist, character, rating = multilabel_timm_predict(
...     'skadi.jpg',
...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full-witha',
... )
>>> general
{'1girl': 0.9938606023788452, 'long_hair': 0.9691187143325806, 'red_eyes': 0.9463587403297424, 'solo': 0.944723904132843, 'navel': 0.9439248442649841, 'breasts': 0.9335891008377075, 'sportswear': 0.8865424394607544, 'shorts': 0.8601726293563843, 'very_long_hair': 0.8445472717285156, 'outdoors': 0.83197021484375, 'midriff': 0.8274217247962952, 'shirt': 0.8188955783843994, 'short_sleeves': 0.8183804750442505, 'crop_top': 0.8089936971664429, 'gloves': 0.8038264513015747, 'black_gloves': 0.7703496813774109, 'thighs': 0.7689077854156494, 'holding': 0.768336832523346, 'looking_at_viewer': 0.739115834236145, 'large_breasts': 0.7282243967056274, 'sky': 0.6852632761001587, 'hair_between_eyes': 0.6799711585044861, 'stomach': 0.6694454550743103, 'baseball_bat': 0.6693665385246277, 'black_shorts': 0.6493985652923584, 'day': 0.6425715684890747, 'cowboy_shot': 0.6186742186546326, 'black_shirt': 0.5906491279602051, 'holding_baseball_bat': 0.5860112905502319, 'sweat': 0.5825777649879456, 'cloud': 0.5549533367156982, 'blue_sky': 0.5523971915245056, 'white_hair': 0.5324308276176453, 'grey_hair': 0.52657151222229, 'short_shorts': 0.4896492063999176, 'standing': 0.45526784658432007, 'parted_lips': 0.4306206703186035, 'blush': 0.4149143397808075, 'thigh_gap': 0.4124316871166229, 'ass_visible_through_thighs': 0.34030789136886597, 'artist_name': 0.2679593563079834, 'groin': 0.2652612328529358, 'blurry': 0.2548949122428894, 'baseball': 0.24870169162750244, 'crop_top_overhang': 0.2240566909313202, 'stretching': 0.2012709677219391, 'cropped_shirt': 0.19828352332115173, 'official_alternate_costume': 0.1960265338420868, 'toned': 0.13941210508346558, 'exercising': 0.11270403861999512, 'lens_flare': 0.10835999250411987, 'taut_clothes': 0.08783495426177979, 'taut_shirt': 0.08448180556297302, 'linea_alba': 0.06583884358406067}
>>> artist
{}
>>> character
{'skadi_(arknights)': 0.8951651453971863}
>>> rating
{'sensitive': 0.9492285847663879}

Its default fmt is ('general', 'artist', 'character', 'rating').

But you can easily get those information you need with a more controllable way with fmt. See the next note.

Note

The fmt argument can include the following keys:

  • Category names: dicts containing category-specific tags and their confidences

  • tag: a dict containing all tags across categories and their confidences

  • embedding: a 1-dim embedding of image, recommended for index building after L2 normalization

  • logits: a 1-dim logits result of image

  • prediction: a 1-dim prediction result of image

You can extract specific category predictions or all tags based on your needs.

>>> from imgutils.generic import multilabel_timm_predict
>>>
>>> # default usage
>>> general, character, rating = multilabel_timm_predict(
...     'skadi.jpg',
...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
... )
>>> general
{'1girl': 0.9963783025741577, 'long_hair': 0.9685494899749756, 'solo': 0.9548443555831909, 'navel': 0.9415484666824341, 'breasts': 0.9369214177131653, 'red_eyes': 0.9019639492034912, 'shirt': 0.873087465763092, 'outdoors': 0.866461992263794, 'crop_top': 0.862577497959137, 'midriff': 0.8544420003890991, 'sportswear': 0.849435567855835, 'shorts': 0.8209151029586792, 'short_sleeves': 0.817188560962677, 'holding': 0.811793327331543, 'very_long_hair': 0.8082301616668701, 'gloves': 0.7840366363525391, 'black_gloves': 0.7765430808067322, 'thighs': 0.7542579770088196, 'looking_at_viewer': 0.7331588268280029, 'day': 0.7203925251960754, 'hair_between_eyes': 0.7121687531471252, 'large_breasts': 0.6990523338317871, 'baseball_bat': 0.6809443831443787, 'grey_hair': 0.6790007948875427, 'sky': 0.6716539263725281, 'stomach': 0.6698249578475952, 'sweat': 0.6454322934150696, 'black_shirt': 0.6270318031311035, 'cowboy_shot': 0.6216483116149902, 'blue_sky': 0.5898874998092651, 'black_shorts': 0.5445142984390259, 'holding_baseball_bat': 0.5013713836669922, 'white_hair': 0.4999670684337616, 'blush': 0.4860053062438965, 'cloud': 0.474183052778244, 'standing': 0.4724341332912445, 'thigh_gap': 0.4330931305885315, 'short_shorts': 0.39793258905410767, 'parted_lips': 0.36694538593292236, 'crop_top_overhang': 0.3321989178657532, 'official_alternate_costume': 0.3157039284706116, 'blurry': 0.24181532859802246, 'groin': 0.21906554698944092, 'ass_visible_through_thighs': 0.2188207507133484, 'cropped_shirt': 0.18700966238975525, 'taut_shirt': 0.08612403273582458, 'taut_clothes': 0.0701744556427002}
>>> character
{'skadi_(arknights)': 0.9796262979507446}
>>> rating
{'sensitive': 0.9580697417259216}
>>>
>>> # get rating and character only
>>> rating, character = multilabel_timm_predict(
...     'skadi.jpg',
...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
...     fmt=('rating', 'character'),
... )
>>> rating
{'sensitive': 0.9580697417259216}
>>> character
{'skadi_(arknights)': 0.9796262979507446}
>>>
>>> # get embeddings only
>>> embedding = multilabel_timm_predict(
...     'skadi.jpg',
...     repo_id='animetimm/mobilenetv3_large_150d.dbv4-full',
...     fmt='embedding',
... )
>>> embedding.dtype, embedding.shape
(dtype('float32'), (1280,))