Source code for lore_sa.surrogate.decision_tree

import copy
import datetime
import operator
from collections import defaultdict

import numpy as np
from sklearn.metrics import confusion_matrix

from ..encoder_decoder import EncDec, ColumnTransformerEnc
from ..logger import logger
from sklearn.tree._tree import TREE_LEAF
from sklearn.tree import DecisionTreeClassifier
import sklearn.model_selection
from sklearn.experimental import enable_halving_search_cv

__all__ = ["Surrogate", "DecisionTreeSurrogate"]

from ..rule import Expression, Rule
from .surrogate import Surrogate
from ..util import vector2dict, multilabel2str


[docs]class DecisionTreeSurrogate(Surrogate):
[docs] def __init__(self, kind=None, preprocessing=None, class_values=None, multi_label: bool = False, one_vs_rest: bool = False, cv=5, prune_tree: bool = False, ): super().__init__(kind, preprocessing) self.dt = None self.fidelity = None self.confusion_matrix = None self.prune_tree = prune_tree self.class_values = class_values self.multi_label = multi_label self.one_vs_rest = one_vs_rest self.cv = cv
[docs] def train(self, Z, Yb, weights=None, ): """ :param Z: The training input samples :param Yb: The target values (class labels) as integers or strings. :param weights: Sample weights. :param class_values: :param [bool] multi_label: :param [bool] one_vs_rest: :param [int] cv: :param [bool] prune_tree: :return: """ self.dt = DecisionTreeClassifier(class_weight='balanced', random_state=42) if self.prune_tree is True: param_list = {'min_samples_split': [0.01, 0.05, 0.1, 0.2, 3, 2], 'min_samples_leaf': [0.001, 0.01, 0.05, 0.1, 2, 4], 'splitter': ['best', 'random'], 'max_depth': [None, 2, 10, 12, 16, 20, 30], 'criterion': ['entropy', 'gini'], 'max_features': [0.2, 1, 5, 'auto', 'sqrt', 'log2'] } # if multi_label is False or (multi_label is True and one_vs_rest is True): # if len(class_values) == 2 or (multi_label and one_vs_rest): # scoring = 'precision' # else: # scoring = 'precision_macro' # else: scoring = 'precision_micro' dt_search = sklearn.model_selection.HalvingGridSearchCV(self.dt, param_grid=param_list, scoring=scoring, cv=self.cv, n_jobs=-1) logger.info('Search the best estimator') logger.info('Start time: {0}'.format(datetime.datetime.now())) dt_search.fit(Z, Yb, sample_weight=weights) logger.info('End time: {0}'.format(datetime.datetime.now())) self.dt = dt_search.best_estimator_ logger.info('Pruning') self.prune_duplicate_leaves(self.dt) else: self.dt.fit(Z, Yb) self.fidelity = self.dt.score(Z, Yb) self.confusion_matrix = confusion_matrix(Yb, self.dt.predict(Z)) logger.info("Fidelity: {}".format(self.fidelity)) return self.dt
[docs] def is_leaf(self, inner_tree, index): """Check whether node is leaf node""" return (inner_tree.children_left[index] == TREE_LEAF and inner_tree.children_right[index] == TREE_LEAF)
[docs] def prune_index(self, inner_tree, decisions, index=0): """ Start pruning from the bottom - if we start from the top, we might miss nodes that become leaves during pruning. Do not use this directly - use prune_duplicate_leaves instead. """ if not self.is_leaf(inner_tree, inner_tree.children_left[index]): self.prune_index(inner_tree, decisions, inner_tree.children_left[index]) if not self.is_leaf(inner_tree, inner_tree.children_right[index]): self.prune_index(inner_tree, decisions, inner_tree.children_right[index]) # Prune children if both children are leaves now and make the same decision: if (self.is_leaf(inner_tree, inner_tree.children_left[index]) and self.is_leaf(inner_tree, inner_tree.children_right[index]) and (decisions[index] == decisions[inner_tree.children_left[index]]) and (decisions[index] == decisions[inner_tree.children_right[index]])): # turn node into a leaf by "unlinking" its children inner_tree.children_left[index] = TREE_LEAF inner_tree.children_right[index] = TREE_LEAF logger.info("Pruned {}".format(index))
[docs] def prune_duplicate_leaves(self, dt): """Remove leaves if both""" decisions = dt.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node self.prune_index(dt.tree_, decisions)
[docs] def get_rule(self, z: np.array, encoder: EncDec = None): """ Extract the rules as the promises and consequences {p -> y}, starting from a Decision Tree {( income > 90) -> grant), ( job = employer) -> grant) } :param [Numpy Array] z: instance encoded of the dataset to extract the rule :param [EncDec] encdec: :return [Rule]: Rule objects """ z = z.reshape(1, -1) feature = self.dt.tree_.feature threshold = self.dt.tree_.threshold predicted_class = self.dt.predict(z) inv_transform_predicted_class = encoder.decode_target_class([predicted_class])[0] target_feature_name = list(encoder.encoded_descriptor['target'].keys())[0] consequence = Expression(variable=target_feature_name, operator=operator.eq, value=inv_transform_predicted_class[0]) leave_id = self.dt.apply(z) node_index = self.dt.decision_path(z).indices feature_names = list(encoder.encoded_features.values()) numeric_columns = list(encoder.encoded_descriptor['numeric'].keys()) premises = list() for node_id in node_index: if leave_id[0] == node_id: break else: attribute = feature_names[feature[node_id]] if attribute not in numeric_columns: # this is a categorical feature # print(f"{attribute} has value {x[0][feature[node_id]]} and threshold is {threshold[node_id]}") thr = False if z[0][feature[node_id]] <= threshold[node_id] else True op = operator.eq else: thr = threshold[node_id] op = operator.le if z[0][feature[node_id]] <= threshold[node_id] else operator.gt premises.append(Expression(attribute, op, thr)) premises = self.compact_premises(premises) return Rule(premises=premises, consequences=consequence, encoder=encoder)
[docs] def compact_premises(self, premises_list): """ Remove the same premises with different values of threashold :param premises_list: List of Expressions that defines the premises :return: """ attribute_list = defaultdict(list) for premise in premises_list: attribute_list[premise.variable].append(premise) compact_plist = list() for att, alist in attribute_list.items(): if len(alist) > 1: min_thr = None max_thr = None for av in alist: if av.operator == operator.le: max_thr = min(av.value, max_thr) if max_thr else av.value elif av.operator == operator.gt: min_thr = max(av.value, min_thr) if min_thr else av.value if max_thr: compact_plist.append(Expression(att, operator.le, max_thr)) if min_thr: compact_plist.append(Expression(att, operator.gt, min_thr)) else: compact_plist.append(alist[0]) return compact_plist
[docs] def get_counterfactual_rules(self, z: np.array, neighborhood_train_X: np.array, neighborhood_train_Y: np.array, encoder: EncDec = None, filter_crules=None, constraints: dict = None, unadmittible_features: list = None): feature_names = list(encoder.encoded_features.values()) # numeric_columns = list(encoder.encoded_descriptor['numeric'].keys()) predicted_class = self.dt.predict(z.reshape(1, -1))[0] # inv_transform_predicted_class = encoder.encoder.named_transformers_.get('target')\ # .inverse_transform([predicted_class])[0] #TODO: modify to generalize to multiclasses class_name = list(encoder.encoded_descriptor['target'].keys())[0] class_values = list(encoder.encoded_descriptor['target'].values())[0]['distinct_values'] clen = np.inf crule_list = list() delta_list = list() # y = self.dt.predict(neighborhood_dataset.df)[0] # Y = self.dt.predict(neighborhood_dataset.df) x_dict = vector2dict(z, feature_names) # select the subset of ```neighborhood_train_X``` that have a classification different from the input x Z1 = neighborhood_train_X[np.where(neighborhood_train_Y != predicted_class)] # We search for the shortest rule among those that support the elements in Z1 for zi in Z1: # crule = self.get_rule(z=zi, encoder=encoder) delta = self.get_falsified_conditions(x_dict, crule) num_falsified_conditions = len(delta) if unadmittible_features is not None: is_feasible = self.check_feasibility_of_falsified_conditions(delta, unadmittible_features) if is_feasible is False: continue if constraints is not None: ##TODO to_remove = list() for p in crule.premises: if p.variable in constraints.keys(): if p.operator == constraints[p.variable]['op']: if p.thr > constraints[p.variable]['thr']: break # caso corretto ##TODO if filter_crules is not None: xc = self.apply_counterfactual(z, delta, feature_names) bb_outcomec = filter_crules(xc.reshape(1, -1))[0] bb_outcomec = class_values[bb_outcomec] if isinstance(class_name, str) else multilabel2str(bb_outcomec, class_values) dt_outcomec = crule.cons if bb_outcomec == dt_outcomec: if num_falsified_conditions < clen: clen = num_falsified_conditions crule_list = [crule] delta_list = [delta] elif num_falsified_conditions == clen: if delta not in delta_list: crule_list.append(crule) delta_list.append(delta) else: if num_falsified_conditions < clen: clen = num_falsified_conditions crule_list = [crule] delta_list = [delta] # print(crule, delta) elif num_falsified_conditions == clen: if delta not in delta_list: crule_list.append(crule) delta_list.append(delta) return crule_list, delta_list
[docs] def get_falsified_conditions(self, x_dict: dict, crule: Rule): """ Check the wrong conditions :param x_dict: :param crule: :return: list of falsified premises """ delta = [] for p in crule.premises: try: if p.operator == operator.le and x_dict[p.variable] > p.value: delta.append(p) elif p.operator == operator.gt and x_dict[p.variable] <= p.value: delta.append(p) except: # print('pop', p.operator2string(), 'xd', x_dict, 'xd di p ', p.variable, 'hthrr', p.value) continue return delta
[docs] def check_feasibility_of_falsified_conditions(self, delta, unadmittible_features: list): """ Check if a falsifield confition is in an unadmittible feature list :param delta: :param unadmittible_features: :return: True or False """ for p in delta: if p.variable in unadmittible_features: if unadmittible_features[p.variable] is None: return False else: if unadmittible_features[p.variable] == p.operator: return False return True
def apply_counterfactual(self, x, delta, feature_names:list, features_map=None, features_map_inv=None, numeric_columns=None): x_dict = vector2dict(x, feature_names) x_copy_dict = copy.deepcopy(x_dict) for p in delta: if p.variable in numeric_columns: if p.value == int(p.value): gap = 1.0 else: decimals = list(str(p.value).split('.')[1]) for idx, e in enumerate(decimals): if e != '0': break gap = 1 / (10 ** (idx + 1)) if p.operator == operator.gt: x_copy_dict[p.variable] = p.value + gap else: x_copy_dict[p.variable] = p.value else: fn = p.variable if p.operator == operator.gt: if features_map is not None: fi = list(feature_names).index(p.att) fi = features_map_inv[fi] for fv in features_map[fi]: x_copy_dict['%s=%s' % (fn, fv)] = 0.0 x_copy_dict[p.att] = 1.0 else: if features_map is not None: fi = list(feature_names).index(p.att) fi = features_map_inv[fi] for fv in features_map[fi]: x_copy_dict['%s=%s' % (fn, fv)] = 1.0 x_copy_dict[p.att] = 0.0 x_counterfactual = np.zeros(len(x_dict)) for i, fn in enumerate(feature_names): x_counterfactual[i] = x_copy_dict[fn] return x_counterfactual