aac_datasets.utils.collate module

class AdvancedCollate(
fill_values: dict[str, float | int],
)[source]

Bases: AdvancedCollateDict

Advanced collate object for DataLoader.

Merge lists in dicts into a single dict of lists. Audio will be padded if a fill value is given in __init__.

Example
>>> collate = AdvancedCollate({"audio": 0.0})
>>> loader = DataLoader(..., collate_fn=collate)
>>> next(iter(loader))
... {"audio": tensor([[...]]), ...}
class BasicCollate(
key_mode: 'intersect' | 'same' | 'union' = 'intersect',
)[source]

Bases: CollateDict

Collate object for DataLoader.

Merge lists in dicts into a single dict of lists. No padding is applied.

pad_last_dim(
tensor: Tensor,
target_length: int,
pad_value: float,
) Tensor[source]

Left padding tensor at last dim.

Parameters:
tensor: Tensor

Tensor of at least 1 dim. (…, T)

target_length: int

Target length of the last dim. If target_length <= T, the function has no effect.

pad_value: float

Fill value used to pad tensor.

Returns:

A tensor of shape (…, target_length).