Source code for lore_sa.explanation

import json
import pickle
import bitarray
import numpy as np

from lore_sa.rule import RuleEncoder, ExpressionEncoder
from lore_sa.rule import json2rule, json2expression


[docs]class Explanation(object):
[docs] def __init__(self): self.bb_pred = None self.dt_pred = None self.rule = None self.crules = None self.deltas = None self.feature_importance = None self.feature_importance_all = None self.exemplars = None self.cexemplars = None self.fidelity = None self.dt = None
def __str__(self): deltas_str = '{ ' for i, delta in enumerate(self.deltas): deltas_str += ' { ' if i > 0 else '{ ' deltas_str += ', '.join([str(s) for s in delta]) deltas_str += ' },\n' deltas_str = deltas_str[:-2] + ' }' return 'r = %s\nc = %s\nfi = %s\nfia = %s\nex = %s\ncex = %s' % (self.rule, deltas_str, self.feature_importance, self.feature_importance_all, self.exemplars, self.cexemplars) def rstr(self): return self.rule def cstr(self): deltas_str = '{ ' for i, delta in enumerate(self.deltas): deltas_str += '{ ' if i > 0 else '{ ' deltas_str += ', '.join([str(s) for s in delta]) deltas_str += ' } --> %s, ' % self.crules[i]._cstr() deltas_str = deltas_str[:-2] + ' }' return deltas_str
[docs]class ExplanationEncoder(json.JSONEncoder): """ Special json encoder for Rule types """
[docs] def default(self, obj): if isinstance(obj, Explanation): re = RuleEncoder() ce = ExpressionEncoder() ba = bitarray.bitarray() ba.frombytes(pickle.dumps(obj.dt)) bal = ba.tolist() json_obj = { 'bb_pred': obj.bb_pred, 'dt_pred': obj.dt_pred, 'rule': re.default(obj.rule), 'crules': [re.default(c) for c in obj.crules], 'deltas': [[ce.default(c) for c in cs] for cs in obj.deltas], 'fidelity': obj.fidelity, 'dt': bal, } return json_obj return ExpressionEncoder().default(obj)
[docs]def json2explanation(obj): exp = Explanation() exp.bb_pred = obj['bb_pred'] exp.dt_pred = obj['dt_pred'] exp.rule = json2rule(obj['rule']) exp.crules = [json2rule(c) for c in obj['crules']] exp.deltas = [[json2expression(c) for c in cs] for cs in obj['deltas']] exp.dt = pickle.loads(bitarray.bitarray(obj['dt']).tobytes()) exp.fidelity = obj['fidelity'] return exp
[docs]class MultilabelExplanation(Explanation):
[docs] def __init__(self): super(MultilabelExplanation).__init__() self.dt_list = None self.rule_list = None self.crules_list = None self.deltas_list = None
[docs]class ImageExplanation(Explanation):
[docs] def __init__(self, img, segments): super(ImageExplanation).__init__() self.img = img self.segments = segments
def get_image_rule(self, hide_rest=False, num_features=None, min_importance=0.0): mask = np.zeros(self.segments.shape, self.segments.dtype) if hide_rest: img2show = np.zeros(self.img.shape).astype(int) else: img2show = np.copy(self.img) num_features = len(self.dt.feature_importances_) if num_features is None else num_features features = np.argsort(self.dt.feature_importances_)[:num_features] for p in self.rule.premises: if p.att not in features or self.dt.feature_importances_[p.att] < min_importance: continue f = p.att w = -1 if p.op == '<=' else 1 c = 0 if w < 0 else 1 mask[self.segments == f] = 1 if w < 0 else 2 img2show[self.segments == f] = self.img[self.segments == f].copy() if not hide_rest: img2show[self.segments == f, c] = np.max(self.img) for cp in [0, 1, 2]: if c == cp: continue return img2show, mask def get_image_counterfactuals(self, hide_rest=False, num_features=None, min_importance=0.0): imgs2show, masks = list(), list() coutcomes = list() for delta, crule in zip(self.deltas, self.crules): mask = np.zeros(self.segments.shape, self.segments.dtype) if hide_rest: img2show = np.zeros(self.img.shape).astype(int) else: img2show = np.copy(self.img) num_features = len(self.dt.feature_importances_) if num_features is None else num_features features = np.argsort(self.dt.feature_importances_)[:num_features] for p in delta: if p.att not in features or self.dt.feature_importances_[p.att] < min_importance: continue f = p.att w = -1 if p.op == '<=' else 1 c = 0 if w < 0 else 1 mask[self.segments == f] = 1 if w < 0 else 2 img2show[self.segments == f] = self.img[self.segments == f].copy() if not hide_rest: img2show[self.segments == f, c] = np.max(self.img) for cp in [0, 1, 2]: if c == cp: continue imgs2show.append(img2show) masks.append(mask) coutcomes.append(crule.cons) return imgs2show, masks, coutcomes
[docs]class TextExplanation(Explanation):
[docs] def __init__(self, text, indexed_text): super(TextExplanation).__init__() self.text = text self.indexed_text = indexed_text
def get_text_rule(self, num_features=None, min_importance=0.0): num_features = len(self.dt.feature_importances_) if num_features is None else num_features features = np.argsort(self.dt.feature_importances_)[:num_features] inwords, outwords = list(), list() for p in self.rule.premises: if p.att not in features or self.dt.feature_importances_[p.att] < min_importance: continue word = self.indexed_text.word(p.att) if p.op == '<=': outwords.append(word) else: inwords.append(word) text_premise = ', '.join(inwords) if len(inwords) > 0 else '' text_premise += ', ' if len(inwords) > 0 and len(outwords) > 0 else '' text_premise += ', '.join([%s' % word for word in outwords]) if len(outwords) > 0 else '' text_rule = '{ %s } --> %s' % (text_premise, self.rule._cstr()) return text_rule def get_text_counterfactuals(self, num_features=None, min_importance=0.0): num_features = len(self.dt.feature_importances_) if num_features is None else num_features features = np.argsort(self.dt.feature_importances_)[:num_features] text_counterfactuals = list() for delta, crule in zip(self.deltas, self.crules): inwords, outwords = list(), list() for p in delta: if p.att not in features or self.dt.feature_importances_[p.att] < min_importance: continue word = self.indexed_text.word(p.att) if p.op == '<=': outwords.append(word) else: inwords.append(word) text_premise = ', '.join(inwords) if len(inwords) > 0 else '' text_premise += ', ' if len(inwords) > 0 and len(outwords) > 0 else '' text_premise += ', '.join([%s' % word for word in outwords]) if len(outwords) > 0 else '' text_rule = '{ %s } --> %s' % (text_premise, crule._cstr()) text_counterfactuals.append(text_rule) return text_counterfactuals