imgutils.generic.classify
Generic tools for classification models.
This module provides utilities and classes for working with classification models, particularly those stored in Hugging Face repositories. It includes functions for image encoding, model loading, and prediction, as well as a main ClassifyModel class that manages the interaction with classification models.
The module is designed to work with ONNX models and supports various image input formats. It also handles token-based authentication for accessing private Hugging Face repositories.
ClassifyModel
- class imgutils.generic.classify.ClassifyModel(repo_id: str, fn_preprocess: Callable[[Image], Image] | None = None, hf_token: str | None = None)[source]
A comprehensive manager for classification models from Hugging Face repositories.
- Parameters:
repo_id (str) – Hugging Face repository identifier
fn_preprocess (Optional[ImagePreprocessFunc]) – Optional custom preprocessing function
hf_token (Optional[str]) – Hugging Face authentication token
- Variables:
repo_id – Repository identifier
_model_names – Cached list of available models
_models – Dictionary of loaded ONNX models
_labels – Dictionary of model labels
_hf_token – Authentication token
- Usage:
>>> classifier = ClassifyModel("org/model-repo") >>> with Image.open("image.jpg") as img: ... label = classifier.predict(img, "model-name")
- __init__(repo_id: str, fn_preprocess: Callable[[Image], Image] | None = None, hf_token: str | None = None)[source]
Initialize a new ClassifyModel instance.
- Parameters:
repo_id (str) – Hugging Face repository identifier
fn_preprocess (Optional[ImagePreprocessFunc]) – Optional custom preprocessing function
hf_token (Optional[str]) – Authentication token for private repositories
- clear()[source]
Clear the cached models and labels.
This method frees up memory by removing all loaded models and labels from the cache.
- launch_demo(default_model_name: str | None = None, server_name: str | None = None, server_port: int | None = None, **kwargs)[source]
Launch the Gradio demo for the classifier model.
This method creates a Gradio Blocks interface, sets up the UI components using make_ui(), and launches the demo server.
- Parameters:
default_model_name (Optional[str]) – The name of the default model to be selected in the dropdown.
server_name (Optional[str]) – The name of the server to run the demo on. Defaults to None.
server_port (Optional[int]) – The port number to run the demo on. Defaults to None.
kwargs – Additional keyword arguments to pass to the Gradio launch method.
- Raises:
ImportError – If Gradio is not installed or properly configured.
- Example:
>>> model = ClassifyModel("username/repo_name") >>> model.launch_demo(default_model_name="model_v1", server_name="0.0.0.0", server_port=7860)
- make_ui(default_model_name: str | None = None)[source]
Create the user interface components for the classifier model demo.
This method sets up the Gradio UI components including an image input, model selection dropdown, submit button, and output label. It also configures the interaction between these components.
- Parameters:
default_model_name (Optional[str]) – The name of the default model to be selected in the dropdown. If None, the most recently updated model will be selected.
- Raises:
ImportError – If Gradio is not installed or properly configured.
- Example:
>>> model = ClassifyModel("username/repo_name") >>> model.make_ui(default_model_name="model_v1")
- predict(image: str | PathLike | bytes | bytearray | BinaryIO | Image, model_name: str) Tuple[str, float] [source]
Predict the class with the highest score for the given image.
This method runs the image through the model and returns the predicted class and its score.
- Parameters:
image (ImageTyping) – The input image to classify.
model_name (str) – The name of the model to use for prediction.
- Returns:
A tuple containing the predicted class label and its score.
- Return type:
Tuple[str, float]
- Raises:
ValueError – If the model name is invalid.
RuntimeError – If there’s an error during prediction.
- predict_score(image: str | PathLike | bytes | bytearray | BinaryIO | Image, model_name: str) Dict[str, float] [source]
Predict the scores for each class using the specified model.
This method runs the image through the model and returns a dictionary of class scores.
- Parameters:
image (ImageTyping) – The input image to classify.
model_name (str) – The name of the model to use for prediction.
- Returns:
A dictionary mapping class labels to their predicted scores.
- Return type:
Dict[str, float]
- Raises:
ValueError – If the model name is invalid.
RuntimeError – If there’s an error during prediction.
classify_predict_score
- imgutils.generic.classify.classify_predict_score(image: str | PathLike | bytes | bytearray | BinaryIO | Image, repo_id: str, model_name: str, hf_token: str | None = None) Dict[str, float] [source]
Predict the scores for each class using the specified model and repository.
This function is a convenience wrapper around ClassifyModel’s predict_score method.
- Parameters:
image (ImageTyping) – The input image to classify.
repo_id (str) – The repository ID containing the models.
model_name (str) – The name of the model to use for prediction.
hf_token (Optional[str]) – Optional Hugging Face authentication token.
- Returns:
A dictionary mapping class labels to their predicted scores.
- Return type:
Dict[str, float]
- Raises:
ValueError – If the model name or repository ID is invalid.
RuntimeError – If there’s an error during prediction.
classify_predict
- imgutils.generic.classify.classify_predict(image: str | PathLike | bytes | bytearray | BinaryIO | Image, repo_id: str, model_name: str, hf_token: str | None = None) Tuple[str, float] [source]
Predict the class with the highest score using the specified model and repository.
This function is a convenience wrapper around ClassifyModel’s predict method.
- Parameters:
image (ImageTyping) – The input image to classify.
repo_id (str) – The repository ID containing the models.
model_name (str) – The name of the model to use for prediction.
hf_token (Optional[str]) – Optional Hugging Face authentication token.
- Returns:
A tuple containing the predicted class label and its score.
- Return type:
Tuple[str, float]
- Raises:
ValueError – If the model name or repository ID is invalid.
RuntimeError – If there’s an error during prediction.