#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import os.path as osp
from pathlib import Path
from typing import Any, Callable, ClassVar, Dict, List, Optional, Union
import torch
import torchaudio
from torch import Tensor
from typing_extensions import NotRequired, TypedDict
try:
# To support torchaudio >= 2.1.0
from torchaudio import AudioMetaData # type: ignore
except ImportError:
from torchaudio.backend.common import AudioMetaData
from aac_datasets.datasets.base import AACDataset
from aac_datasets.datasets.functional.audiocaps import (
AudioCapsCard,
_get_audio_subset_dpath,
download_audiocaps_dataset,
load_audiocaps_dataset,
)
from aac_datasets.utils.globals import _get_ffmpeg_path, _get_root, _get_ytdlp_path
pylog = logging.getLogger(__name__)
[docs]class AudioCapsItem(TypedDict):
r"""Class representing a single AudioCaps item."""
# Common attributes
audio: Tensor
captions: List[str]
dataset: str
fname: str
index: int
subset: str
sr: int
duration: float
# AudioCaps-specific attributes
audiocaps_ids: List[int]
start_time: int
tags: NotRequired[List[int]]
youtube_id: str
[docs]class AudioCaps(AACDataset[AudioCapsItem]):
r"""Unofficial AudioCaps PyTorch dataset.
Subsets available are 'train', 'val' and 'test'.
Audio is a waveform tensor of shape (1, n_times) of 10 seconds max, sampled at 32kHz by default.
Target is a list of strings containing the captions.
The 'train' subset has only 1 caption per sample and 'val' and 'test' have 5 captions.
Download requires 'yt-dlp' and 'ffmpeg' commands.
AudioCaps paper : https://www.aclweb.org/anthology/N19-1011.pdf
.. code-block:: text
:caption: Dataset folder tree
{root}
└── AUDIOCAPS
├── train.csv
├── val.csv
├── test.csv
└── audio_32000Hz
├── train
│ └── (46231/49838 flac files, ~42G for 32kHz)
├── val
│ └── (465/495 flac files, ~425M for 32kHz)
└── test
└── (913/975 flac files, ~832M for 32kHz)
"""
# Common globals
CARD: ClassVar[AudioCapsCard] = AudioCapsCard()
# Initialization
def __init__(
self,
# Common args
root: Union[str, Path, None] = None,
subset: str = AudioCapsCard.DEFAULT_SUBSET,
download: bool = False,
transform: Optional[Callable[[AudioCapsItem], Any]] = None,
verbose: int = 0,
force_download: bool = False,
verify_files: bool = False,
*,
# AudioCaps-specific args
audio_duration: float = 10.0,
audio_format: str = "flac",
audio_n_channels: int = 1,
download_audio: bool = True,
exclude_removed_audio: bool = True,
ffmpeg_path: Union[str, Path, None] = None,
flat_captions: bool = False,
max_workers: Optional[int] = 1,
sr: int = 32_000,
with_tags: bool = False,
ytdlp_path: Union[str, Path, None] = None,
) -> None:
"""
:param root: Dataset root directory.
The data will be stored in the 'AUDIOCAPS' subdirectory.
defaults to ".".
:param subset: The subset of AudioCaps to use. Can be one of :attr:`~AudioCapsCard.SUBSETS`.
defaults to "train".
:param download: Download the dataset if download=True and if the dataset is not already downloaded.
defaults to False.
:param transform: The transform to apply to the global dict item. This transform is applied only in getitem method when argument is an integer.
defaults to None.
:param verbose: Verbose level.
defaults to 0.
:param force_download: If True, force to re-download file even if they exists on disk.
defaults to False.
:param verify_files: If True, check hash value when possible.
defaults to False.
:param audio_duration: Extracted duration for each audio file in seconds.
defaults to 10.0.
:param audio_format: Audio format and extension name.
defaults to "flac".
:param audio_n_channels: Number of channels extracted for each audio file.
defaults to 1.
:param download_audio: If True, download audio, metadata and labels files. Otherwise it will only donwload metadata and labels files.
defaults to True.
:param exclude_removed_audio: If True, the dataset will exclude from the dataset the audio not downloaded from youtube (i.e. not present on disk).
If False, invalid audios will return an empty tensor of shape (0,).
defaults to True.
:param ffmpeg_path: Path to ffmpeg executable file.
defaults to "ffmpeg".
:param flat_captions: If True, map captions to audio instead of audio to caption.
defaults to True.
:param max_workers: Number of threads to download audio files in parallel.
Do not use a value too high to avoid "Too Many Requests" error.
The value None will use `min(32, os.cpu_count() + 4)` workers, which is the default of ThreadPoolExecutor.
defaults to 1.
:param sr: The sample rate used for audio files in the dataset (in Hz).
Since original YouTube videos are recorded in various settings, this parameter allow to download allow audio files with a specific sample rate.
defaults to 32000.
:param verify_files: If True, check all file already downloaded are valid.
defaults to False.
:param with_tags: If True, load the tags from AudioSet dataset.
Note: tags needs to be downloaded with download=True & with_tags=True before being used.
defaults to False.
:param ytdlp_path: Path to yt-dlp or ytdlp executable.
defaults to "yt-dlp".
"""
if subset not in AudioCapsCard.SUBSETS:
raise ValueError(
f"Invalid argument subset={subset} for AudioCaps. (expected one of {AudioCapsCard.SUBSETS})"
)
root = _get_root(root)
ytdlp_path = _get_ytdlp_path(ytdlp_path)
ffmpeg_path = _get_ffmpeg_path(ffmpeg_path)
if download:
download_audiocaps_dataset(
root=root,
subset=subset,
force=force_download,
verbose=verbose,
verify_files=verify_files,
audio_duration=audio_duration,
audio_format=audio_format,
audio_n_channels=audio_n_channels,
download_audio=download_audio,
ffmpeg_path=ffmpeg_path,
max_workers=max_workers,
sr=sr,
with_tags=with_tags,
ytdlp_path=ytdlp_path,
)
raw_data, index_to_name = load_audiocaps_dataset(
root=root,
subset=subset,
verbose=verbose,
audio_format=audio_format,
exclude_removed_audio=exclude_removed_audio,
sr=sr,
with_tags=with_tags,
)
audio_subset_dpath = _get_audio_subset_dpath(root, subset, sr)
size = len(next(iter(raw_data.values())))
raw_data["dataset"] = [AudioCapsCard.NAME] * size
raw_data["subset"] = [subset] * size
raw_data["fpath"] = [
osp.join(audio_subset_dpath, fname) for fname in raw_data["fname"]
]
raw_data["index"] = list(range(size))
column_names = list(AudioCapsItem.__required_keys__) + list(
AudioCapsItem.__optional_keys__
)
if not with_tags:
column_names.remove("tags")
super().__init__(
raw_data=raw_data,
transform=transform,
column_names=column_names,
flat_captions=flat_captions,
sr=sr,
verbose=verbose,
)
# Attributes
self._root = root
self._subset = subset
self._download = download
self._exclude_removed_audio = exclude_removed_audio
self._with_tags = with_tags
self._index_to_name = index_to_name
self.add_online_columns(
{
"audio": AudioCaps._load_audio,
"audio_metadata": AudioCaps._load_audio_metadata,
"duration": AudioCaps._load_duration,
"num_channels": AudioCaps._load_num_channels,
"num_frames": AudioCaps._load_num_frames,
"sr": AudioCaps._load_sr,
}
)
# Properties
@property
def download(self) -> bool:
return self._download
@property
def exclude_removed_audio(self) -> bool:
return self._exclude_removed_audio
@property
def index_to_name(self) -> Dict[int, str]:
return self._index_to_name
@property
def root(self) -> str:
return self._root
@property
def sr(self) -> int:
return self._sr # type: ignore
@property
def subset(self) -> str:
return self._subset
@property
def with_tags(self) -> bool:
return self._with_tags
# Magic methods
def __repr__(self) -> str:
repr_dic = {
"subset": self._subset,
"size": len(self),
"num_columns": len(self.column_names),
"with_tags": self._with_tags,
"exclude_removed_audio": self._exclude_removed_audio,
}
repr_str = ", ".join(f"{k}={v}" for k, v in repr_dic.items())
return f"{AudioCapsCard.PRETTY_NAME}({repr_str})"
# Private methods
def _load_audio(self, index: int) -> Tensor:
if not self._raw_data["is_on_disk"][index]:
return torch.empty((0,))
fpath = self.at(index, "fpath")
audio, sr = torchaudio.load(fpath) # type: ignore
# Sanity check
if audio.nelement() == 0:
raise RuntimeError(
f"Invalid audio number of elements in {fpath}. (expected audio.nelement()={audio.nelement()} > 0)"
)
if self._sr is not None and (self._sr != sr):
raise RuntimeError(
f"Invalid sample rate {sr}Hz for audio {fpath}. (expected {self._sr}Hz)"
)
return audio
def _load_audio_metadata(self, index: int) -> AudioMetaData:
if not self._raw_data["is_on_disk"][index]:
return AudioMetaData(-1, -1, -1, -1, "unknown_encoding")
fpath = self.at(index, "fpath")
audio_metadata = torchaudio.info(fpath) # type: ignore
return audio_metadata