Source code for aac_datasets.check

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import os.path as osp
import random
from argparse import ArgumentParser, Namespace
from typing import Dict, Iterable, Union

import pythonwrench as pw
import yaml

from aac_datasets.datasets.audiocaps import AudioCaps, AudioCapsCard
from aac_datasets.datasets.clotho import Clotho, ClothoCard
from aac_datasets.datasets.macs import MACS, MACSCard
from aac_datasets.datasets.wavcaps import WavCaps, WavCapsCard
from aac_datasets.utils.globals import get_default_root

DATASETS_NAMES = (AudioCapsCard.NAME, ClothoCard.NAME, MACSCard.NAME, WavCapsCard.NAME)

logger = logging.getLogger(__name__)


[docs]def check_directory( root: str, verbose: int = 0, datasets: Union[Iterable[str], str] = DATASETS_NAMES, ) -> Dict[str, Dict[str, int]]: """Check which datasets are installed in root. :param root: The directory to check. :param verbose: The verbose level. defaults to 0. :param datasets: The datasets to search in root directory. defaults to DATASETS_NAMES. :returns: A dictionary of datanames containing the length of each subset. """ if isinstance(datasets, str): datasets = [datasets] else: datasets = list(datasets) data_infos = [ (AudioCapsCard.NAME, AudioCaps), (ClothoCard.NAME, Clotho), (MACSCard.NAME, MACS), (WavCapsCard.NAME, WavCaps), ] data_infos = [ (ds_name, class_) for ds_name, class_ in data_infos if ds_name in datasets ] if verbose >= 1: logger.info(f"Start searching datasets in root='{root}'.") all_found_dsets = {} for ds_name, ds_class in data_infos: if verbose >= 1: logger.info(f"Searching for {ds_name}...") found_dsets = {} for subset in ds_class.CARD.SUBSETS: try: ds = ds_class(root, subset, verbose=0) if len(ds) > 0: # Try to load a random item index = random.randint(0, len(ds) - 1) ds[index] found_dsets[subset] = ds except RuntimeError: if verbose >= 2: logger.info(f"Cannot find {ds_name}_{subset}.") if len(found_dsets) > 0: all_found_dsets[ds_name] = found_dsets if verbose >= 1: msg = f"Checking if audio files exists for {len(all_found_dsets)} datasets..." logger.info(msg) for ds_name, dsets in all_found_dsets.items(): for subset, ds in dsets.items(): fpaths = ds[:, "fpath"] is_valid = [osp.isfile(fpath) for fpath in fpaths] if not all(is_valid): logger.error(f"Cannot find all audio files for {ds_name}.{subset}.") else: logger.info(f"Dataset {ds_name}.{subset} is valid.") all_valids_lens = { ds_name: {subset: len(ds) for subset, ds in dsets.items()} for ds_name, dsets in all_found_dsets.items() } return all_valids_lens
def _get_main_check_args() -> Namespace: parser = ArgumentParser(description="Check datasets in specified directory.") parser.add_argument( "--root", type=str, default=get_default_root(), help="The path to the parent directory of the datasets.", ) parser.add_argument( "--verbose", type=int, default=1, help="Verbose level of the script. 0 means silent mode, 1 is default mode and 2 add additional debugging outputs.", ) parser.add_argument( "--datasets", type=str, nargs="+", default=DATASETS_NAMES, help="The datasets to check in root directory.", ) args = parser.parse_args() return args def _main_check() -> None: args = _get_main_check_args() pw.setup_logging_verbose("aac_datasets", args.verbose) if args.verbose >= 2: logger.debug(yaml.dump({"Arguments": args.__dict__}, sort_keys=False)) valid_datasubsets = check_directory(args.root, args.verbose, args.datasets) if args.verbose >= 1: msg = f"Found {len(valid_datasubsets)}/{len(args.datasets)} dataset(s) in root='{args.root}':" print(msg) if len(valid_datasubsets) > 0: print(yaml.dump(valid_datasubsets, sort_keys=False)) if __name__ == "__main__": _main_check()