Source code for fairxai.data.dataset.tabular_dataset

from __future__ import annotations

import os
from typing import Optional, List, Union, Dict, Any

import numpy as np
import pandas as pd
from pandas import DataFrame, Series

from fairxai.data.dataset import Dataset
from fairxai.data.descriptor.tabular_descriptor import TabularDatasetDescriptor
from fairxai.logger import logger


[docs] class TabularDataset(Dataset): """ Tabular dataset container for explainers. Supports initialization from: - pandas DataFrame - CSV file path - dict or list of dict The dataset supports two serialization modes: - **CSV-based dataset**: Only the CSV path is stored. Data is reloaded from the original CSV on load. - **Memory-based dataset**: Raw DataFrame is saved as CSV inside the project folder for persistence. Features and target are separated: - self._data: features-only DataFrame - self._target: target Series (if class_name provided) """ def __init__( self, data: Union[DataFrame, str, dict, List[dict]], class_name: Optional[str] = None, categorical_columns: Optional[List[str]] = None, ordinal_columns: Optional[List[str]] = None, dropna: bool = False ) -> None: """ Initialize TabularDataset. Parameters ---------- data : DataFrame | str | dict | list[dict] Source data class_name : str | None Target column name, if present categorical_columns : list[str] | None Column names to treat as categorical ordinal_columns : list[str] | None Column names to treat as ordinal dropna : bool If reading CSV, drop rows with missing values """ super().__init__(data, class_name) self.source_type: str # "csv" or "memory" self.csv_path: Optional[str] = None self._categorical_columns: List[str] = list(categorical_columns) if categorical_columns else [] self._ordinal_columns: List[str] = list(ordinal_columns) if ordinal_columns else [] self.class_name: Optional[str] = class_name # -------------------------- # CSV-based dataset # -------------------------- if isinstance(data, str) and os.path.exists(data) and data.lower().endswith(".csv"): self.source_type = "csv" self.csv_path = data df = pd.read_csv(data, skipinitialspace=True, na_values="?", keep_default_na=True) if dropna: df = df.dropna() # -------------------------- # Memory-based dataset # -------------------------- # FIXME: qui devo chiamare save_memory_data(), ma devo trovare un modo per passargli il path giusto! elif isinstance(data, pd.DataFrame): self.source_type = "memory" df = data elif isinstance(data, dict): self.source_type = "memory" df = pd.DataFrame(data) elif isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict): self.source_type = "memory" df = pd.DataFrame(data) else: raise TypeError( f"Unsupported data type for TabularDataset: {type(data)}. " "Supported: DataFrame, CSV path, dict, list[dict]." ) # -------------------------- # Extract target # -------------------------- if self.class_name is not None and self.class_name in df.columns: self._target: Optional[Series] = df[self.class_name].copy() df = df.drop(columns=[self.class_name]) else: self._target = None self._data: DataFrame = df # Remove target from hints if present self._sanitize_column_hints() # Compute descriptor try: self.update_descriptor() except ValueError as e: logger.error(f"Error computing descriptor: {e}") raise # -------------------------- # Serialization # --------------------------
[docs] def to_dict(self) -> Dict[str, Any]: """ Serialize dataset metadata for project persistence. Returns ------- dict Metadata describing dataset type, source, and column hints. """ return { "type": "tabular", "source_type": self.source_type, "csv_path": self.csv_path, "class_name": self.class_name, "categorical_columns": self._categorical_columns, "ordinal_columns": self._ordinal_columns, "columns": list(self._data.columns) }
[docs] @classmethod def from_dict(cls, meta: Dict[str, Any], project_path: Optional[str] = None) -> TabularDataset: """ Reconstruct a TabularDataset from metadata. Parameters ---------- meta : dict Serialized metadata project_path : str | None Project folder path; used to locate memory-based CSV if needed Returns ------- TabularDataset """ source = meta["source_type"] if source == "csv": return cls(data=meta["csv_path"], class_name=meta.get("class_name"), categorical_columns=meta.get("categorical_columns"), ordinal_columns=meta.get("ordinal_columns")) elif source == "memory": # Memory-based dataset: load CSV inside project folder if project_path is None: raise ValueError("project_path must be provided to load memory-based dataset") csv_file = os.path.join(project_path, "dataset", "tabular.csv") if not os.path.exists(csv_file): raise FileNotFoundError(f"Missing memory dataset file: {csv_file}") df = pd.read_csv(csv_file) return cls(data=df, class_name=meta.get("class_name"), categorical_columns=meta.get("categorical_columns"), ordinal_columns=meta.get("ordinal_columns")) else: raise ValueError(f"Unknown source_type: {source}")
# -------------------------- # Save memory-based dataset # --------------------------
[docs] def save_memory_data(self, dest_folder: str) -> None: """ Save memory-based dataset to CSV inside project folder. Parameters ---------- dest_folder : str Destination folder path """ if self.source_type != "memory": return os.makedirs(dest_folder, exist_ok=True) csv_file = os.path.join(dest_folder, "tabular.csv") self._data.to_csv(csv_file, index=False) # Also save target column separately if needed if self._target is not None and self.class_name: target_file = os.path.join(dest_folder, f"{self.class_name}.csv") self._target.to_csv(target_file, index=False) logger.info(f"Saved memory-based tabular dataset to {csv_file}")
# -------------------------- # Internal helpers # -------------------------- def _sanitize_column_hints(self) -> None: """Remove target column from categorical/ordinal hints if present.""" if self.class_name: if self.class_name in self._categorical_columns: logger.warning( f"Target column '{self.class_name}' found in categorical_columns; removing it." ) self._categorical_columns = [c for c in self._categorical_columns if c != self.class_name] if self.class_name in self._ordinal_columns: logger.warning( f"Target column '{self.class_name}' found in ordinal_columns; removing it." ) self._ordinal_columns = [c for c in self._ordinal_columns if c != self.class_name] # -------------------------- # Descriptor # --------------------------
[docs] def update_descriptor( self, categorical_columns: Optional[List[str]] = None, ordinal_columns: Optional[List[str]] = None ) -> Dict[str, Any]: """Compute dataset descriptor based on features-only DataFrame.""" categorical_columns = categorical_columns or self._categorical_columns ordinal_columns = ordinal_columns or self._ordinal_columns # Keep only valid columns categorical_columns = [c for c in categorical_columns if c in self._data.columns] ordinal_columns = [c for c in ordinal_columns if c in self._data.columns] self.descriptor = TabularDatasetDescriptor( data=self._data, categorical_columns=categorical_columns, ordinal_columns=ordinal_columns ).describe(target=self._target, target_name=self.class_name) return self.descriptor
# -------------------------- # Indexing / length # -------------------------- def __len__(self) -> int: return len(self._data) def __getitem__(self, idx: int) -> np.ndarray: """Return feature vector for index idx as 1-D numpy array.""" row = self._data.iloc[idx] return row.values # -------------------------- # Properties # -------------------------- @property def X(self) -> DataFrame: return self._data @property def y(self) -> Optional[Series]: return self._target @property def target(self) -> Optional[Series]: return self._target @property def features(self) -> List[str]: """Return feature names according to descriptor.""" if hasattr(self, "descriptor") and self.descriptor: names = [] for section in ["numeric", "categorical", "ordinal"]: names.extend(self.descriptor.get(section, {}).keys()) return names return list(self._data.columns)
[docs] def get_feature_names(self) -> List[str]: return self.features
[docs] def get_feature_name(self, index: int) -> str: for section in ["numeric", "categorical", "ordinal"]: for name, info in self.descriptor.get(section, {}).items(): if info.get("index") == index: return name # fallback cols = list(self._data.columns) if 0 <= index < len(cols): return cols[index] raise IndexError(f"No feature found with index {index}")
[docs] def get_class_values(self) -> List[Any]: if not self.class_name: raise Exception("class_name is None.") if self._target is not None: return list(self._target.unique()) raise KeyError(f"Target '{self.class_name}' not found in dataset.")
[docs] def set_class_name(self, class_name: str) -> None: """Change target column name, moving column from features to target.""" if class_name == self.class_name: return if class_name in self._data.columns: self._target = self._data[class_name].copy() self._data = self._data.drop(columns=[class_name]) self.class_name = class_name self._sanitize_column_hints() self.update_descriptor() return raise KeyError(f"Column '{class_name}' not found among features.")