Source code for fairxai.project.project

import json
import os
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional, Type

import numpy as np
import yaml

from fairxai.bbox.bbox_factory import ModelFactory
from fairxai.data.dataset.dataset_factory import DatasetFactory
from fairxai.explain.adapter.generic_explainer_adapter import GenericExplainerAdapter
from fairxai.explain.explaination.generic_explanation import GenericExplanation
from fairxai.explain.explainer_manager.explainer_manager import ExplainerManager
from fairxai.logger import logger


[docs] class Project: """ Represents an explainability project binding a dataset, a trained model, and a set of compatible explainers. Supports saving/loading pipelines, explanations, and project metadata in a dedicated workspace. This version uses the "framework-first" ModelFactory approach, where the user specifies the ML framework (sklearn, torch, etc.) and optionally a model file path. The correct black-box wrapper is automatically instantiated and loaded. Attributes: id (str): UUID of the project. created_at (datetime): Timestamp of creation. dataset_instance (Dataset): The dataset instance (tabular, image, etc.). blackbox (Any): Wrapped ML model. explainers (List[Type[GenericExplainerAdapter]]): Compatible explainer classes. explanations (List[Dict[str, Any]]): Stored explanation records. """ def __init__( self, project_name: str, data: Any = None, dataset_metadata: Optional[dict] = None, dataset_type: Optional[str] = None, framework: Optional[str] = None, model_path: Optional[str] = None, model_params: Optional[Dict[str, Any]] = None, workspace_path: str = "./workspace", target_variable: Optional[str] = None, categorical_columns: Optional[List[str]] = None, ordinal_columns: Optional[List[str]] = None, device: str = "cpu", is_reload: bool = False, id: str = None ): """ Initialize a new project or reload an existing one. Either `data` (for new project) or `dataset_metadata` (for reload) must be provided. :param project_name: Name of the project :param data: Raw dataset to create the dataset instance (new project) :param dataset_metadata: Metadata to reload dataset (existing project) :param dataset_type: Dataset type ('tabular', 'image', 'text', 'timeseries') :param framework: ML framework ('sklearn', 'torch', etc.) :param model_path: Optional path to pre-trained model :param model_params: Optional parameters for model instantiation :param workspace_path: Base folder for project workspace :param target_variable: Name of target variable in dataset (optional) :param categorical_columns: Categorical columns (tabular only) :param ordinal_columns: Ordinal columns (tabular only) :param device: Device for Torch models ('cpu' or 'cuda') :param is_reload: If True, indicates that this project is being reloaded and folders should not be recreated """ self.id: str = str(uuid.uuid4()) if is_reload is False else id self.created_at: datetime = datetime.now() self.project_name: str = project_name self.dataset_type: Optional[str] = dataset_type self.framework: Optional[str] = framework # Decide workspace path: reuse if reloading, else create new UUID folder if is_reload and dataset_metadata is not None: # Use workspace from metadata if reloading self.workspace_path: str = workspace_path os.makedirs(self.workspace_path, exist_ok=True) # ensure it exists else: # New project: create unique workspace folder self.workspace_path: str = os.path.join(workspace_path, self.id) os.makedirs(self.workspace_path, exist_ok=True) os.makedirs(os.path.join(self.workspace_path, "results"), exist_ok=True) os.makedirs(os.path.join(self.workspace_path, "pipelines"), exist_ok=True) os.makedirs(os.path.join(self.workspace_path, "logs"), exist_ok=True) # Store model & dataset parameters self.model_path: Optional[str] = model_path self.model_params: Optional[Dict[str, Any]] = model_params self.device: str = device self.target_variable: Optional[str] = target_variable self.categorical_columns: Optional[List[str]] = categorical_columns self.ordinal_columns: Optional[List[str]] = ordinal_columns # ------------------------- # Dataset initialization # ------------------------- if data is not None: # New dataset instance from raw data self.dataset_instance = DatasetFactory.create( data, dataset_type, class_name=target_variable, categorical_columns=categorical_columns, ordinal_columns=ordinal_columns ) elif dataset_metadata is not None: # Reload dataset from saved metadata using dataset-specific from_dict dataset_class = DatasetFactory.get_class(dataset_type) self.dataset_instance = dataset_class.from_dict( dataset_metadata, project_path=self.workspace_path ) else: raise ValueError("Either `data` or `dataset_metadata` must be provided") # ------------------------- # Model / blackbox wrapper # ------------------------- self.blackbox = ModelFactory.create( framework=framework, model_path=model_path, model_params=model_params, device=device ) self.model_type: str = type(self.blackbox.model).__name__ # ------------------------- # Explainer discovery # ------------------------- self.explainer_manager = ExplainerManager(self.dataset_type, self.model_type) self.explainers: List[Type[GenericExplainerAdapter]] = ( self.explainer_manager.list_available_compatible_explainers() ) # ------------------------- # Explanation storage # ------------------------- self.explanations: List[Dict[str, Any]] = [] logger.info(f"Project {self.id} initialized with dataset {self.dataset_type} and model {self.model_type}.") logger.info(f"Found {len(self.explainers)} compatible explainers.") # Persist metadata immediately self._save_metadata() # ------------------------------------------------------------------------- # Pipeline execution # -------------------------------------------------------------------------
[docs] def run_explanation_pipeline(self, pipeline: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Execute a user-defined explainability pipeline. Each step should include: - 'explainer': str (explainer class name) - 'mode': 'local' | 'global' (optional; inferred if missing) - 'params': dict (optional; 'local' requires 'instance_index') :param pipeline: List of pipeline steps :return: List of explanation records """ results: List[Dict[str, Any]] = [] explainer_map = {cls.explainer_name: cls for cls in self.explainers} for step in pipeline: explainer_name = step.get("explainer") if not explainer_name: raise ValueError("Pipeline step missing 'explainer' field") mode = step.get("mode", "").lower() params = step.get("params", {}) if not mode: mode = "local" if "instance_index" in params else "global" if mode not in ("local", "global"): raise ValueError(f"Invalid mode '{mode}' for explainer '{explainer_name}'") if explainer_name not in explainer_map: raise ValueError( f"Explainer '{explainer_name}' not compatible for dataset '{self.dataset_type}' " f"and framework '{self.framework}'. Available: {list(explainer_map.keys())}" ) explainer_cls = explainer_map[explainer_name] explainer: GenericExplainerAdapter = explainer_cls(model=self.blackbox, dataset=self.dataset_instance) if mode == "local": params = step.get("params", {}) # ----------------------------- # 1. Fetch instance correctly # ----------------------------- instance = None if "instance_filename" in params: key = params["instance_filename"] if hasattr(self.dataset_instance, "get_instance"): instance = self.dataset_instance.get_instance(key) instance_index = None else: raise TypeError( f"Dataset {type(self.dataset_instance).__name__} does not support filename-based access." ) elif "instance_index" in params: instance_index = params["instance_index"] if hasattr(self.dataset_instance, "get_instance"): instance = self.dataset_instance.get_instance(instance_index) else: instance = self.dataset_instance[instance_index] else: raise ValueError( "Local explanation requires either 'instance_index' or 'instance_filename' in params." ) # --------------------------------------- # 2. Pass params to the explainer # (this includes hwc_permutation) # --------------------------------------- explanations_list = explainer.explain_instance( instance, params=params # This forwards hwc_permutation, instance_filename and so on ) else: instance_index = None instance = None explanations_list: list[GenericExplanation] = explainer.explain_global() # Build and save record record = self._create_explanation_record(explainer_cls, mode, instance_index, instance, explanations_list) self.explanations.append(record) self._save_result(record) results.append(record) # Update project metadata after pipeline execution self._save_metadata() return results
# ------------------------------------------------------------------------- # Pipeline from YAML # -------------------------------------------------------------------------
[docs] def run_pipeline_from_yaml(self, yaml_path: str) -> List[Dict[str, Any]]: """ Load pipeline definition from YAML file and execute it. :param yaml_path: Path to YAML file containing pipeline specification. :return: List of explanation records """ yaml_path = os.path.expanduser(yaml_path) if not os.path.exists(yaml_path): raise FileNotFoundError(f"Pipeline YAML not found: {yaml_path}") with open(yaml_path, "r") as f: content = yaml.safe_load(f) or {} pipeline = content.get("pipeline") if pipeline is None: raise KeyError("YAML must contain top-level 'pipeline' key") # Save YAML copy in project workspace saved_pipeline_path = os.path.join(self.workspace_path, "pipelines", os.path.basename(yaml_path)) with open(saved_pipeline_path, "w") as out: json.dump(content, out, indent=2, default=str) return self.run_explanation_pipeline(pipeline)
# ------------------------------------------------------------------------- # Serialization / persistence # ------------------------------------------------------------------------- def _create_explanation_record( self, explainer_cls: Type[GenericExplainerAdapter], mode: str, instance_index: Optional[int], instance, explanations_list: List[GenericExplanation], ) -> Dict[str, Any]: """Build a dictionary record of a single explanation for storage/logging.""" return { "timestamp": datetime.now().isoformat(), "explainer": explainer_cls.__name__, "mode": mode, "instance_index": instance_index if instance_index is not None else -1, "instance": self._serialize_instance(instance), "result": [explanation.visualize() for explanation in explanations_list] } @staticmethod def _serialize_instance(instance) -> Any: """Serialize an instance for storage. Uses `to_dict()` if available.""" if instance is None: return None if isinstance(instance, np.ndarray): return np.char.mod('%s', instance).tolist() elif hasattr(instance, "to_dict"): return instance.to_dict() return str(instance) def _save_result(self, record: Dict[str, Any]) -> None: """Save an explanation record as JSON in the results folder.""" results_dir = os.path.join(self.workspace_path, "results") os.makedirs(results_dir, exist_ok=True) ts = record["timestamp"].replace(":", "_") filename = f"{ts}_{record['explainer']}_{record['mode']}.json" path = os.path.join(results_dir, filename) with open(path, "w") as f: json.dump(record, f, indent=2, default=str) logger.debug(f"Saved explanation result to {path}") def _save_metadata(self) -> None: """Persist project metadata (including dataset) to JSON in the workspace.""" metadata = self.to_dict() # Add dataset metadata metadata["dataset_metadata"] = self.dataset_instance.to_dict() path = os.path.join(self.workspace_path, "project.json") with open(path, "w") as f: json.dump(metadata, f, indent=2, default=str)
[docs] def to_dict(self) -> Dict[str, Any]: """Return a serializable dictionary of project metadata.""" return { "project_name": self.project_name, "id": self.id, "created_at": self.created_at.isoformat(), "dataset_type": self.dataset_type, "framework": self.framework, "model_type": self.model_type, "model_path": self.model_path, "model_params": self.model_params, "device": self.device, "target_variable": self.target_variable, "categorical_columns": self.categorical_columns, "ordinal_columns": self.ordinal_columns, "workspace_path": self.workspace_path, "num_explainers": len(self.explainers), "num_explanations": len(self.explanations), }
[docs] @classmethod def load_from_dict(cls, project_metadata: Dict[str, Any]) -> "Project": """ Restore a project from metadata dictionary. :param project_metadata: Project metadata dictionary. :return: Project instance with dataset and blackbox reconstructed. """ project = cls( project_name=project_metadata["project_name"], data=None, dataset_metadata=project_metadata.get("dataset_metadata"), dataset_type=project_metadata["dataset_type"], framework=project_metadata["framework"], workspace_path=project_metadata["workspace_path"], model_path=project_metadata["model_path"], model_params=project_metadata["model_params"], target_variable=project_metadata["target_variable"], categorical_columns=project_metadata["categorical_columns"], ordinal_columns=project_metadata["ordinal_columns"], device=project_metadata["device"], is_reload=True, id=project_metadata["id"] ) project.created_at = datetime.fromisoformat(project_metadata["created_at"]) return project
def __repr__(self): return ( f"<Project id={self.id}, name={self.project_name}, " f"dataset={self.dataset_type}, framework={self.framework}, model={self.model_type}>" )