diff --git a/deeplake/api/tests/test_api.py b/deeplake/api/tests/test_api.py index d24ec6ae3a..0d9224d073 100644 --- a/deeplake/api/tests/test_api.py +++ b/deeplake/api/tests/test_api.py @@ -2247,7 +2247,7 @@ def test_ignore_temp_tensors(local_path): create_shape_tensor=False, create_id_tensor=False, ) - ds.__temptensor.append(123) + # ds.__temptensor.append(123) with deeplake.load(local_path) as ds: assert list(ds.tensors) == [] @@ -2270,9 +2270,8 @@ def test_ignore_temp_tensors(local_path): with deeplake.load(local_path, read_only=True) as ds: assert list(ds.tensors) == [] - assert list(ds._tensors()) == ["__temptensor"] - assert ds.meta.hidden_tensors == ["__temptensor"] - assert ds.__temptensor[0].numpy() == 123 + assert list(ds._tensors()) == [] + assert ds.meta.hidden_tensors == [] @pytest.mark.slow @@ -2550,7 +2549,7 @@ def test_invalid_ds_name(): verify_dataset_name("hub://test/data-set_123") -def test_pickle_bug(local_ds): +def test_pickle_loses_temp_tensors(local_ds): import pickle file = BytesIO() @@ -2564,9 +2563,7 @@ def test_pickle_bug(local_ds): file.seek(0) ds = pickle.load(file) - np.testing.assert_array_equal( - ds["__temp_123"].numpy(), np.array([1, 2, 3, 4, 5]).reshape(-1, 1) - ) + assert "__temp_123" not in ds def test_max_view(memory_ds): diff --git a/deeplake/core/dataset/dataset.py b/deeplake/core/dataset/dataset.py index 29bf358e09..d25f4fc8a1 100644 --- a/deeplake/core/dataset/dataset.py +++ b/deeplake/core/dataset/dataset.py @@ -1061,6 +1061,8 @@ def _delete_tensor(self, name: str, large_ok: bool = False): raise TensorDoesNotExistError(name) if not tensor_exists(key, self.storage, self.version_state["commit_id"]): + if key.startswith("__temp"): + return raise TensorDoesNotExistError(name) if not self._is_root(): diff --git a/deeplake/core/meta/dataset_meta.py b/deeplake/core/meta/dataset_meta.py index da4663ccda..ea75b3abd6 100644 --- a/deeplake/core/meta/dataset_meta.py +++ b/deeplake/core/meta/dataset_meta.py @@ -44,11 +44,28 @@ def allow_delete(self, value): self.is_dirty = True def __getstate__(self) -> Dict[str, Any]: + # d = super().__getstate__() + # d["tensors"] = self.tensors.copy() + # d["groups"] = self.groups.copy() + # d["tensor_names"] = self.tensor_names.copy() + # d["hidden_tensors"] = self.hidden_tensors.copy() + # d["default_index"] = self.default_index.copy() + # d["allow_delete"] = self._allow_delete + # return d + d = super().__getstate__() - d["tensors"] = self.tensors.copy() + d["tensors"] = list( + filter(lambda x: (not x.startswith("__temp")), self.tensors) + ) d["groups"] = self.groups.copy() - d["tensor_names"] = self.tensor_names.copy() - d["hidden_tensors"] = self.hidden_tensors.copy() + + d["tensor_names"] = { + k: v for k, v in self.tensor_names.items() if not k.startswith("__temp") + } + + d["hidden_tensors"] = list( + filter(lambda x: (not x.startswith("__temp")), self.hidden_tensors) + ) d["default_index"] = self.default_index.copy() d["allow_delete"] = self._allow_delete return d @@ -56,6 +73,22 @@ def __getstate__(self) -> Dict[str, Any]: def __setstate__(self, d): if "allow_delete" in d: d["_allow_delete"] = d.pop("allow_delete") + # + # if "hidden_tensors" in d: + # d["hidden_tensors"] = list( + # filter(lambda x: (not x.startswith("__temp")), d["hidden_tensors"]) + # ) + # + # if "tensors" in d: + # d["tensors"] = list( + # filter(lambda x: (not x.startswith("__temp")), d["tensors"]) + # ) + # + # if "tensor_names" in d: + # d["tensor_names"] = { + # k: v for k, v in d["tensor_names"].items() if not k.startswith("__temp") + # } + self.__dict__.update(d) def add_tensor(self, name, key, hidden=False): diff --git a/deeplake/core/storage/azure.py b/deeplake/core/storage/azure.py index fbc1b670e3..79959fc077 100644 --- a/deeplake/core/storage/azure.py +++ b/deeplake/core/storage/azure.py @@ -14,6 +14,8 @@ class AzureProvider(StorageProvider): def __init__(self, root: str, creds: Dict = {}, token: Optional[str] = None): + super().__init__() + try: import azure.identity import azure.storage.blob @@ -87,7 +89,7 @@ def _set_clients(self): self.container_name ) - def __setitem__(self, path, content): + def _setitem_impl(self, path, content): self.check_readonly() self._check_update_creds() if isinstance(content, memoryview): @@ -99,10 +101,10 @@ def __setitem__(self, path, content): ) blob_client.upload_blob(content, overwrite=True) - def __getitem__(self, path): + def _getitem_impl(self, path): return self.get_bytes(path) - def __delitem__(self, path): + def _delitem_impl(self, path): self.check_readonly() blob_client = self.container_client.get_blob_client( f"{self.root_folder}/{path}" @@ -111,7 +113,7 @@ def __delitem__(self, path): raise KeyError(path) blob_client.delete_blob() - def get_bytes( + def _get_bytes_impl( self, path: str, start_byte: Optional[int] = None, @@ -144,11 +146,12 @@ def get_bytes( byts = blob_client.download_blob(offset=offset, length=length).readall() return byts - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): self.check_readonly() self._check_update_creds() blobs = [ - posixpath.join(self.root_folder, key) for key in self._all_keys(prefix) + posixpath.join(self.root_folder, key) + for key in self._all_keys_impl(prefix=prefix) ] # delete_blobs can only delete 256 blobs at a time batches = [blobs[i : i + 256] for i in range(0, len(blobs), 256)] @@ -176,7 +179,7 @@ def get_sas_token(self): ) return sas_token - def _all_keys(self, prefix: str = ""): + def _all_keys_impl(self, refresh: bool = False, prefix: str = ""): self._check_update_creds() prefix = posixpath.join(self.root_folder, prefix) return { @@ -189,14 +192,9 @@ def _all_keys(self, prefix: str = ""): ) # https://github.com/Azure/azure-sdk-for-python/issues/24814 } - def __iter__(self): - yield from self._all_keys() - - def __len__(self): - self._check_update_creds() - return len(self._all_keys()) - def __getstate__(self): + super()._getstate_prepare() + return { "root": self.root, "creds": self.creds, @@ -206,6 +204,7 @@ def __getstate__(self): "db_engine": self.db_engine, "repository": self.repository, "expiration": self.expiration, + "_temp_data": self._temp_data, } def __setstate__(self, state): diff --git a/deeplake/core/storage/gcs.py b/deeplake/core/storage/gcs.py index a20cab7d22..5f7390696c 100644 --- a/deeplake/core/storage/gcs.py +++ b/deeplake/core/storage/gcs.py @@ -248,6 +248,7 @@ def __init__( Raises: ModuleNotFoundError: If google cloud packages aren't installed. """ + super().__init__() try: import google.cloud.storage # type: ignore @@ -323,9 +324,10 @@ def _set_bucket_and_path(self): def _get_path_from_key(self, key): return posixpath.join(self.path, key) - def _all_keys(self): + def _all_keys_impl(self, refresh: bool = False): self._blob_objects = self.client_bucket.list_blobs(prefix=self.path) - return {posixpath.relpath(obj.name, self.path) for obj in self._blob_objects} + all = {posixpath.relpath(obj.name, self.path) for obj in self._blob_objects} + return [f for f in all if not f.endswith("/")] def _set_hub_creds_info( self, @@ -349,7 +351,7 @@ def _set_hub_creds_info( self.db_engine = db_engine self.repository = repository - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): """Remove all keys with given prefix below root - empties out mapping. Warning: @@ -384,11 +386,11 @@ def rename(self, root): if not self.path.endswith("/"): self.path += "/" - def __getitem__(self, key): + def _getitem_impl(self, key): """Retrieve data.""" - return self.get_bytes(key) + return self._get_bytes_impl(key) - def get_bytes( + def _get_bytes_impl( self, path: str, start_byte: Optional[int] = None, @@ -418,7 +420,7 @@ def get_bytes( except self.missing_exceptions: raise KeyError(path) - def __setitem__(self, key, value): + def _setitem_impl(self, key: str, value: bytes): """Store value in key.""" self.check_readonly() blob = self.client_bucket.blob(self._get_path_from_key(key)) @@ -428,15 +430,7 @@ def __setitem__(self, key, value): value = bytes(value) blob.upload_from_string(value, retry=self.retry) - def __iter__(self): - """Iterating over the structure.""" - yield from [f for f in self._all_keys() if not f.endswith("/")] - - def __len__(self): - """Returns length of the structure.""" - return len(self._all_keys()) - - def __delitem__(self, key): + def _delitem_impl(self, key): """Remove key.""" self.check_readonly() blob = self.client_bucket.blob(self._get_path_from_key(key)) @@ -445,7 +439,7 @@ def __delitem__(self, key): except self.missing_exceptions: raise KeyError(key) - def __contains__(self, key): + def _contains_impl(self, key): """Checks if key exists in mapping.""" from google.cloud import storage # type: ignore @@ -455,6 +449,8 @@ def __contains__(self, key): return stats def __getstate__(self): + super()._getstate_prepare() + return ( self.root, self.token, @@ -463,6 +459,7 @@ def __getstate__(self): self.read_only, self.db_engine, self.repository, + self._temp_data, ) def __setstate__(self, state): @@ -473,6 +470,7 @@ def __setstate__(self, state): self.read_only = state[4] self.db_engine = state[5] self.repository = state[6] + self._temp_data = state[7] self._initialize_provider() def get_presigned_url(self, key, full=False): diff --git a/deeplake/core/storage/google_drive.py b/deeplake/core/storage/google_drive.py index 53292fe2c6..698e8e10e8 100644 --- a/deeplake/core/storage/google_drive.py +++ b/deeplake/core/storage/google_drive.py @@ -106,6 +106,7 @@ def __init__( - Due to limits on requests per 100 seconds on google drive api, continuous requests such as uploading many small files can be slow. - Users can request to increse their quotas on their google cloud platform. """ + super().__init__() try: import googleapiclient # type: ignore from google.auth.transport.requests import Request # type: ignore @@ -279,7 +280,7 @@ def get_object_by_id(self, id): file.seek(0) return file.read() - def __getitem__(self, path): + def _getitem_impl(self, path): id = self._get_id(path) if not id: raise KeyError(path) @@ -307,7 +308,7 @@ def _unlock_creation(self, path): lock_hash = "." + hash_inputs(self.root_id, path) os.remove(lock_hash) - def __setitem__(self, path, content): + def _setitem_impl(self, path, content): self.check_readonly() id = self._get_id(path) if not id: @@ -330,7 +331,7 @@ def __setitem__(self, path, content): self._write_to_file(id, content) return - def __delitem__(self, path): + def _delitem_impl(self, path): self.check_readonly() id = self._pop_id(path) if not id: @@ -338,12 +339,15 @@ def __delitem__(self, path): self._delete_file(id) def __getstate__(self): + super()._getstate_prepare() + return ( self.root, self.root_id, self.client_id, self.client_secret, self.refresh_token, + self._temp_data, ) def __setstate__(self, state): @@ -352,19 +356,14 @@ def __setstate__(self, state): self.client_id = state[2] self.client_secret = state[3] self.refresh_token = state[4] + self._temp_data = state[5] self._init_from_state() - def _all_keys(self): + def _all_keys_impl(self, refresh: bool = False): keys = set(self.gid.path_id_map.keys()) return keys - def __iter__(self): - yield from self._all_keys() - - def __len__(self): - return len(self._all_keys()) - - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): self.check_readonly() for key in self._all_keys(): if key.startswith(prefix): diff --git a/deeplake/core/storage/indra.py b/deeplake/core/storage/indra.py index 8078d76348..f52dfa369e 100644 --- a/deeplake/core/storage/indra.py +++ b/deeplake/core/storage/indra.py @@ -13,6 +13,8 @@ def __init__( read_only: Optional[bool] = False, **kwargs, ): + super().__init__() + from indra.api import storage # type: ignore if isinstance(root, str): @@ -39,14 +41,14 @@ def copy(self): def subdir(self, path: str, read_only: bool = False): return IndraProvider(self.core.subdir(path, read_only)) - def __setitem__(self, path, content): + def _setitem_impl(self, path, content): self.check_readonly() self.core.set(path, bytes(content)) - def __getitem__(self, path): + def _getitem_impl(self, path): return bytes(self.core.get(path)) - def get_bytes( + def _get_bytes_impl( self, path, start_byte: Optional[int] = None, end_byte: Optional[int] = None ): s = start_byte or 0 @@ -91,17 +93,14 @@ def get_deeplake_object( def get_object_size(self, path: str) -> int: return self.core.length(path) - def __delitem__(self, path): + def _delitem_impl(self, path): return self.core.remove(path) - def _all_keys(self): + def _all_keys_impl(self, refresh: bool = False): return self.core.list("") - def __len__(self): + def _len_impl(self): return len(self.core.list("")) - def __iter__(self): - return iter(self.core.list("")) - - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): self.core.clear(prefix) diff --git a/deeplake/core/storage/local.py b/deeplake/core/storage/local.py index 5dd7501829..56ff69ff21 100644 --- a/deeplake/core/storage/local.py +++ b/deeplake/core/storage/local.py @@ -29,6 +29,7 @@ def __init__(self, root: str): Raises: FileAtPathException: If the root is a file instead of a directory. """ + super().__init__() if os.path.isfile(root): raise FileAtPathException(root) self.root = root @@ -48,25 +49,7 @@ def subdir(self, path: str, read_only: bool = False): sd.read_only = read_only return sd - def __getitem__(self, path: str): - """Gets the object present at the path within the given byte range. - - Example: - - >>> local_provider = LocalProvider("/home/ubuntu/Documents/") - >>> my_data = local_provider["abc.txt"] - - Args: - path (str): The path relative to the root of the provider. - - Returns: - bytes: The bytes of the object present at the path. - - Raises: - KeyError: If an object is not found at the path. - DirectoryAtPathException: If a directory is found at the path. - Exception: Any other exception encountered while trying to fetch the object. - """ + def _getitem_impl(self, path: str): try: full_path = self._check_is_file(path) with open(full_path, "rb") as file: @@ -76,24 +59,7 @@ def __getitem__(self, path: str): except FileNotFoundError: raise KeyError(path) - def __setitem__(self, path: str, value: bytes): - """Sets the object present at the path with the value - - Example: - - >>> local_provider = LocalProvider("/home/ubuntu/Documents/") - >>> local_provider["abc.txt"] = b"abcd" - - Args: - path (str): the path relative to the root of the provider. - value (bytes): the value to be assigned at the path. - - Raises: - Exception: If unable to set item due to directory at path or permission or space issues. - FileAtPathException: If the directory to the path is a file instead of a directory. - ReadOnlyError: If the provider is in read-only mode. - """ - self.check_readonly() + def _setitem_impl(self, path: str, value: bytes): full_path = self._check_is_file(path) directory = os.path.dirname(full_path) if os.path.isfile(directory): @@ -105,24 +71,7 @@ def __setitem__(self, path: str, value: bytes): if self.files is not None: self.files.add(path) - def __delitem__(self, path: str): - """Delete the object present at the path. - - Example: - - >>> local_provider = LocalProvider("/home/ubuntu/Documents/") - >>> del local_provider["abc.txt"] - - Args: - path (str): the path to the object relative to the root of the provider. - - Raises: - KeyError: If an object is not found at the path. - DirectoryAtPathException: If a directory is found at the path. - Exception: Any other exception encountered while trying to fetch the object. - ReadOnlyError: If the provider is in read-only mode. - """ - self.check_readonly() + def _delitem_impl(self, path: str): try: full_path = self._check_is_file(path) os.remove(full_path) @@ -133,42 +82,7 @@ def __delitem__(self, path: str): except FileNotFoundError: raise KeyError - def __iter__(self): - """Generator function that iterates over the keys of the provider. - - Example: - - >>> local_provider = LocalProvider("/home/ubuntu/Documents/") - >>> for my_data in local_provider: - ... pass - - Yields: - str: the path of the object that it is iterating over, relative to the root of the provider. - """ - yield from self._all_keys() - - def __len__(self): - """Returns the number of files present inside the root of the provider. - - Example: - - >>> local_provider = LocalProvider("/home/ubuntu/Documents/") - >>> len(local_provider) - - Returns: - int: the number of files present inside the root. - """ - return len(self._all_keys()) - - def _all_keys(self, refresh: bool = False) -> Set[str]: - """Lists all the objects present at the root of the Provider. - - Args: - refresh (bool): refresh keys - - Returns: - set: set of all the objects found at the root of the Provider. - """ + def _all_keys_impl(self, refresh: bool = False) -> Set[str]: if self.files is None or refresh: full_path = os.path.expanduser(self.root) key_set = set() @@ -195,6 +109,10 @@ def _check_is_file(self, path: str): Raises: DirectoryAtPathException: If a directory is found at the path. """ + + if self._is_temp(path): + return path + full_path = posixpath.join(self.root, path) full_path = os.path.expanduser(full_path) full_path = str(pathlib.Path(full_path)) @@ -202,9 +120,8 @@ def _check_is_file(self, path: str): raise DirectoryAtPathException return full_path - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): """Deletes ALL data with keys having given prefix on the local machine (under self.root). Exercise caution!""" - self.check_readonly() full_path = os.path.expanduser(self.root) if prefix and self.files: self.files = set(file for file in self.files if not file.startswith(prefix)) @@ -221,15 +138,18 @@ def rename(self, path): os.rename(self.root, path) self.root = path - def __contains__(self, key) -> bool: + def _contains_impl(self, key) -> bool: full_path = self._check_is_file(key) return os.path.exists(full_path) def __getstate__(self): - return self.root + super()._getstate_prepare() + + return {"root": self.root, "_temp_data": self._temp_data} def __setstate__(self, state): - self.__init__(state) + self.__init__(state["root"]) + self._temp_data = state.get("_temp_data", {}) def get_presigned_url(self, key: str) -> str: return os.path.join(self.root, key) @@ -237,7 +157,7 @@ def get_presigned_url(self, key: str) -> str: def get_object_size(self, key: str) -> int: return os.stat(os.path.join(self.root, key)).st_size - def get_bytes( + def _get_bytes_impl( self, path: str, start_byte: Optional[int] = None, diff --git a/deeplake/core/storage/lru_cache.py b/deeplake/core/storage/lru_cache.py index da0e15bd29..d0a74fed99 100644 --- a/deeplake/core/storage/lru_cache.py +++ b/deeplake/core/storage/lru_cache.py @@ -45,6 +45,8 @@ def __init__( This number may be less than the actual space available on the cache_storage. Setting it to a higher value than actually available space may lead to unexpected behaviors. """ + super().__init__() + self.next_storage = next_storage self.cache_storage = cache_storage self.cache_size = cache_size @@ -66,6 +68,10 @@ def __init__( # ) self.use_async = False + def _is_temp(self, key: str) -> bool: + """Always returns False so temp objects are handled by the underlying storage instance""" + return False + def register_deeplake_object(self, path: str, obj: DeepLakeMemoryObject): """Registers a new object in the cache.""" self.deeplake_objects[path] = obj @@ -195,7 +201,7 @@ def _get_item_from_cache(self, path: str): self.lru_sizes.move_to_end(path) # refresh position for LRU return self.cache_storage[path] - def __getitem__(self, path: str): + def _getitem_impl(self, path: str): """If item is in cache_storage, retrieves from there and returns. If item isn't in cache_storage, retrieves from next storage, stores in cache_storage (if possible) and returns. @@ -231,7 +237,7 @@ def get_items(self, paths): self._insert_in_cache(key, result) yield result - def get_bytes( + def _get_bytes_impl( self, path: str, start_byte: Optional[int] = None, @@ -277,6 +283,8 @@ def __setitem__(self, path: str, value: Union[bytes, DeepLakeMemoryObject]): Raises: ReadOnlyError: If the provider is in read-only mode. """ + # Overrides otherwise-final Provider __setitem__ to handle DeepLakeMemoryObject + self.check_readonly() if path in self.deeplake_objects: self.deeplake_objects[path].is_dirty = False @@ -293,6 +301,9 @@ def __setitem__(self, path: str, value: Union[bytes, DeepLakeMemoryObject]): self.maybe_flush() + def _setitem_impl(self, path: str, value: bytes): + assert False, "This function should not be called. Use __setitem__ instead." + def _del_item_from_cache(self, path: str): deleted_from_cache = False @@ -309,7 +320,7 @@ def _del_item_from_cache(self, path: str): return deleted_from_cache - def __delitem__(self, path: str): + def _delitem_impl(self, path: str): """Deletes the object present at the path from the cache and the underlying storage. Args: @@ -347,7 +358,7 @@ def clear_cache_without_flush(self): if self.next_storage is not None and hasattr(self.next_storage, "clear_cache"): self.next_storage.clear_cache() - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): """Deletes ALL the data from all the layers of the cache and the actual storage. This is an IRREVERSIBLE operation. Data once deleted can not be recovered. """ @@ -372,7 +383,7 @@ def clear(self, prefix=""): if self.next_storage is not None: self.next_storage.clear(prefix=prefix) - def __len__(self): + def _len_impl(self): """Returns the number of files present in the cache and the underlying storage. Returns: @@ -380,14 +391,6 @@ def __len__(self): """ return len(self._all_keys()) - def __iter__(self): - """Generator function that iterates over the keys of the cache and the underlying storage. - - Yields: - str: the path of the object that it is iterating over, relative to the root of the provider. - """ - yield from self._all_keys() - def _forward(self, path): """Forward the value at a given path to the next storage, and un-marks its key.""" if self.next_storage is not None: @@ -442,12 +445,7 @@ def _insert_in_cache(self, path: str, value: Union[bytes, DeepLakeMemoryObject]) self.update_used_cache_for_path(path, _get_nbytes(value)) - def _all_keys(self): - """Helper function that lists all the objects present in the cache and the underlying storage. - - Returns: - set: set of all the objects found in the cache and the underlying storage. - """ + def _all_keys_impl(self, refresh: bool = False): key_set = set() if self.next_storage is not None: key_set = self.next_storage._all_keys() # type: ignore @@ -465,6 +463,8 @@ def _flush_if_not_read_only(self): def __getstate__(self) -> Dict[str, Any]: """Returns the state of the cache, for pickling""" + super()._getstate_prepare() + # flushes the cache before pickling self._flush_if_not_read_only() diff --git a/deeplake/core/storage/memory.py b/deeplake/core/storage/memory.py index f15cb1f690..1183acd17a 100644 --- a/deeplake/core/storage/memory.py +++ b/deeplake/core/storage/memory.py @@ -8,107 +8,35 @@ class MemoryProvider(StorageProvider): """Provider class for using the memory.""" def __init__(self, root: str = ""): + super().__init__() self.dict: Dict[str, Any] = {} self.root = root - def __getitem__( + def _getitem_impl( self, path: str, ): - """Gets the object present at the path within the given byte range. - - Example: - - >>> memory_provider = MemoryProvider("xyz") - >>> my_data = memory_provider["abc.txt"] - - Args: - path (str): The path relative to the root of the provider. - - Returns: - bytes: The bytes of the object present at the path. - - Raises: - KeyError: If an object is not found at the path. - """ return self.dict[path] - def __setitem__( + def _setitem_impl( self, path: str, value: bytes, ): - """Sets the object present at the path with the value - - Example: - - >>> memory_provider = MemoryProvider("xyz") - >>> memory_provider["abc.txt"] = b"abcd" - - Args: - path (str): the path relative to the root of the provider. - value (bytes): the value to be assigned at the path. - - Raises: - ReadOnlyError: If the provider is in read-only mode. - """ self.check_readonly() self.dict[path] = value - def __iter__(self): - """Generator function that iterates over the keys of the provider. - - Example: - - >>> memory_provider = MemoryProvider("xyz") - >>> for my_data in memory_provider: - ... pass - - Yields: - str: the path of the object that it is iterating over, relative to the root of the provider. - """ - yield from self.dict - - def __delitem__(self, path: str): - """Delete the object present at the path. - - Example: - - >>> memory_provider = MemoryProvider("xyz") - >>> del memory_provider["abc.txt"] - - Args: - path (str): the path to the object relative to the root of the provider. - - Raises: - KeyError: If an object is not found at the path. - ReadOnlyError: If the provider is in read-only mode. - """ + def _delitem_impl(self, path: str): self.check_readonly() del self.dict[path] - def __len__(self): - """Returns the number of files present inside the root of the provider. - - Example: - - >>> memory_provider = MemoryProvider("xyz") - >>> len(memory_provider) - - Returns: - int: the number of files present inside the root. - """ + def _len_impl(self): return len(self.dict) - def _all_keys(self): - """Lists all the objects present at the root of the Provider. - - Returns: - set: set of all the objects found at the root of the Provider. - """ + def _all_keys_impl(self, refresh: bool = False): return set(self.dict.keys()) - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): """Clears the provider.""" self.check_readonly() if prefix: @@ -116,12 +44,15 @@ def clear(self, prefix=""): else: self.dict = {} - def __getstate__(self) -> str: + def __getstate__(self) -> dict: + super()._getstate_prepare() + """Does NOT save the in memory data in state.""" - return self.root + return {"root": self.root, "_temp_data": self._temp_data} - def __setstate__(self, state: str): - self.__init__(root=state) # type: ignore + def __setstate__(self, state: dict): + self.__init__(root=state["root"]) # type: ignore + self._temp_data = state.get("_temp_data", {}) def get_object_size(self, key: str) -> int: return _get_nbytes(self[key]) diff --git a/deeplake/core/storage/provider.py b/deeplake/core/storage/provider.py index 5cfed68d3f..7a3c785579 100644 --- a/deeplake/core/storage/provider.py +++ b/deeplake/core/storage/provider.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from collections.abc import MutableMapping -from typing import Optional, Set, Sequence, Dict +from typing import Optional, Set, Sequence, Dict, final from deeplake.constants import BYTE_PADDING from deeplake.util.assert_byte_indexes import assert_byte_indexes @@ -35,10 +35,21 @@ class StorageProvider(ABC, MutableMapping): To add a new provider using Provider, create a subclass and implement all 5 abstract methods below. """ - @abstractmethod + def __init__(self): + self._temp_data: dict[str, bytes] = {} + + def _is_temp(self, key: str) -> bool: + """Check if the key is a temporary key and shouldn't be persisted to storage""" + return key.startswith("__temp") and not "/chunks/" in key + + @final def __getitem__(self, path: str): """Gets the object present at the path within the given byte range. + Example: + + >>> my_data = my_provider["abc.txt"] + Args: path (str): The path relative to the root of the provider. @@ -47,8 +58,29 @@ def __getitem__(self, path: str): Raises: KeyError: If an object is not found at the path. + DirectoryAtPathException: If a directory is found at the path. + Exception: Any other exception encountered while trying to fetch the object. """ + if self._is_temp(path): + return self._temp_data[path] + + return self._getitem_impl(path) + + @abstractmethod + def _getitem_impl(self, path: str): + """Gets the object present at the path within the given byte range. + Args: + path (str): The path relative to the root of the provider. + + Returns: + bytes: The bytes of the object present at the path. + + Raises: + KeyError: If an object is not found at the path. + """ + + @final def get_bytes( self, path: str, @@ -69,13 +101,53 @@ def get_bytes( InvalidBytesRequestedError: If `start_byte` > `end_byte` or `start_byte` < 0 or `end_byte` < 0. KeyError: If an object is not found at the path. """ + if self._is_temp(path): + return self._temp_data[path] + + return self._get_bytes_impl(path, start_byte, end_byte) + + def _get_bytes_impl( + self, + path: str, + start_byte: Optional[int] = None, + end_byte: Optional[int] = None, + ): assert_byte_indexes(start_byte, end_byte) return self[path][start_byte:end_byte] - @abstractmethod + def _getstate_prepare(self): + for k in self._temp_data.keys(): + if isinstance(self._temp_data[k], memoryview): + self._temp_data[k] = self._temp_data[k].tobytes() + def __setitem__(self, path: str, value: bytes): """Sets the object present at the path with the value + Example: + + >>> my_provider["abc.txt"] = b"abcd" + + Args: + path (str): the path relative to the root of the provider. + value (bytes): the value to be assigned at the path. + + Raises: + Exception: If unable to set item due to directory at path or permission or space issues. + FileAtPathException: If the directory to the path is a file instead of a directory. + ReadOnlyError: If the provider is in read-only mode. + """ + # print(f"Setitem: {path} in {self}") + self.check_readonly() + if self._is_temp(path): + self._temp_data[path] = value + return + + self._setitem_impl(path, value) + + @abstractmethod + def _setitem_impl(self, path: str, value: bytes): + """Sets the object present at the path with the value + Args: path (str): the path relative to the root of the provider. value (bytes): the value to be assigned at the path. @@ -119,40 +191,83 @@ def set_bytes( value = value.rjust(end_byte, BYTE_PADDING) self[path] = value - @abstractmethod + @final + def __contains__(self, key) -> bool: + if self._is_temp(key): + return key in self._temp_data + + return self._contains_impl(key) + + def _contains_impl(self, key) -> bool: + """Check if the key exists in the provider.""" + return key in self._all_keys_impl() + + @final def __iter__(self): """Generator function that iterates over the keys of the provider. + Example: + + >>> for my_data in my_provider: + ... pass + Yields: str: the path of the object that it is iterating over, relative to the root of the provider. """ + yield from self._all_keys() @abstractmethod - def _all_keys(self) -> Set[str]: - """Generator function that iterates over the keys of the provider. + def _all_keys_impl(self, refresh: bool = False) -> Set[str]: + pass + + @final + def _all_keys(self, refresh: bool = False) -> Set[str]: + """Lists all the objects present at the root of the Provider. + + Args: + refresh (bool): refresh keys Returns: - set: set of all keys present at the root of the provider. + set: set of all the objects found at the root of the Provider. """ + return set.union(set(self._all_keys_impl(refresh)), set(self._temp_data.keys())) - @abstractmethod + @final def __delitem__(self, path: str): """Delete the object present at the path. + Example: + + >>> del my_provider["abc.txt"] + Args: path (str): the path to the object relative to the root of the provider. Raises: KeyError: If an object is not found at the path. + DirectoryAtPathException: If a directory is found at the path. + Exception: Any other exception encountered while trying to fetch the object. + ReadOnlyError: If the provider is in read-only mode. """ + self.check_readonly() + + if self._is_temp(path): + del self._temp_data[path] + return + self._delitem_impl(path) @abstractmethod + def _delitem_impl(self, path: str): + """Delete the object present at the path.""" + + @final def __len__(self): """Returns the number of files present inside the root of the provider. Returns: int: the number of files present inside the root. """ + return len(self._all_keys()) + len(self._temp_data) def enable_readonly(self): """Enables read-only mode for the provider.""" @@ -180,9 +295,23 @@ def maybe_flush(self): if self.autoflush: self.flush() - @abstractmethod + @final def clear(self, prefix=""): """Delete the contents of the provider.""" + self.check_readonly() + + if prefix: + self._temp_data = { + k: v for k, v in self._temp_data.items() if not k.startswith(prefix) + } + else: + self._temp_data = {} + + self._clear_impl(prefix) + + @abstractmethod + def _clear_impl(self, prefix=""): + """Delete the contents of the provider.""" def delete_multiple(self, paths: Sequence[str]): for path in paths: diff --git a/deeplake/core/storage/s3.py b/deeplake/core/storage/s3.py index d8936ffb3a..b50cc4c93c 100644 --- a/deeplake/core/storage/s3.py +++ b/deeplake/core/storage/s3.py @@ -116,6 +116,8 @@ def __init__( This is optional, tokens are normally autogenerated. **kwargs: Additional arguments to pass to the S3 client. Includes: ``expiration``. """ + super().__init__() + self.root = root self.aws_access_key_id = aws_access_key_id self.aws_secret_access_key = aws_secret_access_key @@ -166,18 +168,7 @@ def _set(self, path, content): ContentType="application/octet-stream", # signifies binary data ) - def __setitem__(self, path, content): - """Sets the object present at the path with the value - - Args: - path (str): the path relative to the root of the S3Provider. - content (bytes): the value to be assigned at the path. - - Raises: - S3SetError: Any S3 error encountered while setting the value at the path. - ReadOnlyError: If the provider is in read-only mode. - """ - self.check_readonly() + def _setitem_impl(self, path, content): self._check_update_creds() path = "".join((self.path, path)) content = bytes(memoryview(content)) @@ -210,20 +201,8 @@ def _get(self, path, bucket=None): ) return resp["Body"].read() - def __getitem__(self, path): - """Gets the object present at the path. - - Args: - path (str): the path relative to the root of the S3Provider. - - Returns: - bytes: The bytes of the object present at the path. - - Raises: - KeyError: If an object is not found at the path. - S3GetError: Any other error other than KeyError while retrieving the object. - """ - return self.get_bytes(path) + def _getitem_impl(self, path): + return self._get_bytes_impl(path) def _get_bytes( self, path, start_byte: Optional[int] = None, end_byte: Optional[int] = None @@ -241,7 +220,7 @@ def _get_bytes( resp = self.client.get_object(Bucket=self.bucket, Key=path, Range=range) return resp["Body"].read() - def get_bytes( + def _get_bytes_impl( self, path: str, start_byte: Optional[int] = None, @@ -297,7 +276,7 @@ def get_bytes( def _del(self, path): self.client.delete_object(Bucket=self.bucket, Key=path) - def __delitem__(self, path): + def _delitem_impl(self, path): """Delete the object present at the path. Args: @@ -351,19 +330,13 @@ def _keys_iterator(self): for content in page.get("Contents", ()): yield content["Key"] - def _all_keys(self): - """Helper function that lists all the objects present at the root of the S3Provider. - - Returns: - set: set of all the objects found at the root of the S3Provider. + def _all_keys_impl(self, refresh: bool = False): + self._check_update_creds() - Raises: - S3ListError: Any S3 error encountered while listing the objects. - """ len_path = len(self.path.split("/")) - 1 return ("/".join(name.split("/")[len_path:]) for name in self._keys_iterator()) - def __len__(self): + def _len_impl(self): """Returns the number of files present at the root of the S3Provider. Note: @@ -377,22 +350,13 @@ def __len__(self): """ return sum(1 for _ in self._keys_iterator()) - def __iter__(self): - """Generator function that iterates over the keys of the S3Provider. - - Yields: - str: the name of the object that it is iterating over. - """ - self._check_update_creds() - yield from self._all_keys() - def _clear(self, prefix): bucket = self.resource.Bucket(self.bucket) for response in bucket.objects.filter(Prefix=prefix).delete(): if response["Errors"]: raise S3DeletionError(response["Errors"][0]["Message"]) - def clear(self, prefix=""): + def _clear_impl(self, prefix=""): """Deletes ALL data with keys having given prefix on the s3 bucket (under self.root). Warning: @@ -463,9 +427,12 @@ def _state_keys(self): "read_only", "profile_name", "creds_used", + "_temp_data", } def __getstate__(self): + super()._getstate_prepare() + return {key: getattr(self, key) for key in self._state_keys()} def __setstate__(self, state):