Source code for lore_sa.dataset.dataset

__all__ = ["Dataset"]

from abc import abstractmethod

from lore_sa.logger import logger


[docs]class Dataset(): """ Generic class to handle datasets """
[docs] @abstractmethod def update_descriptor(self): """ it creates the dataset descriptor dictionary """
[docs] def set_target_label(self, descriptor): """ Set the target column into the dataset descriptor :param descriptor: :return: a modified version of the input descriptor with a new key 'target' """ if self.class_name is None: logger.warning("No target class is defined") return descriptor for type in descriptor: for k in descriptor[type]: if k == self.class_name: descriptor['target'] = {k: descriptor[type][k]} descriptor[type].pop(k) return descriptor return descriptor
def set_descriptor(self, descriptor): self.descriptor = descriptor self.descriptor = self.set_target_label(self.descriptor)
[docs] def set_class_name(self,class_name: str): """ Set the class name. Only the column name string :param [str] class_name: :return: """ self.class_name = class_name self.descriptor = self.set_target_label(self.descriptor)
[docs] def get_class_values(self): """ return the list of values of the target column :return: """ if self.class_name is None: raise Exception("ERR: class_name is None. Set class_name with set_class_name('<column name>')") print("test1", self.descriptor['target']) return self.descriptor['target'][self.class_name]['distinct_values']
def get_numeric_columns(self): numeric_columns = list(self.descriptor['numeric'].keys()) return numeric_columns def get_categorical_columns(self): categorical_columns = list(self.descriptor['categorical'].keys()) return categorical_columns def get_feature_names(self): return self.get_numeric_columns() + self.get_categorical_columns() def get_number_of_features(self): return len(self.get_feature_names()) def get_feature_name(self, index): pass
[docs] def get_feature_name(self, index): """ Get the feature name by index :param index: :return: the name of the corresponding feature """ for category in self.descriptor.keys(): for name in self.descriptor[category].keys(): if self.descriptor[category][name]['index'] == index: return name