Source code for aac_datasets.datasets.base

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import copy
import logging
import os.path as osp
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    Iterable,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
    overload,
)

import pythonwrench as pw
import torchaudio
import tqdm
from datasets import Dataset as HFDataset
from torch import Tensor
from torch.utils.data.dataset import Dataset
from torchwrench import IntegralTensor0D, IntegralTensor1D
from typing_extensions import TypeAlias, TypeGuard

from aac_datasets.datasets.functional.common import DatasetCard
from aac_datasets.utils.typing import AudioMetaData

logger = logging.getLogger(__name__)

ItemType = TypeVar("ItemType", covariant=True)
IndexType: TypeAlias = Union[int, Iterable[int], Iterable[bool], Tensor, slice, None]
ColumnType: TypeAlias = Union[str, Iterable[str], None]

_INDEX_TYPES = ("int", "Iterable[int]", "Iterable[bool]", "Tensor", "slice", "None")


[docs]class AACDataset(Generic[ItemType], Dataset[ItemType]): """Base class for AAC datasets.""" # Initialization def __init__( self, raw_data: Optional[Dict[str, List[Any]]] = None, transform: Optional[Callable[[ItemType], Any]] = None, column_names: Optional[Iterable[str]] = None, flat_captions: bool = False, sr: Union[int, Iterable[int], None] = None, verbose: int = 0, ) -> None: if raw_data is None: raw_data = {} if column_names is None: column_names = raw_data.keys() column_names = list(column_names) if isinstance(sr, Iterable): sr = list(sr) if len(raw_data) > 1: size = len(next(iter(raw_data.values()))) invalid_columns = [col for col, lst in raw_data.items() if len(lst) != size] if len(invalid_columns) > 0: msg = f"Invalid raw_data number of items in the following columns: {tuple(invalid_columns)}." raise ValueError(msg) super().__init__() self._raw_data = raw_data self._transform = transform self._columns = column_names self._flat_captions = flat_captions self._sr = sr self._verbose = verbose self._online_fns = {} self._sizes = [] if self._flat_captions: self._flat_raw_data()
[docs] @staticmethod def new_empty() -> "AACDataset": """Create a new empty dataset.""" return AACDataset( raw_data={}, transform=None, column_names=(), flat_captions=False, sr=None, verbose=0, )
# Properties @property def all_columns(self) -> List[str]: """The name of all columns of the dataset.""" return list(pw.union_dicts([self._raw_data, self._online_fns])) @property def column_names(self) -> List[str]: """The name of all selected column of the dataset.""" return self._columns @property def flat_captions(self) -> bool: """Returns true if captions has been flattened.""" return self._flat_captions @property def num_columns(self) -> int: """Number of columns in the dataset.""" return len(self.column_names) @property def num_rows(self) -> int: """Number of rows in the dataset (same as len()).""" return len(self) @property def raw_data(self) -> Dict[str, List[Any]]: return self._raw_data @property def shape(self) -> Tuple[int, int]: """Shape of the dataset (number of columns, number of rows).""" return len(self), len(self.column_names) @property def sr(self) -> Union[int, List[int], None]: return self._sr @property def transform(self) -> Optional[Callable]: return self._transform @property def verbose(self) -> int: return self._verbose @column_names.setter def column_names( self, columns: Iterable[str], ) -> None: columns = list(columns) self._check_columns(columns) self._columns = columns @transform.setter def transform(self, transform: Optional[Callable[[ItemType], Any]]) -> None: self._transform = transform @verbose.setter def verbose(self, verbose: int) -> None: self._verbose = verbose # Public methods
[docs] @pw.deprecated_function() def at(self, *args, **kwargs) -> Any: """Deprecated: Use get_item method instead.""" return self.get_item(*args, **kwargs)
@overload def get_item(self, index: int) -> ItemType: ... @overload def get_item( # type: ignore self, index: Union[Iterable[int], Iterable[bool], slice, None], column: str, ) -> List: ... @overload def get_item( self, index: Union[Iterable[int], Iterable[bool], slice, None], column: Union[Iterable[str], None] = None, ) -> Dict[str, List]: ... @overload def get_item(self, index: IndexType, column: ColumnType) -> Any: ...
[docs] def get_item( self, index: IndexType = None, column: ColumnType = None, ) -> Any: """Get a specific data field. :param index: The index or slice of the value in range [0, len(dataset)-1]. :param column: The name(s) of the column. Can be any value of :meth:`~Clotho.columns`. :returns: The field value. The type depends of the column. """ if index is None: index = slice(None) elif isinstance(index, Tensor): if not pw.isinstance_generic(index, (IntegralTensor0D, IntegralTensor0D)): msg = "Invalid tensor dtype. (expected integral 0d or 1d tensor)" raise TypeError(msg) index = index.tolist() if column is None: column = self.column_names if not isinstance(column, str) and isinstance(column, Iterable): return {column_i: self.get_item(index, column_i) for column_i in column} if isinstance(index, (int, slice)) and ( column in self._raw_data.keys() and column not in self._online_fns ): return self._raw_data[column][index] # type: ignore if isinstance(index, slice): index = range(len(self))[index] if isinstance(index, Iterable): index = list(index) if pw.isinstance_generic(index, Iterable[bool]): if len(index) != len(self): msg = f"The length of the mask ({len(index)}) does not match the length of the dataset ({len(self)})." raise IndexError(msg) index = [i for i, idx_i in enumerate(index) if idx_i] elif not pw.isinstance_generic(index, Iterable[int]): msg = f"Invalid input type for {index=}. (expected Iterable[int], not Iterable[{index[0].__class__.__name__}])" raise TypeError(msg) values = [ self.get_item(idx_i, column) for idx_i in tqdm.tqdm( index, desc=f"Loading column '{column}'...", disable=self._verbose < 2, ) ] return values if not isinstance(index, int): msg = ( f"Invalid argument type {type(index)}. (expected one of {_INDEX_TYPES})" ) raise TypeError(msg) return self._load_online_value(column, index)
[docs] def has_raw_column(self, column: str) -> bool: """Returns True if column name exists in raw data.""" return column in self._raw_data
[docs] def has_post_column(self, column: str) -> bool: """Returns True if column name exists in post processed data.""" return column in self._online_fns
[docs] def has_column(self, column: str) -> bool: """Returns True if column name exists in data.""" return self.has_raw_column(column) or self.has_post_column(column)
[docs] def remove_column(self, column: str) -> Union[List[Any], Callable]: """Removes a column from this dataset.""" if column in self._raw_data: column_data = self._raw_data.pop(column, []) return column_data elif column in self._online_fns: fn = self._online_fns.pop(column) return fn else: raise ValueError(f"Column '{column}' does not exists in dataset.")
[docs] def rename_column( self, old_column: str, new_column: str, allow_replace: bool = False, ) -> None: """Renames a column from this dataset.""" column_data_or_fn = self.remove_column(old_column) if isinstance(column_data_or_fn, List): self.add_raw_column(new_column, column_data_or_fn, allow_replace) elif isinstance(column_data_or_fn, Callable): self.add_online_column(new_column, column_data_or_fn, allow_replace) else: msg = f"Invalid type {type(column_data_or_fn)}. (expected List or Callable)" raise TypeError(msg)
[docs] def add_raw_column( self, column_name: str, column_data: List[Any], allow_replace: bool = False, ) -> None: """Add a new raw column to this dataset.""" if not allow_replace and column_name in self._raw_data: msg = f"Column '{column_name}' already exists. Please choose another name or set allow_replace arg to True." raise ValueError(msg) if len(self._raw_data) > 0 and len(column_data) != len(self): msg = f"Invalid number of rows in column '{column_name}'." raise ValueError(msg) self._raw_data[column_name] = column_data
[docs] def add_online_column( self, column: str, load_fn: Callable[[Any, int], Any], allow_replace: bool = False, ) -> None: """Add a new post-processed column to this dataset.""" if not allow_replace and column in self._online_fns: msg = f"Column '{column}' already exists in {self} and found argument {allow_replace=}." raise ValueError(msg) self._online_fns[column] = load_fn
[docs] def add_online_columns( self, post_columns_fns: Dict[str, Callable[[Any, int], Any]], allow_replace: bool = False, ) -> None: """Add several new post-processed columns to this dataset.""" for name, load_fn in post_columns_fns.items(): self.add_online_column(name, load_fn, allow_replace)
[docs] def preload_online_column( self, column: str, allow_replace: bool = False, ) -> Callable[[Any, int], Any]: """Load all data from a post-column data into raw data.""" if column not in self._online_fns: msg = f"Invalid argument {column=}." raise ValueError(msg) column_data = [ self._load_online_value(column, i) for i in tqdm.trange( len(self), disable=self._verbose < 2, desc=f"Preloading column '{column}'", ) ] fn = self._online_fns.pop(column) self.add_raw_column(column, column_data, allow_replace=allow_replace) return fn
[docs] def to_dict(self, load_online_values: bool = False) -> Dict[str, List[Any]]: """Convert dataset to dictionary. :param load_online_values: If True, load ALL online values (e.g. audio waveform). Otherwise load only the raw data of the dataset. defaults to False. """ raw_data = copy.copy(self._raw_data) if load_online_values: for column_name in self._online_fns.keys(): column_data = self.get_item(None, column_name) raw_data[column_name] = column_data return raw_data
[docs] def to_list(self, load_online_values: bool = False) -> List[ItemType]: """Convert dataset to list. :param load_online_values: If True, load ALL online values (e.g. audio waveform). Otherwise load only the raw data of the dataset. defaults to False. """ raw_data = self.to_dict(load_online_values) return pw.dict_list_to_list_dict(raw_data, key_mode="same") # type: ignore
[docs] def to_hf_dataset(self, load_online_values: bool = False) -> HFDataset: datadict = self.to_dict(load_online_values=False) dataset_info = None if hasattr(self.__class__, "CARD"): card = self.__class__.CARD # type: ignore if isinstance(card, DatasetCard): dataset_info = card.to_dataset_info() split = getattr(self, "subset", None) hf_dataset = HFDataset.from_dict(datadict, info=dataset_info, split=split) del datadict if load_online_values: for colname, fn in self._online_fns.items(): def wrapped_fn(index: int) -> Any: return {colname: fn(self, index)} hf_dataset = hf_dataset.map( wrapped_fn, input_columns="index", load_from_cache_file=False, keep_in_memory=False, ) return hf_dataset
# Magic methods @overload def __getitem__(self, index: int) -> ItemType: ... @overload def __getitem__(self, index: Tuple[Union[Iterable[int], slice, None], str]) -> List: # type: ignore ... @overload def __getitem__( self, index: Union[Iterable[int], slice, None] ) -> Dict[str, List]: ... @overload def __getitem__( self, index: Tuple[Union[Iterable[int], slice, None], Union[Iterable[str], None]], ) -> Dict[str, List]: ... @overload def __getitem__(self, index: Any) -> Any: ... def __getitem__(self, index: Union[IndexType, Tuple[IndexType, ColumnType]]) -> Any: # type: ignore if ( isinstance(index, tuple) and len(index) == 2 and _is_index(index[0]) and _is_column(index[1]) ): index, column = index else: column = None item = self.get_item(index, column) # type: ignore if ( isinstance(index, int) and self._transform is not None and ( column is None or ( isinstance(column, Iterable) and not isinstance(column, str) and set(column) == set(self._columns) ) ) ): item = self._transform(item) # type: ignore return item def __len__(self) -> int: """ :return: The number of items in the dataset. """ if len(self._raw_data) > 0: return len(next(iter(self._raw_data.values()))) else: return 0 def __repr__(self) -> str: info = { "size": len(self), "num_columns": self.num_columns, } repr_str = ", ".join(f"{k}={v}" for k, v in info.items()) return f"{self.__class__.__name__}({repr_str})" # Private methods def _check_columns(self, columns: List[str]) -> None: expected_columns = dict.fromkeys(self.all_columns) invalid_columns = [name for name in columns if name not in expected_columns] if len(invalid_columns) > 0: msg = f"Invalid argument {columns=}. (found {len(invalid_columns)} invalids column names for {self.__class__.__name__}: {invalid_columns})" raise ValueError(msg) invalid_columns = [name for name in columns if not self.has_column(name)] if len(invalid_columns) > 0: msg = f"Invalid argument {columns=}. (found {len(invalid_columns)} invalids column names for {self.__class__.__name__}: {invalid_columns})" raise ValueError(msg) def _flat_raw_data(self) -> None: raw_data, sizes = _flat_raw_data(self._raw_data) self._raw_data = raw_data self._sizes = sizes def _unflat_raw_data(self) -> None: raw_data = _unflat_raw_data(self._raw_data, self._sizes) self._raw_data = raw_data def _load_online_value(self, column: str, index: int) -> Any: if column in self._online_fns: fn = self._online_fns[column] return fn(self, index) else: msg = f"Invalid argument column={column} at {index=}. (expected one of {self.all_columns})" raise ValueError(msg) def _load_audio(self, index: int) -> Tensor: fpath = self.get_item(index, "fpath") audio_and_sr: Tuple[Tensor, int] = torchaudio.load(fpath) # type: ignore audio, sr = audio_and_sr # Sanity check if audio.nelement() == 0: msg = f"Invalid audio number of elements in {fpath}. (expected {audio.nelement()=} > 0)" raise RuntimeError(msg) if self._sr is not None and (self._sr != sr): msg = ( f"Invalid sample rate {sr}Hz for audio {fpath}. (expected {self._sr}Hz)" ) raise RuntimeError(msg) return audio def _load_audio_metadata(self, index: int) -> AudioMetaData: fpath = self.get_item(index, "fpath") audio_metadata = torchaudio.info(fpath) # type: ignore return audio_metadata def _load_duration(self, index: int) -> float: audio_metadata: AudioMetaData = self.get_item(index, "audio_metadata") duration = audio_metadata.num_frames / audio_metadata.sample_rate return duration def _load_fname(self, index: int) -> str: fpath = self.get_item(index, "fpath") fname = osp.basename(fpath) return fname def _load_num_channels(self, index: int) -> int: audio_metadata = self.get_item(index, "audio_metadata") num_channels = audio_metadata.num_channels return num_channels def _load_num_frames(self, index: int) -> int: audio_metadata = self.get_item(index, "audio_metadata") num_frames = audio_metadata.num_frames return num_frames def _load_sr(self, index: int) -> int: audio_metadata = self.get_item(index, "audio_metadata") sr = audio_metadata.sample_rate return sr
def _is_index(index: Any) -> TypeGuard[IndexType]: return pw.isinstance_generic( index, ( int, Iterable[int], Iterable[bool], slice, pw.NoneType, IntegralTensor0D, IntegralTensor1D, ), ) def _is_column(column: Any) -> TypeGuard[ColumnType]: return pw.isinstance_generic(column, (Iterable[str], pw.NoneType)) def _flat_raw_data( raw_data: Dict[str, List[Any]], caps_column: str = "captions", ) -> Tuple[Dict[str, List[Any]], List[int]]: if caps_column not in raw_data: msg = f"Cannot flat raw data without '{caps_column}' column. (found only columns {tuple(raw_data.keys())})" raise ValueError(msg) mcaps: List[List[str]] = raw_data[caps_column] raw_data_flat = {key: [] for key in raw_data.keys()} for i, caps in enumerate(mcaps): if len(caps) == 0: for key in raw_data.keys(): raw_data_flat[key].append(raw_data[key][i]) else: for cap in caps: for key in raw_data.keys(): if key == caps_column: continue raw_data_flat[key].append(raw_data[key][i]) # Overwrite cap raw_data_flat[caps_column].append([cap]) sizes = [len(caps) for caps in mcaps] return raw_data_flat, sizes def _unflat_raw_data( raw_data_flat: Dict[str, List[Any]], sizes: List[int], caps_column: str = "captions", ) -> Dict[str, List[Any]]: if caps_column not in raw_data_flat: msg = f"Cannot flat raw data without '{caps_column}' column. (found only columns {tuple(raw_data_flat.keys())})" raise ValueError(msg) raw_data = {key: [] for key in raw_data_flat.keys()} cumsize = 0 for size in sizes: for key in raw_data.keys(): if key == caps_column: caps = [ raw_data_flat[key][index][0] for index in range(cumsize, cumsize + size) ] raw_data[key].append(caps) else: raw_data[key].append(raw_data_flat[key][cumsize]) cumsize += size return raw_data