"""
Overview:
Generic tools for classification models.
"""
import json
import os
from functools import lru_cache
from typing import Tuple, Optional, List, Dict
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download, HfFileSystem
from ..data import rgb_encode, ImageTyping, load_image
from ..utils import open_onnx_model
__all__ = [
'ClassifyModel',
'classify_predict_score',
'classify_predict',
]
def _img_encode(image: Image.Image, size: Tuple[int, int] = (384, 384),
normalize: Optional[Tuple[float, float]] = (0.5, 0.5)):
"""
Encode an image into a numpy array.
:param image: The input image.
:type image: Image.Image
:param size: The size to resize the image to, defaults to (384, 384).
:type size: Tuple[int, int], optional
:param normalize: The mean and standard deviation for normalization, defaults to (0.5, 0.5).
:type normalize: Optional[Tuple[float, float]], optional
:return: The encoded image as a numpy array.
:rtype: np.ndarray
"""
image = image.resize(size, Image.BILINEAR)
data = rgb_encode(image, order_='CHW')
if normalize is not None:
mean_, std_ = normalize
mean = np.asarray([mean_]).reshape((-1, 1, 1))
std = np.asarray([std_]).reshape((-1, 1, 1))
data = (data - mean) / std
return data.astype(np.float32)
[docs]class ClassifyModel:
"""
Class for managing classification models.
This class provides methods for loading classification models, predicting scores, and predictions.
Methods:
predict_score: Predicts the scores for each class.
predict: Predicts the class with the highest score.
clear: Clears the loaded models and labels.
Attributes:
None
"""
[docs] def __init__(self, repo_id: str):
"""
Initialize the ClassifyModel instance.
:param repo_id: The repository ID containing the models.
:type repo_id: str
"""
self.repo_id = repo_id
self._model_names = None
self._models = {}
self._labels = {}
@classmethod
def _get_hf_token(cls):
"""
Get the Hugging Face token from the environment variable.
:return: The Hugging Face token.
:rtype: str
"""
return os.environ.get('HF_TOKEN')
@property
def model_names(self) -> List[str]:
"""
Get the model names available in the repository.
:return: The list of model names.
:rtype: List[str]
"""
if self._model_names is None:
hf_fs = HfFileSystem(token=self._get_hf_token())
self._model_names = [
os.path.dirname(os.path.relpath(item, self.repo_id)) for item in
hf_fs.glob(f'{self.repo_id}/*/model.onnx')
]
return self._model_names
def _check_model_name(self, model_name: str):
"""
Check if the model name is valid.
:param model_name: The name of the model.
:type model_name: str
:raises ValueError: If the model name is invalid.
"""
if model_name not in self.model_names:
raise ValueError(f'Unknown model {model_name!r} in model repository {self.repo_id!r}, '
f'models {self.model_names!r} are available.')
def _open_model(self, model_name: str):
"""
Open the specified model.
:param model_name: The name of the model.
:type model_name: str
:return: The opened model.
:rtype: Any
"""
if model_name not in self._models:
self._check_model_name(model_name)
self._models[model_name] = open_onnx_model(hf_hub_download(
self.repo_id,
f'{model_name}/model.onnx',
token=self._get_hf_token(),
))
return self._models[model_name]
def _open_label(self, model_name: str) -> List[str]:
"""
Open the labels file for the specified model.
:param model_name: The name of the model.
:type model_name: str
:return: The list of labels.
:rtype: List[str]
"""
if model_name not in self._labels:
self._check_model_name(model_name)
with open(hf_hub_download(
self.repo_id,
f'{model_name}/meta.json',
token=self._get_hf_token(),
), 'r') as f:
self._labels[model_name] = json.load(f)['labels']
return self._labels[model_name]
def _raw_predict(self, image: ImageTyping, model_name: str):
"""
Make a raw prediction on the specified image using the specified model.
:param image: The input image.
:type image: ImageTyping
:param model_name: The name of the model.
:type model_name: str
:return: The raw prediction.
:rtype: np.ndarray
"""
image = load_image(image, force_background='white', mode='RGB')
model = self._open_model(model_name)
batch, channels, height, width = model.get_inputs()[0].shape
if channels != 3:
raise RuntimeError(f'Model {model_name!r} required {[batch, channels, height, width]!r}, '
f'channels not 3.') # pragma: no cover
if isinstance(height, int) and isinstance(width, int):
input_ = _img_encode(image, size=(width, height))[None, ...]
else:
input_ = _img_encode(image)[None, ...]
output, = self._open_model(model_name).run(['output'], {'input': input_})
return output
[docs] def predict_score(self, image: ImageTyping, model_name: str) -> Dict[str, float]:
"""
Predict the scores for each class.
:param image: The input image.
:type image: ImageTyping
:param model_name: The name of the model.
:type model_name: str
:return: The dictionary containing class scores.
:rtype: Dict[str, float]
"""
output = self._raw_predict(image, model_name)
values = dict(zip(self._open_label(model_name), map(lambda x: x.item(), output[0])))
return values
[docs] def predict(self, image: ImageTyping, model_name: str) -> Tuple[str, float]:
"""
Predict the class with the highest score.
:param image: The input image.
:type image: ImageTyping
:param model_name: The name of the model.
:type model_name: str
:return: The predicted class and its score.
:rtype: Tuple[str, float]
"""
output = self._raw_predict(image, model_name)[0]
max_id = np.argmax(output)
return self._open_label(model_name)[max_id], output[max_id].item()
[docs] def clear(self):
"""
Clear the loaded models and labels.
"""
self._models.clear()
self._labels.clear()
@lru_cache()
def _open_models_for_repo_id(repo_id: str) -> ClassifyModel:
"""
Open classification models for the specified repository ID.
:param repo_id: The repository ID containing the models.
:type repo_id: str
:return: The ClassifyModel instance for the repository.
:rtype: ClassifyModel
"""
return ClassifyModel(repo_id)
[docs]def classify_predict_score(image: ImageTyping, repo_id: str, model_name: str) -> Dict[str, float]:
"""
Predict the scores for each class using the specified model.
:param image: The input image.
:type image: ImageTyping
:param repo_id: The repository ID containing the models.
:type repo_id: str
:param model_name: The name of the model.
:type model_name: str
:return: The dictionary containing class scores.
:rtype: Dict[str, float]
"""
return _open_models_for_repo_id(repo_id).predict_score(image, model_name)
[docs]def classify_predict(image: ImageTyping, repo_id: str, model_name: str) -> Tuple[str, float]:
"""
Predict the class with the highest score using the specified model.
:param image: The input image.
:type image: ImageTyping
:param repo_id: The repository ID containing the models.
:type repo_id: str
:param model_name: The name of the model.
:type model_name: str
:return: The predicted class and its score.
:rtype: Tuple[str, float]
"""
return _open_models_for_repo_id(repo_id).predict(image, model_name)