aac_datasets.utils.collate module¶
- class AdvancedCollate( )[source]¶
Bases:
object
Advanced collate object for
DataLoader
.Merge lists in dicts into a single dict of lists. Audio will be padded if a fill value is given in __init__.
>>> collate = AdvancedCollate({"audio": 0.0}) >>> loader = DataLoader(..., collate_fn=collate) >>> next(iter(loader)) ... {"audio": tensor([[...]]), ...}
- class BasicCollate[source]¶
Bases:
object
Collate object for
DataLoader
.Merge lists in dicts into a single dict of lists. No padding is applied.
- can_be_stacked( ) bool [source]¶
Returns true if a list of tensors can be stacked with torch.stack function.
- pad_last_dim( ) Tensor [source]¶
Left padding tensor at last dim.
- Parameters:
tensor – Tensor of at least 1 dim. (…, T)
target_length – Target length of the last dim. If target_length <= T, the function has no effect.
pad_value – Fill value used to pad tensor.
- Returns:
A tensor of shape (…, target_length).