imgutils.generic.classify_timm

TIMM-based Image Classification Module

This module provides functionality for using pre-trained TIMM (PyTorch Image Models) models for image classification tasks. It includes capabilities for:

  • Loading TIMM models from Hugging Face repositories

  • Processing and classifying images

  • Creating interactive web demos with Gradio

  • Retrieving and formatting prediction results

The module is designed to make it easy to use pre-trained image classification models with minimal setup, supporting both programmatic use and interactive demos.

ClassifyTIMMModel

class imgutils.generic.classify_timm.ClassifyTIMMModel(repo_id: str, hf_token: str | None = None)[source]

A class for handling TIMM-based image classification models.

This class provides functionality to load models from Hugging Face repositories, perform predictions, and create interactive demos.

Parameters:
  • repo_id (str) – The Hugging Face repository ID for the model

  • hf_token (Optional[str]) – Optional Hugging Face authentication token

__init__(repo_id: str, hf_token: str | None = None)[source]

Initialize a ClassifyTIMMModel instance.

Parameters:
  • repo_id (str) – The Hugging Face repository ID for the model

  • hf_token (Optional[str]) – Optional Hugging Face authentication token

launch_demo(server_name: str | None = None, server_port: int | None = None, **kwargs)[source]

Launch a standalone Gradio demo for the model.

This method creates and launches a complete Gradio web application for interactive image classification.

Parameters:
  • server_name (Optional[str]) – Server name for the Gradio app

  • server_port (Optional[int]) – Server port for the Gradio app

  • kwargs – Additional keyword arguments to pass to gr.Blocks.launch()

Raises:

EnvironmentError – If Gradio is not installed

make_ui()[source]

Create a Gradio UI component for the model.

This method builds a user interface for interactive image classification that can be embedded in a larger Gradio application.

Raises:

EnvironmentError – If Gradio is not installed

predict(image: str | PathLike | bytes | bytearray | BinaryIO | Image, preprocessor: Literal['test', 'val'] = 'test', fmt='scores-top5')[source]

Predict classification results for an image.

This method processes the image, runs inference, and formats the results according to the specified format.

Parameters:
  • image (ImageTyping) – The input image to classify

  • preprocessor (Literal['test', 'val']) – Which preprocessor to use (‘test’ or ‘val’)

  • fmt – Output format specification, e.g., ‘scores-top5’

Returns:

Formatted prediction results

Return type:

dict or other type depending on fmt

Note

The fmt argument can include the following keys:

  • scores: dicts containing all the prediction scores of all the classes, may be a very big dict

  • scores-top<k>: dict containing top-k classes and their scores, e.g. scores-top5 means top-5 classes

  • embedding: a 1-dim embedding of image, recommended for index building after L2 normalization

  • logits: a 1-dim logits result of image

  • prediction: a 1-dim prediction result of image

You can extract specific category predictions or all tags based on your needs.

For more details see documentation of classify_timm_predict().

classify_timm_predict

imgutils.generic.classify_timm.classify_timm_predict(image: str | PathLike | bytes | bytearray | BinaryIO | Image, repo_id: str, preprocessor: Literal['test', 'val'] = 'test', fmt='scores-top5', hf_token: str | None = None)[source]

Perform image classification using a TIMM model from a Hugging Face repository.

This is a convenience function that handles model loading and prediction in one call.

Parameters:
  • image (ImageTyping) – The input image to classify

  • repo_id (str) – The Hugging Face repository ID for the model

  • preprocessor (Literal['test', 'val']) – Which preprocessor to use (‘test’ or ‘val’)

  • fmt – Output format specification, e.g., ‘scores-top5’

  • hf_token (Optional[str]) – Optional Hugging Face authentication token

Returns:

Formatted prediction results

Return type:

dict or other type depending on fmt

Example:

Here are some images for example

../../_images/classify_timm_demo.plot.py.svg
>>> from imgutils.generic import classify_timm_predict
>>>
>>> classify_timm_predict(
...     'classify_timm/img1.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls'
... )
{'jia_redian_ruzi_ruzi': 0.9890832304954529, 'siya_ho': 0.005189628805965185, 'bai_qi-qsr': 0.0015026535838842392, 'kkuem': 0.0012714712647721171, 'teddy_(khanshin)': 0.00035598213435150683}
>>>
>>> classify_timm_predict(
...     'classify_timm/img2.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls'
... )
{'monori_rogue': 0.6921895742416382, 'stanley_lau': 0.2040979117155075, 'neoartcore': 0.03475344926118851, 'ayya_sap': 0.005350438412278891, 'goomrrat': 0.004616163671016693}
>>>
>>> classify_timm_predict(
...     'classify_timm/img3.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls'
... )
{'shexyo': 0.9998241066932678, 'oroborus': 0.0001537767384434119, 'jeneral': 7.268482477229554e-06, 'free_style_(yohan1754)': 3.4537688406999223e-06, 'kakeku': 2.5340586944366805e-06}
>>>
>>> classify_timm_predict(
...     'classify_timm/img4.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls'
... )
{'z.taiga': 0.9999995231628418, 'tina_(tinafya)': 1.2290533391023928e-07, 'arind_yudha': 6.17258208990279e-08, 'chixiao': 4.949555076905199e-08, 'zerotwenty_(020)': 4.218352955831506e-08}
>>>
>>> classify_timm_predict(
...     'classify_timm/img5.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls'
... )
{'spam_(spamham4506)': 0.9999998807907104, 'falken_(yutozin)': 4.501828954062148e-08, 'yuki_(asayuki101)': 3.285677863118508e-08, 'danbal': 5.452678752959628e-09, 'buri_(retty9349)': 3.757136379789472e-09}
>>>
>>> classify_timm_predict(
...     'classify_timm/img6.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls'
... )
{'mashuu_(neko_no_oyashiro)': 1.0, 'minaba_hideo': 4.543745646401476e-08, 'simosi': 6.499865978781827e-09, 'maoh_yueer': 4.302619149854081e-09, '7nite': 3.6548184478846224e-09}

Note

The fmt argument can include the following keys:

  • scores: dicts containing all the prediction scores of all the classes, may be a very big dict

  • scores-top<k>: dict containing top-k classes and their scores, e.g. scores-top5 means top-5 classes

  • embedding: a 1-dim embedding of image, recommended for index building after L2 normalization

  • logits: a 1-dim logits result of image

  • prediction: a 1-dim prediction result of image

You can extract specific category predictions or all tags based on your needs.

>>> from imgutils.generic import classify_timm_predict
>>>
>>> classify_timm_predict(
...     'classify_timm/img1.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls'
... )
{'jia_redian_ruzi_ruzi': 0.9890832304954529, 'siya_ho': 0.005189628805965185, 'bai_qi-qsr': 0.0015026535838842392, 'kkuem': 0.0012714712647721171, 'teddy_(khanshin)': 0.00035598213435150683}
>>> embedding = classify_timm_predict(
...     'classify_timm/img1.jpg',
...     repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls',
...     fmt='embedding'
... )
>>> embedding.shape, embedding.dtype
((1024,), dtype('float32'))