#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy
import logging
import os.path as osp
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
from torch import Tensor
from typing_extensions import TypedDict
from aac_datasets.datasets.base import AACDataset
from aac_datasets.datasets.functional.macs import (
MACSCard,
_get_audio_dpath,
download_macs_dataset,
load_macs_dataset,
)
from aac_datasets.utils.globals import _get_root
pylog = logging.getLogger(__name__)
[docs]class MACSItem(TypedDict):
r"""Dataclass representing a single MACS item."""
# Common attributes
audio: Tensor
captions: List[str]
dataset: str
fname: str
index: int
subset: str
sr: int
duration: float
# MACS-specific attributes
annotators_ids: List[str]
competences: List[float]
identifier: str
scene_label: str
tags: List[List[str]]
[docs]class MACS(AACDataset[MACSItem]):
r"""Unofficial MACS PyTorch dataset.
.. code-block:: text
:caption: Dataset folder tree
{root}
└── MACS
├── audio
│ └── (3930 wav files, ~13GB)
├── LICENCE.txt
├── MACS.yaml
├── MACS_competence.csv
└── tau_meta
├── fold1_evaluate.csv
├── fold1_test.csv
├── fold1_train.csv
└── meta.csv
"""
# Common globals
CARD: ClassVar[MACSCard] = MACSCard()
# Initialization
def __init__(
self,
# Common args
root: Union[str, Path, None] = None,
subset: str = MACSCard.DEFAULT_SUBSET,
download: bool = False,
transform: Optional[Callable[[MACSItem], Any]] = None,
verbose: int = 0,
force_download: bool = False,
verify_files: bool = False,
*,
# MACS-specific args
clean_archives: bool = True,
flat_captions: bool = False,
) -> None:
"""
:param root: The parent of the dataset root directory.
The data will be stored in the 'MACS' subdirectory.
defaults to ".".
:param subset: The subset of the dataset. This parameter is here only to accept the same interface than the other datasets.
The only valid subset is "full" and other values will raise a ValueError.
defaults to "full".
:param download: Download the dataset if download=True and if the dataset is not already downloaded.
defaults to False.
:param transform: The transform to apply to the global dict item. This transform is applied only in getitem method when argument is an integer.
defaults to None.
:param verbose: Verbose level to use. Can be 0 or 1.
defaults to 0.
:param force_download: If True, force to re-download file even if they exists on disk.
defaults to False.
:param verify_files: If True, check hash value when possible.
defaults to False.
:param clean_archives: If True, remove the compressed archives from disk to save space.
defaults to True.
:param flat_captions: If True, map captions to audio instead of audio to caption.
defaults to True.
"""
if subset not in MACSCard.SUBSETS:
raise ValueError(
f"Invalid argument subset={subset} for MACS. (expected one of {MACSCard.SUBSETS})"
)
root = _get_root(root)
if download:
download_macs_dataset(
root=root,
subset=subset,
force=force_download,
verbose=verbose,
clean_archives=clean_archives,
verify_files=verify_files,
)
raw_data, annotator_id_to_competence = load_macs_dataset(
root=root,
subset=subset,
verbose=verbose,
)
audio_dpath = _get_audio_dpath(root)
size = len(next(iter(raw_data.values())))
raw_data["dataset"] = [MACSCard.NAME] * size
raw_data["subset"] = [subset] * size
raw_data["fpath"] = [
osp.join(audio_dpath, fname) for fname in raw_data["fname"]
]
raw_data["index"] = list(range(size))
super().__init__(
raw_data=raw_data,
transform=transform,
column_names=MACSItem.__required_keys__,
flat_captions=flat_captions,
sr=MACSCard.SAMPLE_RATE,
verbose=verbose,
)
self._root = root
self._subset = subset
self._download = download
self._transform = transform
self._flat_captions = flat_captions
self._verbose = verbose
self._annotator_id_to_competence = annotator_id_to_competence
self.add_online_columns(
{
"audio": MACS._load_audio,
"audio_metadata": MACS._load_audio_metadata,
"duration": MACS._load_duration,
"num_channels": MACS._load_num_channels,
"num_frames": MACS._load_num_frames,
"sr": MACS._load_sr,
"competences": MACS._load_competences,
}
)
# Properties
@property
def download(self) -> bool:
return self._download
@property
def root(self) -> str:
return self._root
@property
def sr(self) -> int:
return self._sr # type: ignore
@property
def subset(self) -> str:
return self._subset
# Public methods
[docs] def get_annotator_id_to_competence_dict(self) -> Dict[int, float]:
"""Get annotator to competence dictionary."""
# Note : copy to prevent any changes on this attribute
return copy.deepcopy(self._annotator_id_to_competence)
[docs] def get_competence(self, annotator_id: int) -> float:
"""Get competence value for a specific annotator id."""
return self._annotator_id_to_competence[annotator_id]
def _load_competences(self, index: int) -> List[float]:
annotators_ids: List[int] = self.at(index, "annotators_ids")
competences = [self.get_competence(id_) for id_ in annotators_ids]
return competences
# Magic methods
def __repr__(self) -> str:
repr_dic = {
"subset": self._subset,
"size": len(self),
"num_columns": len(self.column_names),
}
repr_str = ", ".join(f"{k}={v}" for k, v in repr_dic.items())
return f"{MACSCard.PRETTY_NAME}({repr_str})"