"""
This module provides utilities for creating and parsing torchvision transforms.
It includes functionality for registering custom transforms, handling interpolation modes,
and converting between different transform representations.
The module supports common image transformations like resize, center crop, tensor conversion
and normalization. It provides a flexible framework for extending with additional transforms.
"""
import copy
from functools import wraps
from typing import Union
from .base import NotParseTarget
try:
import torchvision
except (ImportError, ModuleNotFoundError):
_HAS_TORCHVISION = False
else:
_HAS_TORCHVISION = True
def _check_torchvision():
"""
Check if torchvision is available and raise error if not installed.
:raises EnvironmentError: If torchvision is not installed
"""
if not _HAS_TORCHVISION:
raise EnvironmentError('No torchvision available.\n'
'Please install it by `pip install dghs-imgutils[torchvision]`.')
def _get_interpolation_mode(value):
"""
Convert different interpolation mode representations to torchvision.transforms.InterpolationMode.
:param value: The interpolation mode value to convert. Can be int, string or InterpolationMode
:return: The corresponding InterpolationMode enum value
:raises ValueError: If the interpolation value is invalid
:raises TypeError: If the value type is not supported
"""
from torchvision.transforms import InterpolationMode
_INT_TO_INTERMODE = {
0: InterpolationMode.NEAREST,
2: InterpolationMode.BILINEAR,
3: InterpolationMode.BICUBIC,
4: InterpolationMode.BOX,
5: InterpolationMode.HAMMING,
1: InterpolationMode.LANCZOS,
}
_STR_TO_INTERMODE = {
value.value: value
for key, value in InterpolationMode.__members__.items()
}
if isinstance(value, InterpolationMode):
return value
elif isinstance(value, int):
if value not in _INT_TO_INTERMODE:
raise ValueError(f'Invalid interpolation value - {value!r}.')
return _INT_TO_INTERMODE[value]
elif isinstance(value, str):
value = value.lower()
if value not in _STR_TO_INTERMODE:
raise ValueError(f'Invalid interpolation value - {value!r}.')
return _STR_TO_INTERMODE[value]
else:
raise TypeError(f'Unknown type of interpolation mode - {value!r}.')
_TRANS_CREATORS = {}
def _register_transform(name: str, safe: bool = True):
"""
Register a transform creation function.
:param name: Name of the transform
:param safe: Whether to check torchvision availability
:return: Decorator function
"""
if safe:
_check_torchvision()
def _fn(func):
_TRANS_CREATORS[name] = func
return func
return _fn
_TRANS_PARSERS = {}
def _register_parse(name: str, safe: bool = True):
"""
Register a transform parsing function.
:param name: Name of the transform parser
:param safe: Whether to check torchvision availability
:return: Decorator function
"""
if safe:
_check_torchvision()
def _fn(func):
@wraps(func)
def _new_func(*args, **kwargs):
return {
'type': name,
**func(*args, **kwargs),
}
_TRANS_PARSERS[name] = _new_func
return _new_func
return _fn
[docs]def register_torchvision_parse(name: str):
"""
Register a torchvision transform parsing function.
:param name: Name of the transform parser
:return: Decorator function
"""
return _register_parse(name, safe=True)
@_register_transform('resize', safe=False)
def _create_resize(size, interpolation='bilinear', max_size=None, antialias=True):
"""
Create a torchvision Resize transform.
:param size: Target size
:param interpolation: Interpolation mode
:param max_size: Maximum size
:param antialias: Whether to use anti-aliasing
:return: Resize transform
"""
from torchvision.transforms import Resize
return Resize(
size=size,
interpolation=_get_interpolation_mode(interpolation),
max_size=max_size,
antialias=antialias,
)
@_register_parse('resize', safe=False)
def _parse_resize(obj):
"""
Parse a Resize transform object.
:param obj: Transform object to parse
:return: Dict containing transform parameters
:raises NotParseTarget: If obj is not a Resize transform
"""
from torchvision.transforms import Resize
if not isinstance(obj, Resize):
raise NotParseTarget
obj: Resize
return {
'size': obj.size,
'interpolation': obj.interpolation.value,
'max_size': obj.max_size,
'antialias': obj.antialias,
}
@_register_transform('center_crop', safe=False)
def _create_center_crop(size):
"""
Create a torchvision CenterCrop transform.
:param size: Target size for cropping
:return: CenterCrop transform
"""
from torchvision.transforms import CenterCrop
return CenterCrop(
size=size,
)
@_register_parse('center_crop', safe=False)
def _parse_center_crop(obj):
"""
Parse a CenterCrop transform object.
:param obj: Transform object to parse
:return: Dict containing transform parameters
:raises NotParseTarget: If obj is not a CenterCrop transform
"""
from torchvision.transforms import CenterCrop
if not isinstance(obj, CenterCrop):
raise NotParseTarget
obj: CenterCrop
return {
'size': obj.size,
}
if _HAS_TORCHVISION:
from torchvision.transforms import ToTensor
class MaybeToTensor(ToTensor):
def __init__(self) -> None:
super().__init__()
def __call__(self, pic):
import torchvision.transforms.functional as F
import torch
if isinstance(pic, torch.Tensor):
return pic
return F.to_tensor(pic)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
else:
MaybeToTensor = None
@_register_transform('maybe_to_tensor', safe=False)
def _create_maybe_to_tensor():
"""
Create a MaybeToTensor transform that converts input to tensor if not already a tensor.
:return: MaybeToTensor transform
"""
return MaybeToTensor()
@_register_parse('maybe_to_tensor', safe=False)
def _parse_maybe_to_tensor(obj):
"""
Parse a MaybeToTensor transform object.
:param obj: Transform object to parse
:return: Empty dict since no parameters needed
:raises NotParseTarget: If obj is not a MaybeToTensor transform
"""
if type(obj).__name__ != 'MaybeToTensor':
raise NotParseTarget
return {}
@_register_transform('to_tensor', safe=False)
def _create_to_tensor():
"""
Create a torchvision ToTensor transform.
:return: ToTensor transform
"""
from torchvision.transforms import ToTensor
return ToTensor()
@_register_parse('to_tensor', safe=False)
def _parse_to_tensor(obj):
"""
Parse a ToTensor transform object.
:param obj: Transform object to parse
:return: Empty dict since no parameters needed
:raises NotParseTarget: If obj is not a ToTensor transform
"""
if type(obj).__name__ != 'ToTensor':
raise NotParseTarget
return {}
@_register_transform('normalize', safe=False)
def _create_normalize(mean, std, inplace=False):
"""
Create a torchvision Normalize transform.
:param mean: Sequence of means for each channel
:param std: Sequence of standard deviations for each channel
:param inplace: Whether to perform normalization in-place
:return: Normalize transform
"""
import torch
from torchvision.transforms import Normalize
return Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
inplace=inplace,
)
@_register_parse('normalize', safe=False)
def _parse_normalize(obj):
"""
Parse a Normalize transform object.
:param obj: Transform object to parse
:return: Dict containing transform parameters
:raises NotParseTarget: If obj is not a Normalize transform
"""
from torchvision.transforms import Normalize
if not isinstance(obj, Normalize):
raise NotParseTarget
obj: Normalize
return {
'mean': obj.mean.tolist() if hasattr(obj.mean, 'tolist') else obj.mean,
'std': obj.std.tolist() if hasattr(obj.std, 'tolist') else obj.std,
}