import json
import pickle
import bitarray
import numpy as np
from lore_sa.rule import RuleEncoder, ExpressionEncoder
from lore_sa.rule import json2rule, json2expression
[docs]class Explanation(object):
[docs] def __init__(self):
self.bb_pred = None
self.dt_pred = None
self.rule = None
self.crules = None
self.deltas = None
self.feature_importance = None
self.feature_importance_all = None
self.exemplars = None
self.cexemplars = None
self.fidelity = None
self.dt = None
def __str__(self):
deltas_str = '{ '
for i, delta in enumerate(self.deltas):
deltas_str += ' { ' if i > 0 else '{ '
deltas_str += ', '.join([str(s) for s in delta])
deltas_str += ' },\n'
deltas_str = deltas_str[:-2] + ' }'
return 'r = %s\nc = %s\nfi = %s\nfia = %s\nex = %s\ncex = %s' % (self.rule, deltas_str, self.feature_importance, self.feature_importance_all, self.exemplars, self.cexemplars)
def rstr(self):
return self.rule
def cstr(self):
deltas_str = '{ '
for i, delta in enumerate(self.deltas):
deltas_str += '{ ' if i > 0 else '{ '
deltas_str += ', '.join([str(s) for s in delta])
deltas_str += ' } --> %s, ' % self.crules[i]._cstr()
deltas_str = deltas_str[:-2] + ' }'
return deltas_str
[docs]class ExplanationEncoder(json.JSONEncoder):
""" Special json encoder for Rule types """
[docs] def default(self, obj):
if isinstance(obj, Explanation):
re = RuleEncoder()
ce = ExpressionEncoder()
ba = bitarray.bitarray()
ba.frombytes(pickle.dumps(obj.dt))
bal = ba.tolist()
json_obj = {
'bb_pred': obj.bb_pred,
'dt_pred': obj.dt_pred,
'rule': re.default(obj.rule),
'crules': [re.default(c) for c in obj.crules],
'deltas': [[ce.default(c) for c in cs] for cs in obj.deltas],
'fidelity': obj.fidelity,
'dt': bal,
}
return json_obj
return ExpressionEncoder().default(obj)
[docs]def json2explanation(obj):
exp = Explanation()
exp.bb_pred = obj['bb_pred']
exp.dt_pred = obj['dt_pred']
exp.rule = json2rule(obj['rule'])
exp.crules = [json2rule(c) for c in obj['crules']]
exp.deltas = [[json2expression(c) for c in cs] for cs in obj['deltas']]
exp.dt = pickle.loads(bitarray.bitarray(obj['dt']).tobytes())
exp.fidelity = obj['fidelity']
return exp
[docs]class MultilabelExplanation(Explanation):
[docs] def __init__(self):
super(MultilabelExplanation).__init__()
self.dt_list = None
self.rule_list = None
self.crules_list = None
self.deltas_list = None
[docs]class ImageExplanation(Explanation):
[docs] def __init__(self, img, segments):
super(ImageExplanation).__init__()
self.img = img
self.segments = segments
def get_image_rule(self, hide_rest=False, num_features=None, min_importance=0.0):
mask = np.zeros(self.segments.shape, self.segments.dtype)
if hide_rest:
img2show = np.zeros(self.img.shape).astype(int)
else:
img2show = np.copy(self.img)
num_features = len(self.dt.feature_importances_) if num_features is None else num_features
features = np.argsort(self.dt.feature_importances_)[:num_features]
for p in self.rule.premises:
if p.att not in features or self.dt.feature_importances_[p.att] < min_importance:
continue
f = p.att
w = -1 if p.op == '<=' else 1
c = 0 if w < 0 else 1
mask[self.segments == f] = 1 if w < 0 else 2
img2show[self.segments == f] = self.img[self.segments == f].copy()
if not hide_rest:
img2show[self.segments == f, c] = np.max(self.img)
for cp in [0, 1, 2]:
if c == cp:
continue
return img2show, mask
def get_image_counterfactuals(self, hide_rest=False, num_features=None, min_importance=0.0):
imgs2show, masks = list(), list()
coutcomes = list()
for delta, crule in zip(self.deltas, self.crules):
mask = np.zeros(self.segments.shape, self.segments.dtype)
if hide_rest:
img2show = np.zeros(self.img.shape).astype(int)
else:
img2show = np.copy(self.img)
num_features = len(self.dt.feature_importances_) if num_features is None else num_features
features = np.argsort(self.dt.feature_importances_)[:num_features]
for p in delta:
if p.att not in features or self.dt.feature_importances_[p.att] < min_importance:
continue
f = p.att
w = -1 if p.op == '<=' else 1
c = 0 if w < 0 else 1
mask[self.segments == f] = 1 if w < 0 else 2
img2show[self.segments == f] = self.img[self.segments == f].copy()
if not hide_rest:
img2show[self.segments == f, c] = np.max(self.img)
for cp in [0, 1, 2]:
if c == cp:
continue
imgs2show.append(img2show)
masks.append(mask)
coutcomes.append(crule.cons)
return imgs2show, masks, coutcomes
[docs]class TextExplanation(Explanation):
[docs] def __init__(self, text, indexed_text):
super(TextExplanation).__init__()
self.text = text
self.indexed_text = indexed_text
def get_text_rule(self, num_features=None, min_importance=0.0):
num_features = len(self.dt.feature_importances_) if num_features is None else num_features
features = np.argsort(self.dt.feature_importances_)[:num_features]
inwords, outwords = list(), list()
for p in self.rule.premises:
if p.att not in features or self.dt.feature_importances_[p.att] < min_importance:
continue
word = self.indexed_text.word(p.att)
if p.op == '<=':
outwords.append(word)
else:
inwords.append(word)
text_premise = ', '.join(inwords) if len(inwords) > 0 else ''
text_premise += ', ' if len(inwords) > 0 and len(outwords) > 0 else ''
text_premise += ', '.join(['¬ %s' % word for word in outwords]) if len(outwords) > 0 else ''
text_rule = '{ %s } --> %s' % (text_premise, self.rule._cstr())
return text_rule
def get_text_counterfactuals(self, num_features=None, min_importance=0.0):
num_features = len(self.dt.feature_importances_) if num_features is None else num_features
features = np.argsort(self.dt.feature_importances_)[:num_features]
text_counterfactuals = list()
for delta, crule in zip(self.deltas, self.crules):
inwords, outwords = list(), list()
for p in delta:
if p.att not in features or self.dt.feature_importances_[p.att] < min_importance:
continue
word = self.indexed_text.word(p.att)
if p.op == '<=':
outwords.append(word)
else:
inwords.append(word)
text_premise = ', '.join(inwords) if len(inwords) > 0 else ''
text_premise += ', ' if len(inwords) > 0 and len(outwords) > 0 else ''
text_premise += ', '.join(['¬ %s' % word for word in outwords]) if len(outwords) > 0 else ''
text_rule = '{ %s } --> %s' % (text_premise, crule._cstr())
text_counterfactuals.append(text_rule)
return text_counterfactuals