Source code for imgutils.preprocess.torchvision
"""
This module provides utilities for creating and parsing torchvision transforms.
It includes functionality for registering custom transforms, handling interpolation modes,
and converting between different transform representations.
The module supports common image transformations like resize, center crop, tensor conversion
and normalization. It provides a flexible framework for extending with additional transforms.
"""
import copy
from functools import wraps
from typing import Union
from .base import NotParseTarget
def _check_torchvision():
"""
Check if torchvision is available and raise error if not installed.
:raises EnvironmentError: If torchvision is not installed
"""
try:
import torchvision
except (ImportError, ModuleNotFoundError):
raise EnvironmentError('No torchvision available.\n'
'Please install it by `pip install dghs-imgutils[torchvision]`.')
def _get_interpolation_mode(value):
"""
Convert different interpolation mode representations to torchvision.transforms.InterpolationMode.
:param value: The interpolation mode value to convert. Can be int, string or InterpolationMode
:return: The corresponding InterpolationMode enum value
:raises ValueError: If the interpolation value is invalid
:raises TypeError: If the value type is not supported
"""
from torchvision.transforms import InterpolationMode
_INT_TO_INTERMODE = {
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
_STR_TO_INTERMODE = {
value.value: value
for key, value in InterpolationMode.__members__.items()
}
if isinstance(value, InterpolationMode):
return value
elif isinstance(value, int):
if value not in _INT_TO_INTERMODE:
raise ValueError(f'Invalid interpolation value - {value!r}.')
return _INT_TO_INTERMODE[value]
elif isinstance(value, str):
value = value.lower()
if value not in _STR_TO_INTERMODE:
raise ValueError(f'Invalid interpolation value - {value!r}.')
return _STR_TO_INTERMODE[value]
else:
raise TypeError(f'Unknown type of interpolation mode - {value!r}.')
_TRANS_CREATORS = {}
def _register_transform(name: str, safe: bool = True):
"""
Register a transform creation function.
:param name: Name of the transform
:param safe: Whether to check torchvision availability
:return: Decorator function
"""
if safe:
_check_torchvision()
def _fn(func):
_TRANS_CREATORS[name] = func
return func
return _fn
[docs]def register_torchvision_transform(name: str):
"""
Register a torchvision transform creation function.
:param name: Name of the transform
:return: Decorator function
"""
return _register_transform(name, safe=True)
_TRANS_PARSERS = {}
def _register_parse(name: str, safe: bool = True):
"""
Register a transform parsing function.
:param name: Name of the transform parser
:param safe: Whether to check torchvision availability
:return: Decorator function
"""
if safe:
_check_torchvision()
def _fn(func):
@wraps(func)
def _new_func(*args, **kwargs):
return {
'type': name,
**func(*args, **kwargs),
}
_TRANS_PARSERS[name] = _new_func
return _new_func
return _fn
[docs]def register_torchvision_parse(name: str):
"""
Register a torchvision transform parsing function.
:param name: Name of the transform parser
:return: Decorator function
"""
return _register_parse(name, safe=True)
@_register_transform('resize', safe=False)
def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True):
"""
Create a torchvision Resize transform.
:param size: Target size
:param interpolation: Interpolation mode
:param max_size: Maximum size
:param antialias: Whether to use anti-aliasing
:return: Resize transform
"""
from torchvision.transforms import Resize
return Resize(
size=size,
interpolation=_get_interpolation_mode(interpolation),
max_size=max_size,
antialias=antialias,
)
@_register_parse('resize', safe=False)
def _parse_resize(obj):
"""
Parse a Resize transform object.
:param obj: Transform object to parse
:return: Dict containing transform parameters
:raises NotParseTarget: If obj is not a Resize transform
"""
from torchvision.transforms import Resize
if not isinstance(obj, Resize):
raise NotParseTarget
obj: Resize
return {
'size': obj.size,
'interpolation': obj.interpolation.value,
'max_size': obj.max_size,
'antialias': obj.antialias,
}
@_register_transform('center_crop', safe=False)
def _create_center_crop(size):
"""
Create a torchvision CenterCrop transform.
:param size: Target size for cropping
:return: CenterCrop transform
"""
from torchvision.transforms import CenterCrop
return CenterCrop(
size=size,
)
@_register_parse('center_crop', safe=False)
def _parse_center_crop(obj):
"""
Parse a CenterCrop transform object.
:param obj: Transform object to parse
:return: Dict containing transform parameters
:raises NotParseTarget: If obj is not a CenterCrop transform
"""
from torchvision.transforms import CenterCrop
if not isinstance(obj, CenterCrop):
raise NotParseTarget
obj: CenterCrop
return {
'size': obj.size,
}
@_register_transform('maybe_to_tensor', safe=False)
def _create_maybe_to_tensor():
"""
Create a MaybeToTensor transform that converts input to tensor if not already a tensor.
:return: MaybeToTensor transform
"""
from torchvision.transforms import ToTensor
class MaybeToTensor(ToTensor):
def __init__(self) -> None:
super().__init__()
def __call__(self, pic):
import torchvision.transforms.functional as F
import torch
if isinstance(pic, torch.Tensor):
return pic
return F.to_tensor(pic)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
return MaybeToTensor()
@_register_parse('maybe_to_tensor', safe=False)
def _parse_maybe_to_tensor(obj):
"""
Parse a MaybeToTensor transform object.
:param obj: Transform object to parse
:return: Empty dict since no parameters needed
:raises NotParseTarget: If obj is not a MaybeToTensor transform
"""
if type(obj).__name__ != 'MaybeToTensor':
raise NotParseTarget
return {}
@_register_transform('to_tensor', safe=False)
def _create_to_tensor():
"""
Create a torchvision ToTensor transform.
:return: ToTensor transform
"""
from torchvision.transforms import ToTensor
return ToTensor()
@_register_parse('to_tensor', safe=False)
def _parse_to_tensor(obj):
"""
Parse a ToTensor transform object.
:param obj: Transform object to parse
:return: Empty dict since no parameters needed
:raises NotParseTarget: If obj is not a ToTensor transform
"""
if type(obj).__name__ != 'ToTensor':
raise NotParseTarget
return {}
@_register_transform('normalize', safe=False)
def _create_normalize(mean, std, inplace=False):
"""
Create a torchvision Normalize transform.
:param mean: Sequence of means for each channel
:param std: Sequence of standard deviations for each channel
:param inplace: Whether to perform normalization in-place
:return: Normalize transform
"""
import torch
from torchvision.transforms import Normalize
return Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
inplace=inplace,
)
@_register_parse('normalize', safe=False)
def _parse_normalize(obj):
"""
Parse a Normalize transform object.
:param obj: Transform object to parse
:return: Dict containing transform parameters
:raises NotParseTarget: If obj is not a Normalize transform
"""
from torchvision.transforms import Normalize
if not isinstance(obj, Normalize):
raise NotParseTarget
obj: Normalize
return {
'mean': obj.mean.tolist() if hasattr(obj.mean, 'tolist') else obj.mean,
'std': obj.std.tolist() if hasattr(obj.std, 'tolist') else obj.std,
}
[docs]def create_torchvision_transforms(tvalue: Union[list, dict]):
"""
Create torchvision transforms from config.
:param tvalue: Transform configuration as list or dict
:return: Composed transforms or single transform
:raises TypeError: If tvalue has unsupported type
:example:
>>> from imgutils.preprocess import create_torchvision_transforms
>>>
>>> create_torchvision_transforms({
... 'type': 'resize',
... 'size': 384,
... 'interpolation': 'bicubic',
... })
Resize(size=384, interpolation=bicubic, max_size=None, antialias=True)
>>> create_torchvision_transforms({
... 'type': 'resize',
... 'size': (224, 256),
... 'interpolation': 'bilinear',
... })
Resize(size=(224, 256), interpolation=bilinear, max_size=None, antialias=True)
>>> create_torchvision_transforms({'type': 'center_crop', 'size': 224})
CenterCrop(size=(224, 224))
>>> create_torchvision_transforms({'type': 'to_tensor'})
ToTensor()
>>> create_torchvision_transforms({'type': 'maybe_to_tensor'})
MaybeToTensor()
>>> create_torchvision_transforms({'type': 'normalize', 'mean': 0.5, 'std': 0.5})
Normalize(mean=0.5, std=0.5)
>>> create_torchvision_transforms({
... 'type': 'normalize',
... 'mean': [0.485, 0.456, 0.406],
... 'std': [0.229, 0.224, 0.225],
... })
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
>>> create_torchvision_transforms([
... {'antialias': True,
... 'interpolation': 'bicubic',
... 'max_size': None,
... 'size': 384,
... 'type': 'resize'},
... {'size': (224, 224), 'type': 'center_crop'},
... {'type': 'maybe_to_tensor'},
... {'mean': 0.5, 'std': 0.5, 'type': 'normalize'}
... ])
Compose(
Resize(size=384, interpolation=bicubic, max_size=None, antialias=True)
CenterCrop(size=(224, 224))
MaybeToTensor()
Normalize(mean=0.5, std=0.5)
)
.. note::
Currently the following transforms are supported:
- `torchvision.transforms.Resize`
- `torchvision.transforms.CenterCrop`
- `torchvision.transforms.ToTensor`
- `timm.data.MaybeToTensor`
- `torchvision.transforms.Normalize`
"""
_check_torchvision()
from torchvision.transforms import Compose
if isinstance(tvalue, list):
return Compose([create_torchvision_transforms(titem) for titem in tvalue])
elif isinstance(tvalue, dict):
tvalue = copy.deepcopy(tvalue)
ttype = tvalue.pop('type')
return _TRANS_CREATORS[ttype](**tvalue)
else:
raise TypeError(f'Unknown type of transforms - {tvalue!r}.')
[docs]def parse_torchvision_transforms(value):
"""
Parse torchvision transforms into config dict.
:param value: Transform object to parse
:return: Transform configuration as list or dict
:raises TypeError: If transform type is not supported
:example:
>>> from timm.data import MaybeToTensor
>>> from torchvision.transforms import Resize, InterpolationMode, CenterCrop, ToTensor, Normalize
>>>
>>> from imgutils.preprocess import parse_torchvision_transforms
>>>
>>> parse_torchvision_transforms(Resize(
... size=384,
... interpolation=InterpolationMode.BICUBIC,
... ))
{'type': 'resize', 'size': 384, 'interpolation': 'bicubic', 'max_size': None, 'antialias': True}
>>> parse_torchvision_transforms(Resize(
... size=(224, 256),
... interpolation=InterpolationMode.BILINEAR,
... ))
{'type': 'resize', 'size': (224, 256), 'interpolation': 'bilinear', 'max_size': None, 'antialias': True}
>>> parse_torchvision_transforms(CenterCrop(size=224))
{'type': 'center_crop', 'size': (224, 224)}
>>> parse_torchvision_transforms(ToTensor())
{'type': 'to_tensor'}
>>> parse_torchvision_transforms(MaybeToTensor())
{'type': 'maybe_to_tensor'}
>>> parse_torchvision_transforms(Normalize(mean=0.5, std=0.5))
{'type': 'normalize', 'mean': 0.5, 'std': 0.5}
>>> parse_torchvision_transforms(Normalize(
... mean=[0.485, 0.456, 0.406],
... std=[0.229, 0.224, 0.225],
... ))
{'type': 'normalize', 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
>>> parse_torchvision_transforms(Compose([
... Resize(
... size=384,
... interpolation=Image.BICUBIC,
... ),
... CenterCrop(size=224),
... MaybeToTensor(),
... Normalize(mean=0.5, std=0.5),
... ]))
[{'antialias': True,
'interpolation': 'bicubic',
'max_size': None,
'size': 384,
'type': 'resize'},
{'size': (224, 224), 'type': 'center_crop'},
{'type': 'maybe_to_tensor'},
{'mean': 0.5, 'std': 0.5, 'type': 'normalize'}]
"""
_check_torchvision()
from torchvision.transforms import Compose
if isinstance(value, Compose):
return [
parse_torchvision_transforms(trans)
for trans in value.transforms
]
else:
for key, _parser in _TRANS_PARSERS.items():
try:
return _parser(value)
except NotParseTarget:
pass
raise TypeError(f'Unknown parse transform - {value!r}.')