import os
from typing import List, Union, Optional, Tuple, Dict
import numpy as np
from PIL import Image
from fairxai.data.dataset import Dataset
from fairxai.data.descriptor.image_descriptor import ImageDatasetDescriptor
from fairxai.logger import logger
[docs]
class ImageDataset(Dataset):
"""
Represents an image dataset that can be loaded either from a folder
containing image files or directly from in-memory NumPy arrays.
The dataset supports two serialization modes:
- **Folder-based dataset**:
Only the folder path is serialized. Images are reloaded at project load time.
- **Memory-based dataset**:
Raw NumPy arrays are saved to a compressed ``.npz`` file inside the project folder.
This allows reconstruction of datasets not tied to an external file system.
Parameters
----------
data : str | np.ndarray | list[np.ndarray]
Either a folder path or image arrays.
class_name : str | None, optional
Optional class label for the dataset.
Raises
------
TypeError
If ``data`` is not a supported type.
ValueError
If no valid images can be loaded.
"""
def __init__(
self,
data: Union[str, np.ndarray, List[np.ndarray]],
class_name: Optional[str] = None
) -> None:
super().__init__(data=None, class_name=class_name)
# Tracks how the dataset was created ("folder" or "memory").
self.source_type: str
# Path of the folder, only for folder-mode datasets.
self.folder_path: Optional[str] = None
# Filenames corresponding to loaded images (valid only for folder datasets).
self.filenames: List[str] = []
# ---------------------------------------------------------
# FOLDER MODE: load images from a directory
# ---------------------------------------------------------
if isinstance(data, str):
self.source_type = "folder"
self.folder_path = data
logger.info(f"Loading images from folder: {data}")
self.data, self.filenames = self._load_from_folder(data)
# ---------------------------------------------------------
# MEMORY MODE: single NumPy array
# ---------------------------------------------------------
elif isinstance(data, np.ndarray):
self.source_type = "memory"
self.data = [data]
self.filenames = []
# ---------------------------------------------------------
# MEMORY MODE: list of NumPy arrays
# ---------------------------------------------------------
elif isinstance(data, list) and all(isinstance(a, np.ndarray) for a in data):
self.source_type = "memory"
self.data = data
self.filenames = []
else:
raise TypeError(
"Data must be a folder path, a NumPy array, or a list of NumPy arrays."
)
# Build dataset descriptor
try:
self.update_descriptor()
except ValueError as exc:
logger.error(f"Error computing descriptor: {exc}")
raise
# ======================================================================
# SERIALIZATION
# ======================================================================
[docs]
def to_dict(self) -> Dict:
"""
Serialize dataset metadata into a dictionary.
Notes
-----
- Folder datasets store only ``folder_path``.
- Memory datasets do not store raw image arrays here; arrays are saved
separately via `save_memory_data`.
Returns
-------
dict
Metadata describing how to reconstruct the dataset.
"""
return {
"type": "image",
"source_type": self.source_type,
"folder_path": self.folder_path,
"class_name": self.class_name,
}
[docs]
@classmethod
def from_dict(cls, meta: Dict, project_path: str) -> "ImageDataset":
"""
Reconstruct an `ImageDataset` instance from serialized metadata.
Parameters
----------
meta : dict
Serialized dataset information.
project_path : str
Filesystem path to the root of the project.
Returns
-------
ImageDataset
Raises
------
FileNotFoundError
If memory-based dataset arrays are missing.
ValueError
For unknown dataset source types.
"""
source = meta["source_type"]
# ---------------------------------------------------------
# FOLDER MODE RECONSTRUCTION
# ---------------------------------------------------------
if source == "folder":
return cls(
data=meta["folder_path"],
class_name=meta.get("class_name")
)
# ---------------------------------------------------------
# MEMORY MODE RECONSTRUCTION (.npz)
# ---------------------------------------------------------
elif source == "memory":
npz_path = os.path.join(project_path, "dataset", "images.npz")
if not os.path.exists(npz_path):
raise FileNotFoundError(f"Missing memory dataset file: {npz_path}")
npz = np.load(npz_path)
arrays = [npz[k] for k in npz]
return cls(
data=arrays,
class_name=meta.get("class_name")
)
raise ValueError(f"Unknown dataset source type: {source}")
# ======================================================================
# SAVE MEMORY-BASED DATASETS
# ======================================================================
[docs]
def save_memory_data(self, dest_folder: str) -> None:
"""
Save memory-based image arrays into a compressed ``.npz`` file.
Parameters
----------
dest_folder : str
Directory where the ``images.npz`` will be written.
Notes
-----
If the dataset originates from a folder, this method does nothing.
"""
if self.source_type != "memory":
return
os.makedirs(dest_folder, exist_ok=True)
npz_path = os.path.join(dest_folder, "images.npz")
# Save each image array under an integer key (0, 1, ...).
np.savez_compressed(npz_path, *self.data)
logger.info(f"Saved memory image arrays to {npz_path}")
# ======================================================================
# IMAGE LOADING UTILITIES
# ======================================================================
def _load_from_folder(
self,
folder_path: str,
extensions: Optional[List[str]] = None,
recursive: bool = True
) -> Tuple[List[np.ndarray], List[str]]:
"""
Load all images inside a folder.
Parameters
----------
folder_path : str
Path to the image folder.
extensions : list[str], optional
Allowed file extensions. Default is common image formats.
recursive : bool
If ``True``, walk directories recursively.
Returns
-------
(list[np.ndarray], list[str])
Loaded images and their corresponding filenames.
Raises
------
ValueError
If no valid images can be loaded.
"""
if extensions is None:
extensions = [".png", ".jpg", ".jpeg", ".bmp", ".tiff"]
image_paths: List[str] = []
# Collect image paths
if recursive:
for root, _, files in os.walk(folder_path):
for filename in files:
if os.path.splitext(filename)[1].lower() in extensions:
image_paths.append(os.path.join(root, filename))
else:
for filename in os.listdir(folder_path):
if os.path.splitext(filename)[1].lower() in extensions:
image_paths.append(os.path.join(folder_path, filename))
if not image_paths:
raise ValueError(f"No images found in folder: {folder_path}")
images = []
filenames = []
# Load each image with fallback handling
for path in image_paths:
try:
img = Image.open(path)
images.append(np.array(img))
filenames.append(os.path.basename(path))
except Exception as exc:
logger.warning(f"Skipping image {path}: {exc}")
if not images:
raise ValueError(
f"No valid images could be loaded from folder: {folder_path}"
)
return images, filenames
# ======================================================================
# DESCRIPTOR + ACCESSORS
# ======================================================================
[docs]
def update_descriptor(
self,
hwc_permutation: Optional[List[int]] = None
) -> Dict:
"""
Compute and attach the dataset descriptor.
Parameters
----------
hwc_permutation : list[int] | None
Optional permutation of axes (H, W, C).
Returns
-------
dict
The computed descriptor.
"""
logger.info("Creating descriptor for image dataset")
descriptor = ImageDatasetDescriptor(self.data).describe(
hwc_permutation=hwc_permutation
)
self.set_descriptor(descriptor)
return descriptor
[docs]
def get_instance(self, key: Union[int, str]) -> np.ndarray:
"""
Retrieve a single image instance either by index or filename.
Parameters
----------
key : int | str
Integer index or filename.
Returns
-------
np.ndarray
The requested image.
Raises
------
IndexError
If index is out of range.
ValueError
If filename lookup fails or filenames are unavailable.
TypeError
If key is neither int nor str.
"""
if isinstance(key, int):
if key < 0 or key >= len(self.data):
raise IndexError(f"Index {key} is out of range.")
return self.data[key]
if isinstance(key, str):
if not self.filenames:
raise ValueError("No filenames available for lookup.")
try:
idx = self.filenames.index(key)
return self.data[idx]
except ValueError:
raise ValueError(f"Filename '{key}' not found in dataset.")
raise TypeError("Key must be an integer index or filename string.")