Source code for aac_datasets.utils.collate

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Dict, Union

from pythonwrench.warnings import deprecated_alias, deprecated_function
from torch import Tensor
from torchwrench.nn.functional.padding import pad_dim
from torchwrench.nn.functional.predicate import is_stackable
from torchwrench.utils.data.collate import AdvancedCollateDict, CollateDict, KeyMode


[docs]class BasicCollate(CollateDict): """Collate object for :class:`~torch.utils.data.dataloader.DataLoader`. Merge lists in dicts into a single dict of lists. No padding is applied. """ def __init__(self, key_mode: KeyMode = "intersect") -> None: super().__init__(key_mode)
[docs]class AdvancedCollate(AdvancedCollateDict): """Advanced collate object for :class:`~torch.utils.data.dataloader.DataLoader`. Merge lists in dicts into a single dict of lists. Audio will be padded if a fill value is given in `__init__`. .. code-block:: python :caption: Example >>> collate = AdvancedCollate({"audio": 0.0}) >>> loader = DataLoader(..., collate_fn=collate) >>> next(iter(loader)) ... {"audio": tensor([[...]]), ...} """ def __init__(self, fill_values: Dict[str, Union[float, int]]) -> None: super().__init__(fill_values, key_mode="intersect")
[docs]@deprecated_function() def pad_last_dim(tensor: Tensor, target_length: int, pad_value: float) -> Tensor: """Left padding tensor at last dim. :param tensor: Tensor of at least 1 dim. (..., T) :param target_length: Target length of the last dim. If target_length <= T, the function has no effect. :param pad_value: Fill value used to pad tensor. :returns: A tensor of shape (..., target_length). """ return pad_dim(tensor, target_length, pad_value=pad_value)
@deprecated_alias(is_stackable) def can_be_stacked(*args, **kwargs): ...