Source code for lore_sa.surrogate.decision_tree

import copy
import datetime
import operator
from collections import defaultdict

import numpy as np

from lore_sa.dataset import TabularDataset
from lore_sa.encoder_decoder import EncDec, OneHotEnc, TabularEnc
from lore_sa.logger import logger
from sklearn.tree._tree import TREE_LEAF
from sklearn.tree import DecisionTreeClassifier
import sklearn.model_selection

__all__ = ["Surrogate","DecisionTreeSurrogate"]

from lore_sa.rule import Expression, Rule
from lore_sa.surrogate.surrogate import Surrogate
from lore_sa.util import vector2dict, multilabel2str
import lore_sa


[docs]class DecisionTreeSurrogate(Surrogate):
[docs] def __init__(self, kind = None, preprocessing=None): super().__init__(kind, preprocessing) self.dt = None
[docs] def train(self, Z, Yb, weights = None, class_values = None, multi_label: bool= False, one_vs_rest: bool = False, cv = 5, prune_tree: bool = False): """ :param Z: The training input samples :param Yb: The target values (class labels) as integers or strings. :param weights: Sample weights. :param class_values: :param [bool] multi_label: :param [bool] one_vs_rest: :param [int] cv: :param [bool] prune_tree: :return: """ self.dt = DecisionTreeClassifier() if prune_tree is True: param_list = {'min_samples_split': [ 0.01, 0.05, 0.1, 0.2, 3, 2], 'min_samples_leaf': [0.001, 0.01, 0.05, 0.1, 2, 4], 'splitter' : ['best', 'random'], 'max_depth': [None, 2, 10, 12, 16, 20, 30], 'criterion': ['entropy', 'gini'], 'max_features': [0.2, 1, 5, 'auto', 'sqrt', 'log2'] } if multi_label is False or (multi_label is True and one_vs_rest is True): if len(class_values) == 2 or (multi_label and one_vs_rest): scoring = 'precision' else: scoring = 'precision_macro' else: scoring = 'precision_samples' dt_search = sklearn.model_selection.HalvingGridSearchCV(self.dt, param_grid=param_list, scoring=scoring, cv=cv, n_jobs=-1) logger.info('Search the best estimator') logger.info('Start time: {0}'.format(datetime.datetime.now())) dt_search.fit(Z, Yb, sample_weight=weights) logger.info('End time: {0}'.format(datetime.datetime.now())) self.dt = dt_search.best_estimator_ logger.info('Pruning') self.prune_duplicate_leaves(self.dt) else: self.dt.fit(Z, Yb) return self.dt
[docs] def is_leaf(self, inner_tree, index): """Check whether node is leaf node""" return (inner_tree.children_left[index] == TREE_LEAF and inner_tree.children_right[index] == TREE_LEAF)
[docs] def prune_index(self, inner_tree, decisions, index=0): """ Start pruning from the bottom - if we start from the top, we might miss nodes that become leaves during pruning. Do not use this directly - use prune_duplicate_leaves instead. """ if not self.is_leaf(inner_tree, inner_tree.children_left[index]): self.prune_index(inner_tree, decisions, inner_tree.children_left[index]) if not self.is_leaf(inner_tree, inner_tree.children_right[index]): self.prune_index(inner_tree, decisions, inner_tree.children_right[index]) # Prune children if both children are leaves now and make the same decision: if (self.is_leaf(inner_tree, inner_tree.children_left[index]) and self.is_leaf(inner_tree, inner_tree.children_right[index]) and (decisions[index] == decisions[inner_tree.children_left[index]]) and (decisions[index] == decisions[inner_tree.children_right[index]])): # turn node into a leaf by "unlinking" its children inner_tree.children_left[index] = TREE_LEAF inner_tree.children_right[index] = TREE_LEAF logger.info("Pruned {}".format(index))
[docs] def prune_duplicate_leaves(self, dt): """Remove leaves if both""" decisions = dt.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node self.prune_index(dt.tree_, decisions)
[docs] def get_rule(self, x: np.array, dataset: TabularDataset, encoder: EncDec = None): """ Extract the rules as the promises and consequences {p -> y}, starting from a Decision Tree >>> {( income > 90) -> grant), ( job = employer) -> grant) } :param [Numpy Array] x: instance encoded of the dataset to extract the rule :param [TabularDataset] dataset: Neighborhood instances :param [EncDec] encdec: :return [Rule]: Rule objects """ x = x.reshape(1, -1) feature = self.dt.tree_.feature threshold = self.dt.tree_.threshold predicted_class = self.dt.predict(x) consequence = Expression(variable=dataset.class_name, operator=operator.eq, value=encoder.decode_target_class(predicted_class)) leave_id = self.dt.apply(x) node_index = self.dt.decision_path(x).indices feature_names = dataset.get_features_names() numeric_columns = dataset.get_numeric_columns() premises = list() for node_id in node_index: if leave_id[0] == node_id: break else: if encoder is not None: if isinstance(encoder, OneHotEnc) or isinstance(encoder, TabularEnc): attribute = feature_names[feature[node_id]] if attribute not in numeric_columns: print (attribute) thr = False if x[0][feature[node_id]] <= threshold[node_id] else True op = operator.eq else: thr = threshold[node_id] op = operator.le if x[0][feature[node_id]] <= threshold[node_id] else operator.gt else: print (type(encoder)) raise Exception('unknown encoder instance ') else: op = operator.le if x[0][feature[node_id]] <= threshold[node_id] else operator.gt attribute = feature_names[feature[node_id]] thr = threshold[node_id] premises.append(Expression(attribute, op, thr)) premises = self.compact_premises(premises) return Rule(premises=premises, consequences=consequence, encoder=encoder)
[docs] def compact_premises(self, premises_list): """ Remove the same premises with different values of threashold :param premises_list: List of Expressions that defines the premises :return: """ attribute_list = defaultdict(list) for premise in premises_list: attribute_list[premise.variable].append(premise) compact_plist = list() for att, alist in attribute_list.items(): if len(alist) > 1: min_thr = None max_thr = None for av in alist: if av.operator == operator.le: max_thr = min(av.value, max_thr) if max_thr else av.value elif av.operator == operator.gt: min_thr = max(av.value, min_thr) if min_thr else av.value if max_thr: compact_plist.append(Expression(att, operator.le, max_thr)) if min_thr: compact_plist.append(Expression(att, operator.gt, min_thr)) else: compact_plist.append(alist[0]) return compact_plist
[docs] def get_counterfactual_rules(self, x: np.array, class_name, feature_names, neighborhood_dataset: TabularDataset, features_map_inv = None, multi_label: bool =False, encoder: EncDec = None, filter_crules=None, constraints: dict = None, unadmittible_features: list = None): """ :param [Numpy Array] x: instance encoded of the dataset :param [Numpy Array] neighborhood_dataset: Neighborhood instances :param [TabularDataset] dataset: :param features_map_inv: :param [bool] multi_label: :param [EncDec] encdec: :param filter_crules: :param [dict] constraints: :param [list] unadmittible_features: List of unadmittible features :return: """ class_values = neighborhood_dataset.get_class_values() class_name = neighborhood_dataset.class_name clen = np.inf crule_list = list() delta_list = list() #y = self.dt.predict(neighborhood_dataset.df)[0] #Y = self.dt.predict(neighborhood_dataset.df) Z_list = neighborhood_dataset.df[[x for x in feature_names if x!=class_name]].to_numpy() x_dict = vector2dict(x, neighborhood_dataset.get_features_names()) for z in Z_list: # estraggo la regola per ognuno crule = self.get_rule(x = z, dataset= neighborhood_dataset, encoder = encoder) delta = self.get_falsified_conditions(x_dict, crule) num_falsified_conditions = len(delta) if unadmittible_features is not None: is_feasible = self.check_feasibility_of_falsified_conditions(delta, unadmittible_features) if is_feasible is False: continue if constraints is not None: ##TODO to_remove = list() for p in crule.premises: if p.variable in constraints.keys(): if p.operator == constraints[p.variable]['op']: if p.thr > constraints[p.variable]['thr']: break # caso corretto ##TODO if filter_crules is not None: xc = self.apply_counterfactual(x, delta, neighborhood_dataset) bb_outcomec = filter_crules(xc.reshape(1, -1))[0] bb_outcomec = class_values[bb_outcomec] if isinstance(class_name, str) else multilabel2str(bb_outcomec, class_values) dt_outcomec = crule.cons if bb_outcomec == dt_outcomec: if num_falsified_conditions < clen: clen = num_falsified_conditions crule_list = [crule] delta_list = [delta] elif num_falsified_conditions == clen: if delta not in delta_list: crule_list.append(crule) delta_list.append(delta) else: if num_falsified_conditions < clen: clen = num_falsified_conditions crule_list = [crule] delta_list = [delta] print (crule,delta) elif num_falsified_conditions == clen: if delta not in delta_list: crule_list.append(crule) delta_list.append(delta) return crule_list, delta_list
[docs] def get_falsified_conditions(self, x_dict: dict, crule: Rule): """ Check the wrong conditions :param x_dict: :param crule: :return: list of falsified premises """ delta = [] for p in crule.premises: try: if p.operator == operator.le and x_dict[p.variable] > p.value: delta.append(p) elif p.operator == operator.gt and x_dict[p.variable] <= p.value: delta.append(p) except: #print('pop', p.operator2string(), 'xd', x_dict, 'xd di p ', p.variable, 'hthrr', p.value) continue return delta
[docs] def check_feasibility_of_falsified_conditions(self, delta, unadmittible_features: list): """ Check if a falsifield confition is in an unadmittible feature list :param delta: :param unadmittible_features: :return: True or False """ for p in delta: if p.variable in unadmittible_features: if unadmittible_features[p.variable] is None: return False else: if unadmittible_features[p.variable] == p.operator: return False return True
def apply_counterfactual(self, x, delta, dataset, features_map=None, features_map_inv=None, numeric_columns=None): feature_names = dataset.get_features_names() x_dict = vector2dict(x, feature_names) x_copy_dict = copy.deepcopy(x_dict) for p in delta: if p.variable in numeric_columns: if p.value == int(p.value): gap = 1.0 else: decimals = list(str(p.value).split('.')[1]) for idx, e in enumerate(decimals): if e != '0': break gap = 1 / (10 ** (idx + 1)) if p.operator == operator.gt: x_copy_dict[p.variable] = p.value + gap else: x_copy_dict[p.variable] = p.value else: fn = p.variable if p.operator == operator.gt: if features_map is not None: fi = list(feature_names).index(p.att) fi = features_map_inv[fi] for fv in features_map[fi]: x_copy_dict['%s=%s' % (fn, fv)] = 0.0 x_copy_dict[p.att] = 1.0 else: if features_map is not None: fi = list(feature_names).index(p.att) fi = features_map_inv[fi] for fv in features_map[fi]: x_copy_dict['%s=%s' % (fn, fv)] = 1.0 x_copy_dict[p.att] = 0.0 x_counterfactual = np.zeros(len(x_dict)) for i, fn in enumerate(feature_names): x_counterfactual[i] = x_copy_dict[fn] return x_counterfactual