Source code for imgutils.tagging.overlap
import copy
import json
from functools import lru_cache
from typing import Mapping, List, Union
from huggingface_hub import hf_hub_download
@lru_cache()
def _get_overlap_tags() -> Mapping[str, List[str]]:
"""
Retrieve the overlap tag information from the specified Hugging Face Hub repository.
This function downloads a JSON file containing tag overlap information and parses it into a dictionary.
:return: A dictionary where keys are tags and values are lists of overlapping tags.
:rtype: Mapping[str, List[str]]
"""
json_file = hf_hub_download(
'alea31415/tag_filtering',
'overlap_tags_simplified.json',
repo_type='dataset',
)
with open(json_file, 'r') as file:
data = json.load(file)
return data
[docs]def drop_overlap_tags(tags: Union[List[str], Mapping[str, float]]) -> Union[List[str], Mapping[str, float]]:
"""
Drop overlapping tags from the given list of tags.
This function removes tags that have overlaps with other tags based on precomputed overlap information.
:param tags: A list of tags.
:type tags: List[str]
:return: A list of tags without overlaps.
:rtype: List[str]
Examples::
>>> from imgutils.tagging import drop_overlap_tags
>>>
>>> tags = [
... '1girl', 'solo',
... 'long_hair', 'very_long_hair', 'red_hair',
... 'breasts', 'medium_breasts',
... ]
>>> drop_overlap_tags(tags)
['1girl', 'solo', 'very_long_hair', 'red_hair', 'medium_breasts']
>>>
>>> tags = {
... '1girl': 0.8849405313291128,
... 'solo': 0.8548297594823425,
... 'long_hair': 0.03910296474461261,
... 'very_long_hair': 0.6615180440330748,
... 'red_hair': 0.21552028866308015,
... 'breasts': 0.3165260620737027,
... 'medium_breasts': 0.47744464927382957,
... }
>>> drop_overlap_tags(tags)
{
'1girl': 0.8849405313291128,
'solo': 0.8548297594823425,
'very_long_hair': 0.6615180440330748,
'red_hair': 0.21552028866308015,
'medium_breasts': 0.47744464927382957
}
"""
overlap_tags_dict = _get_overlap_tags()
result_tags = []
_origin_tags = copy.deepcopy(tags)
if isinstance(tags, dict):
tags = list(tags.keys())
tags_underscore = [tag.replace(' ', '_') for tag in tags]
tags: List[str]
tags_underscore: List[str]
for tag, tag_ in zip(tags, tags_underscore):
to_remove = False
# Case 1: If the tag is a key and some of the associated values are in tags
if tag_ in overlap_tags_dict:
overlap_values = set(val for val in overlap_tags_dict[tag_])
if overlap_values.intersection(set(tags_underscore)):
to_remove = True
# Checking superword condition separately
for tag_another in tags:
if tag in tag_another and tag != tag_another:
to_remove = True
break
if not to_remove:
result_tags.append(tag)
if isinstance(_origin_tags, list):
return result_tags
elif isinstance(_origin_tags, dict):
_rtags_set = set(result_tags)
return {key: value for key, value in _origin_tags.items() if key in _rtags_set}
else:
raise TypeError(f'Unknown tags type - {_origin_tags!r}.') # pragma: no cover