Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds EXR format to store depth images in float32 #7463

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@

VISION_REQUIRE = [
"Pillow>=9.4.0", # When PIL.Image.ExifTags was introduced
"openexr_numpy>=0.0.6", # for EXR format support for depth
]

BENCHMARKS_REQUIRE = [
Expand Down
3 changes: 2 additions & 1 deletion src/datasets/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"Array4D",
"Array5D",
"ClassLabel",
"Exr",
"Features",
"LargeList",
"Sequence",
Expand All @@ -15,7 +16,7 @@
"Video",
]
from .audio import Audio
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Features, LargeList, Sequence, Value
from .features import Array2D, Array3D, Array4D, Array5D, ClassLabel, Exr, Features, LargeList, Sequence, Value
from .image import Image
from .translation import Translation, TranslationVariableLanguages
from .video import Video
237 changes: 237 additions & 0 deletions src/datasets/features/exr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from openexr_numpy import imwrite as exwrite
from openexr_numpy import imread as exload

import os
from dataclasses import dataclass, field
from io import BytesIO
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union, Tuple

import numpy as np
import pyarrow as pa

from .. import config
from ..download.download_config import DownloadConfig
from ..table import array_cast
from ..utils.file_utils import is_local_path, xopen
from ..utils.py_utils import no_op_if_value_is_null, string_to_dict

if TYPE_CHECKING:
from .features import FeatureType

@dataclass
class Exr:
"""Exr [`Feature`] to read Exr image data from an Exr file.

Input: The Exr feature accepts as input:
- A `str`: Absolute path to the Exr file (i.e., random access is allowed).
- A `dict` with the keys:

- `path`: String with relative path of the Exr file to the archive file.
- `bytes`: Bytes of the Exr file.

This is useful for archived files with sequential access.

- An `np.ndarray`: NumPy array representing the Exr image.

Args:
decode (`bool`, defaults to `True`):
Whether to decode the Exr data. If `False`,
returns the underlying dictionary in the format `{"path": exr_path, "bytes": exr_bytes}`.

Example:

```py
>>> from datasets import load_dataset, Exr
>>> ds = load_dataset("my_dataset")
>>> ds = ds.cast_column("exr_image", Exr())
>>> ds[0]["exr_image"]
{'array': array([...], dtype=float32),
'path': '/path/to/file.exr'}
```
"""

decode: bool = True
id: Optional[str] = None
# Automatically constructed
dtype: ClassVar[str] = "np.ndarray"
pa_type: ClassVar[Any] = pa.struct({"bytes": pa.binary(), "path": pa.string()})
_type: str = field(default="Exr", init=False, repr=False)

def __call__(self):
return self.pa_type

def encode_example(self, value: Union[str, bytes, dict, np.ndarray]) -> dict:
"""Encode example into a format for Arrow.

Args:
value (`str`, `np.ndarray` or `dict`):
Data passed as input to Exr feature.

Returns:
`dict` with "path" and "bytes" fields
"""
if isinstance(value, str):
return {"path": value, "bytes": None}
elif isinstance(value, bytes):
return {"path": None, "bytes": value}
elif isinstance(value, np.ndarray):
# Convert the Exr array to bytes using the provided exwrite function
buffer = BytesIO()
exwrite(buffer, value) # exwrite is your custom function to save Exr to bytes
return {"path": None, "bytes": buffer.getvalue()}
elif value.get("path") is not None and os.path.isfile(value["path"]):
return {"bytes": None, "path": value.get("path")}
elif value.get("bytes") is not None or value.get("path") is not None:
return {"bytes": value.get("bytes"), "path": value.get("path")}
else:
raise ValueError(f"An Exr sample should have one of 'path' or 'bytes' but they are missing or None in {value}.")

def decode_example(self, value: Tuple[Dict, str], token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None) -> np.ndarray:
"""Decode example Exr file into image data.

Args:
value (`str` or `dict`):
A string with the absolute Exr file path, a dictionary with keys:

- `path`: String with absolute or relative Exr file path.
- `bytes`: The bytes of the Exr file.
token_per_repo_id (`dict`, *optional*):
To access and decode Exr files from private repositories on the Hub, you can pass
a dictionary repo_id (`str`) -> token (`bool` or `str`).

Returns:
`np.ndarray`
"""
if not self.decode:
raise RuntimeError("Decoding is disabled for this feature. Please use Exr(decode=True) instead.")

if token_per_repo_id is None:
token_per_repo_id = {}

if isinstance(value, str):
path = value
bytes_ = None
else:
path, bytes_ = value["path"], value["bytes"]
if bytes_ is None:
if path is None:
raise ValueError(f"An Exr sample should have one of 'path' or 'bytes' but both are None in {value}.")
else:
if is_local_path(path):
array = exload(path) # exload is your custom function to load Exr
else:
source_url = path.split("::")[-1]
pattern = (
config.HUB_DATASETS_URL
if source_url.startswith(config.HF_ENDPOINT)
else config.HUB_DATASETS_HFFS_URL
)
try:
repo_id = string_to_dict(source_url, pattern)["repo_id"]
token = token_per_repo_id.get(repo_id)
except ValueError:
token = None
download_config = DownloadConfig(token=token)
with xopen(path, "rb", download_config=download_config) as f:
bytes_ = BytesIO(f.read())
array = exload(bytes_) # exload can handle file-like objects
else:
try:
bt = BytesIO(bytes_)
#print (len(bt.getvalue()))
array = exload(BytesIO(bytes_)) # exload can handle file-like objects
except Exception as e:
print (f"Warning, cannot read exr file because of {e}")
array = np.zeros((768, 1024), dtype=np.float64)
return array

def flatten(self) -> Union["FeatureType", Dict[str, "FeatureType"]]:
"""If in the decodable state, return the feature itself, otherwise flatten the feature into a dictionary."""
from .features import Value

return (
self
if self.decode
else {
"bytes": Value("binary"),
"path": Value("string"),
}
)

def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArray]) -> pa.StructArray:
"""Cast an Arrow array to the Exr arrow storage type.
The Arrow types that can be converted to the Exr pyarrow storage type are:

- `pa.string()` - it must contain the "path" data
- `pa.binary()` - it must contain the Exr bytes
- `pa.struct({"bytes": pa.binary()})`
- `pa.struct({"path": pa.string()})`
- `pa.struct({"bytes": pa.binary(), "path": pa.string()})` - order doesn't matter
- `pa.list(*)` - it must contain the Exr array data

Args:
storage (`Union[pa.StringArray, pa.StructArray, pa.ListArray]`):
PyArrow array to cast.

Returns:
`pa.StructArray`: Array in the Exr arrow storage type, that is
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
"""
if pa.types.is_string(storage.type):
bytes_array = pa.array([None] * len(storage), type=pa.binary())
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
elif pa.types.is_binary(storage.type):
path_array = pa.array([None] * len(storage), type=pa.string())
storage = pa.StructArray.from_arrays([storage, path_array], ["bytes", "path"], mask=storage.is_null())
elif pa.types.is_struct(storage.type):
if storage.type.get_field_index("bytes") >= 0:
bytes_array = storage.field("bytes")
else:
bytes_array = pa.array([None] * len(storage), type=pa.binary())
if storage.type.get_field_index("path") >= 0:
path_array = storage.field("path")
else:
path_array = pa.array([None] * len(storage), type=pa.string())
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=storage.is_null())
elif pa.types.is_list(storage.type):
bytes_array = pa.array(
[self.encode_example(np.array(arr))["bytes"] if arr is not None else None for arr in storage.to_pylist()],
type=pa.binary(),
)
path_array = pa.array([None] * len(storage), type=pa.string())
storage = pa.StructArray.from_arrays(
[bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null()
)
return array_cast(storage, self.pa_type)

def embed_storage(self, storage: pa.StructArray) -> pa.StructArray:
"""Embed Exr files into the Arrow array.

Args:
storage (`pa.StructArray`):
PyArrow array to embed.

Returns:
`pa.StructArray`: Array in the Exr arrow storage type, that is
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
"""

@no_op_if_value_is_null
def path_to_bytes(path):
with xopen(path, "rb") as f:
bytes_ = f.read()
return bytes_

bytes_array = pa.array(
[
(path_to_bytes(x["path"]) if x["bytes"] is None else x["bytes"]) if x is not None else None
for x in storage.to_pylist()
],
type=pa.binary(),
)
path_array = pa.array(
[os.path.basename(path) if path is not None else None for path in storage.field("path").to_pylist()],
type=pa.string(),
)
storage = pa.StructArray.from_arrays([bytes_array, path_array], ["bytes", "path"], mask=bytes_array.is_null())
return array_cast(storage, self.pa_type)
3 changes: 3 additions & 0 deletions src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from ..utils.py_utils import asdict, first_non_null_value, zip_dict
from .audio import Audio
from .image import Image, encode_pil_image
from .exr import Exr
from .translation import Translation, TranslationVariableLanguages
from .video import Video

Expand Down Expand Up @@ -1205,6 +1206,7 @@ class LargeList:
Array4D,
Array5D,
Audio,
Exr,
Image,
Video,
]
Expand Down Expand Up @@ -1421,6 +1423,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[dict[str, Uni
Array5D.__name__: Array5D,
Audio.__name__: Audio,
Image.__name__: Image,
Exr.__name__: Exr,
Video.__name__: Video,
}

Expand Down