Source code for imgutils.validate.teen

"""
Overview:
    A model for classifying teen anime images into 4 classes (``contentious``, ``safe_teen``, ``non_teen"``).

    The following are sample images for testing.

    .. image:: teen.plot.py.svg
        :align: center

    This is an overall benchmark of all the classification validation models:

    .. image:: teen_benchmark.plot.py.svg
        :align: center

    The models are hosted on
    `huggingface - deepghs/anime_teen <https://huggingface.co/deepghs/anime_teen>`_.
"""
from typing import Tuple, Dict

from ..data import ImageTyping
from ..generic import classify_predict, classify_predict_score

__all__ = [
    'anime_teen_score',
    'anime_teen',
]

_DEFAULT_MODEL_NAME = 'mobilenetv3_v0_dist'
_REPO_ID = 'deepghs/anime_teen'


[docs]def anime_teen_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]: """ Overview: Predict the class of the given image, return the score with as a dict object. :param image: Image to teen. :param model_name: Model to use. Default is ``mobilenetv3_v0_dist``. All available models are listed on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_v0``. :return: A dict with classes and scores. Examples:: >>> from imgutils.validate import anime_teen_score >>> >>> anime_teen_score('teen/contentious/1.jpg') {'contentious': 0.9998493194580078, 'safe_teen': 3.0378791052498855e-05, 'non_teen': 0.00012023092131130397} >>> anime_teen_score('teen/contentious/2.jpg') {'contentious': 0.9790042638778687, 'safe_teen': 0.0017522255657240748, 'non_teen': 0.01924353837966919} >>> anime_teen_score('teen/contentious/3.jpg') {'contentious': 0.9998124241828918, 'safe_teen': 4.19778298237361e-05, 'non_teen': 0.0001456339523429051} >>> anime_teen_score('teen/safe_teen/4.jpg') {'contentious': 0.0008521362324245274, 'safe_teen': 0.9989691972732544, 'non_teen': 0.00017870066221803427} >>> anime_teen_score('teen/safe_teen/5.jpg') {'contentious': 6.0992944781901315e-05, 'safe_teen': 0.9994398951530457, 'non_teen': 0.0004991036257706583} >>> anime_teen_score('teen/safe_teen/6.jpg') {'contentious': 5.2035720727872103e-05, 'safe_teen': 0.9994019269943237, 'non_teen': 0.0005460577667690814} >>> anime_teen_score('teen/non_teen/7.jpg') {'contentious': 3.0478151529678144e-05, 'safe_teen': 3.524079147609882e-05, 'non_teen': 0.999934196472168} >>> anime_teen_score('teen/non_teen/8.jpg') {'contentious': 9.786742884898558e-05, 'safe_teen': 8.653994154883549e-05, 'non_teen': 0.9998156428337097} >>> anime_teen_score('teen/non_teen/9.jpg') {'contentious': 0.0001218809193233028, 'safe_teen': 0.00013706681784242392, 'non_teen': 0.9997410178184509} """ return classify_predict_score(image, _REPO_ID, model_name)
[docs]def anime_teen(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]: """ Overview: Predict the class of the given image, return the class and its score. :param image: Image to teen. :param model_name: Model to use. Default is ``mobilenetv3_sce_dist``. All available models are listed on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_plus``. :return: A tuple contains the class and its score. Examples:: >>> from imgutils.validate import anime_teen >>> >>> anime_teen('teen/contentious/1.jpg') ('contentious', 0.9998493194580078) >>> anime_teen('teen/contentious/2.jpg') ('contentious', 0.9790042638778687) >>> anime_teen('teen/contentious/3.jpg') ('contentious', 0.9998124241828918) >>> anime_teen('teen/safe_teen/4.jpg') ('safe_teen', 0.9989691972732544) >>> anime_teen('teen/safe_teen/5.jpg') ('safe_teen', 0.9994398951530457) >>> anime_teen('teen/safe_teen/6.jpg') ('safe_teen', 0.9994019269943237) >>> anime_teen('teen/non_teen/7.jpg') ('non_teen', 0.999934196472168) >>> anime_teen('teen/non_teen/8.jpg') ('non_teen', 0.9998156428337097) >>> anime_teen('teen/non_teen/9.jpg') ('non_teen', 0.9997410178184509) """ return classify_predict(image, _REPO_ID, model_name)