"""
Overview:
A model for classifying anime images into 5 classes (``3d``, ``bangumi``, ``comic``, ``illustration``, and ``not_painting``).
The following are sample images for testing.
.. image:: classify.plot.py.svg
:align: center
This is an overall benchmark of all the classification validation models:
.. image:: classify_benchmark.plot.py.svg
:align: center
The models are hosted on
`huggingface - deepghs/anime_classification <https://huggingface.co/deepghs/anime_classification>`_.
.. note::
In older versions of models, there are 4 classes, which means ``not_painting`` do not exist.
"""
from typing import Tuple, Dict
from ..data import ImageTyping
from ..generic import classify_predict, classify_predict_score
__all__ = [
'anime_classify_score',
'anime_classify',
]
_DEFAULT_MODEL_NAME = 'mobilenetv3_v1.5_dist'
_REPO_ID = 'deepghs/anime_classification'
[docs]def anime_classify_score(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Dict[str, float]:
"""
Overview:
Predict the class of the given image, return the score with as a dict object.
:param image: Image to classify.
:param model_name: Model to use. Default is ``mobilenetv3_v1.3_dist``. All available models are listed
on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_v1.3_focal``.
:return: A dict with classes and scores.
Examples::
>>> from imgutils.validate import anime_classify_score
>>>
>>> anime_classify_score('classify/3d/1.jpg')
{'3d': 0.8346158862113953, 'bangumi': 0.004201625939458609, 'comic': 0.0028638991061598063, 'illustration': 0.15633030235767365, 'not_painting': 0.001988308737054467}
>>> anime_classify_score('classify/3d/2.jpg')
{'3d': 0.9868855476379395, 'bangumi': 0.001178382197394967, 'comic': 0.00015886101755313575, 'illustration': 0.0005986307514831424, 'not_painting': 0.011178601533174515}
>>> anime_classify_score('classify/3d/3.jpg')
{'3d': 0.9933090209960938, 'bangumi': 0.0012440024875104427, 'comic': 0.00040085514774546027, 'illustration': 0.004924307577311993, 'not_painting': 0.00012189441622467712}
>>> anime_classify_score('classify/bangumi/4.jpg')
{'3d': 0.00031298911198973656, 'bangumi': 0.9968050718307495, 'comic': 5.182305903872475e-05, 'illustration': 0.0027923565357923508, 'not_painting': 3.7805559259140864e-05}
>>> anime_classify_score('classify/bangumi/5.jpg')
{'3d': 0.0004650334012694657, 'bangumi': 0.996709942817688, 'comic': 3.736721191671677e-05, 'illustration': 0.0027629584074020386, 'not_painting': 2.4619508621981367e-05}
>>> anime_classify_score('classify/bangumi/6.jpg')
{'3d': 0.0003803370927926153, 'bangumi': 0.998649537563324, 'comic': 5.190127922105603e-05, 'illustration': 0.0008622839814051986, 'not_painting': 5.595230686594732e-05}
>>> anime_classify_score('classify/comic/7.jpg')
{'3d': 0.0004573142796289176, 'bangumi': 0.00031435859273187816, 'comic': 0.8671838641166687, 'illustration': 0.13199880719184875, 'not_painting': 4.563074617180973e-05}
>>> anime_classify_score('classify/comic/8.jpg')
{'3d': 7.153919796110131e-06, 'bangumi': 8.290010737255216e-05, 'comic': 0.9727378487586975, 'illustration': 0.027150526642799377, 'not_painting': 2.162296004826203e-05}
>>> anime_classify_score('classify/comic/9.jpg')
{'3d': 2.4933258828241378e-05, 'bangumi': 0.0004275702522136271, 'comic': 0.995402455329895, 'illustration': 0.002233930164948106, 'not_painting': 0.001911122351884842}
>>> anime_classify_score('classify/illustration/10.jpg')
{'3d': 0.1603819727897644, 'bangumi': 0.0007561995880678296, 'comic': 0.00017044576816260815, 'illustration': 0.838487982749939, 'not_painting': 0.0002034590725088492}
>>> anime_classify_score('classify/illustration/11.jpg')
{'3d': 0.005001617129892111, 'bangumi': 0.000932251859921962, 'comic': 0.009352140128612518, 'illustration': 0.9846979379653931, 'not_painting': 1.6018555470509455e-05}
>>> anime_classify_score('classify/illustration/12.jpg')
{'3d': 0.004064667969942093, 'bangumi': 9.464051254326478e-05, 'comic': 0.025772539898753166, 'illustration': 0.9699516296386719, 'not_painting': 0.00011656546121230349}
>>> anime_classify_score('classify/not_painting/13.jpg')
{'3d': 5.287263775244355e-05, 'bangumi': 3.370255853951676e-06, 'comic': 0.01098843663930893, 'illustration': 0.0031668643932789564, 'not_painting': 0.9857884049415588}
>>> anime_classify_score('classify/not_painting/14.jpg')
{'3d': 7.499273488065228e-05, 'bangumi': 2.8419872251106426e-05, 'comic': 0.0003471920208539814, 'illustration': 0.029472889378666878, 'not_painting': 0.9700765609741211}
>>> anime_classify_score('classify/not_painting/15.jpg')
{'3d': 0.0012387704337015748, 'bangumi': 0.001172148622572422, 'comic': 9.787473391043022e-05, 'illustration': 0.003680602880194783, 'not_painting': 0.9938107132911682}
"""
return classify_predict_score(image, _REPO_ID, model_name)
[docs]def anime_classify(image: ImageTyping, model_name: str = _DEFAULT_MODEL_NAME) -> Tuple[str, float]:
"""
Overview:
Predict the class of the given image, return the class and its score.
:param image: Image to classify.
:param model_name: Model to use. Default is ``mobilenetv3_v1.3_dist``. All available models are listed
on the benchmark plot above. If you need better accuracy, just set this to ``caformer_s36_v1.3_focal``.
:return: A tuple contains the class and its score.
Examples::
>>> from imgutils.validate import anime_classify
>>>
>>> anime_classify('classify/3d/1.jpg')
('3d', 0.8346157073974609)
>>> anime_classify('classify/3d/2.jpg')
('3d', 0.9868855476379395)
>>> anime_classify('classify/3d/3.jpg')
('3d', 0.9933090209960938)
>>> anime_classify('classify/bangumi/4.jpg')
('bangumi', 0.9968050718307495)
>>> anime_classify('classify/bangumi/5.jpg')
('bangumi', 0.996709942817688)
>>> anime_classify('classify/bangumi/6.jpg')
('bangumi', 0.998649537563324)
>>> anime_classify('classify/comic/7.jpg')
('comic', 0.8671836853027344)
>>> anime_classify('classify/comic/8.jpg')
('comic', 0.9727378487586975)
>>> anime_classify('classify/comic/9.jpg')
('comic', 0.995402455329895)
>>> anime_classify('classify/illustration/10.jpg')
('illustration', 0.8384883403778076)
>>> anime_classify('classify/illustration/11.jpg')
('illustration', 0.9846979975700378)
>>> anime_classify('classify/illustration/12.jpg')
('illustration', 0.9699516296386719)
>>> anime_classify('classify/not_painting/13.jpg')
('not_painting', 0.9857884049415588)
>>> anime_classify('classify/not_painting/14.jpg')
('not_painting', 0.9700766801834106)
>>> anime_classify('classify/not_painting/15.jpg')
('not_painting', 0.9938107132911682)
"""
return classify_predict(image, _REPO_ID, model_name)