diff --git a/src/datasets/load.py b/src/datasets/load.py index 98e5f3b0b8f..43c098833cb 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -941,6 +941,7 @@ def __init__( download_config: Optional[DownloadConfig] = None, download_mode: Optional[Union[DownloadMode, str]] = None, use_exported_dataset_infos: bool = False, + cache_dir: Optional[str] = None, ): self.name = name self.commit_hash = commit_hash @@ -949,6 +950,7 @@ def __init__( self.download_config = download_config or DownloadConfig() self.download_mode = download_mode self.use_exported_dataset_infos = use_exported_dataset_infos + self.cache_dir = cache_dir increase_load_count(name) def get_module(self) -> DatasetModule: @@ -967,6 +969,7 @@ def get_module(self) -> DatasetModule: repo_type="dataset", revision=self.commit_hash, proxies=self.download_config.proxies, + cache_dir=self.cache_dir ) dataset_card_data = DatasetCard.load(dataset_readme_path).data except EntryNotFoundError: @@ -1537,6 +1540,7 @@ def dataset_module_factory( repo_type="dataset", revision=revision, proxies=download_config.proxies, + cache_dir=cache_dir ) commit_hash = os.path.basename(os.path.dirname(dataset_readme_path)) except LocalEntryNotFoundError as e: @@ -1583,6 +1587,7 @@ def dataset_module_factory( repo_type="dataset", revision=commit_hash, proxies=download_config.proxies, + cache_dir=cache_dir ) if _require_custom_configs or (revision and revision != "main"): can_load_config_from_parquet_export = False @@ -1626,6 +1631,7 @@ def dataset_module_factory( download_config=download_config, download_mode=download_mode, use_exported_dataset_infos=use_exported_dataset_infos, + cache_dir=cache_dir ).get_module() except GatedRepoError as e: message = f"Dataset '{path}' is a gated dataset on the Hub." diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index bbd19859b65..aa959e55854 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -193,6 +193,7 @@ def cached_path( filename=resolved_path.path_in_repo, force_download=download_config.force_download, proxies=download_config.proxies, + cache_dir=cache_dir ) except ( huggingface_hub.utils.RepositoryNotFoundError,