aac_datasets.utils.collate module

class AdvancedCollate(
fill_values: Dict[str, float | int],
)[source]

Bases: object

Advanced collate object for DataLoader.

Merge lists in dicts into a single dict of lists. Audio will be padded if a fill value is given in __init__.

Example
>>> collate = AdvancedCollate({"audio": 0.0})
>>> loader = DataLoader(..., collate_fn=collate)
>>> next(iter(loader))
... {"audio": tensor([[...]]), ...}
class BasicCollate[source]

Bases: object

Collate object for DataLoader.

Merge lists in dicts into a single dict of lists. No padding is applied.

can_be_stacked(
tensors: List[Tensor],
) bool[source]

Returns true if a list of tensors can be stacked with torch.stack function.

pad_last_dim(
tensor: Tensor,
target_length: int,
pad_value: float,
) Tensor[source]

Left padding tensor at last dim.

Parameters:
  • tensor – Tensor of at least 1 dim. (…, T)

  • target_length – Target length of the last dim. If target_length <= T, the function has no effect.

  • pad_value – Fill value used to pad tensor.

Returns:

A tensor of shape (…, target_length).