imgutils.generic.classify_timm
TIMM-based Image Classification Module
This module provides functionality for using pre-trained TIMM (PyTorch Image Models) models for image classification tasks. It includes capabilities for:
Loading TIMM models from Hugging Face repositories
Processing and classifying images
Creating interactive web demos with Gradio
Retrieving and formatting prediction results
The module is designed to make it easy to use pre-trained image classification models with minimal setup, supporting both programmatic use and interactive demos.
ClassifyTIMMModel
- class imgutils.generic.classify_timm.ClassifyTIMMModel(repo_id: str, hf_token: str | None = None)[source]
A class for handling TIMM-based image classification models.
This class provides functionality to load models from Hugging Face repositories, perform predictions, and create interactive demos.
- Parameters:
repo_id (str) – The Hugging Face repository ID for the model
hf_token (Optional[str]) – Optional Hugging Face authentication token
- __init__(repo_id: str, hf_token: str | None = None)[source]
Initialize a ClassifyTIMMModel instance.
- Parameters:
repo_id (str) – The Hugging Face repository ID for the model
hf_token (Optional[str]) – Optional Hugging Face authentication token
- launch_demo(server_name: str | None = None, server_port: int | None = None, **kwargs)[source]
Launch a standalone Gradio demo for the model.
This method creates and launches a complete Gradio web application for interactive image classification.
- Parameters:
server_name (Optional[str]) – Server name for the Gradio app
server_port (Optional[int]) – Server port for the Gradio app
kwargs – Additional keyword arguments to pass to gr.Blocks.launch()
- Raises:
EnvironmentError – If Gradio is not installed
- make_ui()[source]
Create a Gradio UI component for the model.
This method builds a user interface for interactive image classification that can be embedded in a larger Gradio application.
- Raises:
EnvironmentError – If Gradio is not installed
- predict(image: str | PathLike | bytes | bytearray | BinaryIO | Image, preprocessor: Literal['test', 'val'] = 'test', fmt='scores-top5')[source]
Predict classification results for an image.
This method processes the image, runs inference, and formats the results according to the specified format.
- Parameters:
image (ImageTyping) – The input image to classify
preprocessor (Literal['test', 'val']) – Which preprocessor to use (‘test’ or ‘val’)
fmt – Output format specification, e.g., ‘scores-top5’
- Returns:
Formatted prediction results
- Return type:
dict or other type depending on fmt
Note
The
fmt
argument can include the following keys:scores
: dicts containing all the prediction scores of all the classes, may be a very big dictscores-top<k>
: dict containing top-k classes and their scores, e.g.scores-top5
means top-5 classesembedding
: a 1-dim embedding of image, recommended for index building after L2 normalizationlogits
: a 1-dim logits result of imageprediction
: a 1-dim prediction result of image
You can extract specific category predictions or all tags based on your needs.
For more details see documentation of
classify_timm_predict()
.
classify_timm_predict
- imgutils.generic.classify_timm.classify_timm_predict(image: str | PathLike | bytes | bytearray | BinaryIO | Image, repo_id: str, preprocessor: Literal['test', 'val'] = 'test', fmt='scores-top5', hf_token: str | None = None)[source]
Perform image classification using a TIMM model from a Hugging Face repository.
This is a convenience function that handles model loading and prediction in one call.
- Parameters:
image (ImageTyping) – The input image to classify
repo_id (str) – The Hugging Face repository ID for the model
preprocessor (Literal['test', 'val']) – Which preprocessor to use (‘test’ or ‘val’)
fmt – Output format specification, e.g., ‘scores-top5’
hf_token (Optional[str]) – Optional Hugging Face authentication token
- Returns:
Formatted prediction results
- Return type:
dict or other type depending on fmt
- Example:
Here are some images for example
>>> from imgutils.generic import classify_timm_predict >>> >>> classify_timm_predict( ... 'classify_timm/img1.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls' ... ) {'jia_redian_ruzi_ruzi': 0.9890832304954529, 'siya_ho': 0.005189628805965185, 'bai_qi-qsr': 0.0015026535838842392, 'kkuem': 0.0012714712647721171, 'teddy_(khanshin)': 0.00035598213435150683} >>> >>> classify_timm_predict( ... 'classify_timm/img2.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls' ... ) {'monori_rogue': 0.6921895742416382, 'stanley_lau': 0.2040979117155075, 'neoartcore': 0.03475344926118851, 'ayya_sap': 0.005350438412278891, 'goomrrat': 0.004616163671016693} >>> >>> classify_timm_predict( ... 'classify_timm/img3.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls' ... ) {'shexyo': 0.9998241066932678, 'oroborus': 0.0001537767384434119, 'jeneral': 7.268482477229554e-06, 'free_style_(yohan1754)': 3.4537688406999223e-06, 'kakeku': 2.5340586944366805e-06} >>> >>> classify_timm_predict( ... 'classify_timm/img4.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls' ... ) {'z.taiga': 0.9999995231628418, 'tina_(tinafya)': 1.2290533391023928e-07, 'arind_yudha': 6.17258208990279e-08, 'chixiao': 4.949555076905199e-08, 'zerotwenty_(020)': 4.218352955831506e-08} >>> >>> classify_timm_predict( ... 'classify_timm/img5.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls' ... ) {'spam_(spamham4506)': 0.9999998807907104, 'falken_(yutozin)': 4.501828954062148e-08, 'yuki_(asayuki101)': 3.285677863118508e-08, 'danbal': 5.452678752959628e-09, 'buri_(retty9349)': 3.757136379789472e-09} >>> >>> classify_timm_predict( ... 'classify_timm/img6.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls' ... ) {'mashuu_(neko_no_oyashiro)': 1.0, 'minaba_hideo': 4.543745646401476e-08, 'simosi': 6.499865978781827e-09, 'maoh_yueer': 4.302619149854081e-09, '7nite': 3.6548184478846224e-09}
Note
The
fmt
argument can include the following keys:scores
: dicts containing all the prediction scores of all the classes, may be a very big dictscores-top<k>
: dict containing top-k classes and their scores, e.g.scores-top5
means top-5 classesembedding
: a 1-dim embedding of image, recommended for index building after L2 normalizationlogits
: a 1-dim logits result of imageprediction
: a 1-dim prediction result of image
You can extract specific category predictions or all tags based on your needs.
>>> from imgutils.generic import classify_timm_predict >>> >>> classify_timm_predict( ... 'classify_timm/img1.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls' ... ) {'jia_redian_ruzi_ruzi': 0.9890832304954529, 'siya_ho': 0.005189628805965185, 'bai_qi-qsr': 0.0015026535838842392, 'kkuem': 0.0012714712647721171, 'teddy_(khanshin)': 0.00035598213435150683} >>> embedding = classify_timm_predict( ... 'classify_timm/img1.jpg', ... repo_id='animetimm/swinv2_base_window8_256.dbv4a-fullxx-cls', ... fmt='embedding' ... ) >>> embedding.shape, embedding.dtype ((1024,), dtype('float32'))