aac_datasets.utils.collate module

class AdvancedCollate(
fill_values: Dict[str, float | int],
)[source]

Bases: AdvancedCollateDict

Advanced collate object for DataLoader.

Merge lists in dicts into a single dict of lists. Audio will be padded if a fill value is given in __init__.

Example
>>> collate = AdvancedCollate({"audio": 0.0})
>>> loader = DataLoader(..., collate_fn=collate)
>>> next(iter(loader))
... {"audio": tensor([[...]]), ...}
class BasicCollate(
key_mode: Literal['intersect', 'same', 'union'] = 'intersect',
)[source]

Bases: CollateDict

Collate object for DataLoader.

Merge lists in dicts into a single dict of lists. No padding is applied.

pad_last_dim(
tensor: Tensor,
target_length: int,
pad_value: float,
) Tensor[source]

Left padding tensor at last dim.

Parameters:
  • tensor – Tensor of at least 1 dim. (…, T)

  • target_length – Target length of the last dim. If target_length <= T, the function has no effect.

  • pad_value – Fill value used to pad tensor.

Returns:

A tensor of shape (…, target_length).