Source code for lore_sa.rule

import json

from lore_sa.encoder_decoder import EncDec
from lore_sa.util import vector2dict
from typing import Callable
import operator

[docs]def json2expression(obj): return Expression(obj['att'], obj['op'], obj['thr'])
[docs]def json2rule(obj): premises = [json2expression(p) for p in obj['premise']] cons = obj['cons'] return Rule(premises, cons)
[docs]class Expression(object): """ Utility object to define a logical expression. It is used to define the premises of a Rule emitted from a surrogate model. """
[docs] def __init__(self, variable: str, operator: Callable, value): """ :param[str] variable: name of the variable that defines the rule :param[Callable] operator: logical operator involved in the rule :param value: numerical value to define the rule. E.g. variable > value """ self.variable = variable self.operator = operator self.value = value
[docs] def operator2string(self): """ it converts the logical operator into a string representation. E.g.: operator2string(operator.gt) = ">") """ operator_strings = {operator.gt: '>', operator.lt: '<', operator.ne: '!=', operator.eq: '=', operator.ge: '>=', operator.le: '<='} if self.operator not in operator_strings: raise ValueError( "logical operator not recognized. Use one of [operator.gt,operator.lt,operator.eq, operator.gte, operator.lte]") return operator_strings[self.operator]
def __str__(self): """ It writes the expression as a string """ return "%s %s %s" % (self.variable, self.operator2string(), self.value)
[docs]class Rule(object):
[docs] def __init__(self, premises: list, consequences: Expression, encoder: EncDec): """ :param [list] premises: list of Expression objects representing the premises :param [Expression] consequences: Expression representing the consequence :param [EncDec] encoder: encoder to decode categorical rules """ self.encoder = encoder self.premises = [self.decode_rule(p) for p in premises] self.consequences = self.decode_rule(consequences)
def _pstr(self): return '{ %s }' % (', '.join([str(p) for p in self.premises])) def _cstr(self): return '{ %s }' % self.consequences def __str__(self): str_out = 'premises:\n' + '%s \n' % ("\n".join([str(e) for e in self.premises])) str_out += 'consequence: %s' % (str(self.consequences)) return str_out def __eq__(self, other): return self.premises == other.premises and self.consequences == other.cons def __len__(self): return len(self.premises) def __hash__(self): return hash(str(self)) def decode_rule(self, rule: Expression): if 'categorical' not in self.encoder.dataset_descriptor.keys() or self.encoder.dataset_descriptor['categorical'] == {}: return rule if rule.variable.split('=')[0] in self.encoder.dataset_descriptor['categorical'].keys(): decoded_label = rule.variable.split("=")[0] decoded_value = rule.variable.split("=")[1] rule.variable = decoded_label rule.value = decoded_value if rule.operator == operator.le: rule.operator = operator.ne if rule.operator == operator.gt: rule.operator = operator.eq return rule else: return rule def is_covered(self, x, feature_names): xd = vector2dict(x, feature_names) for p in self.premises: if p.operator == operator.le and xd[p.variable] > p.value: return False elif p.operator == operator.gt and xd[p.variable] <= p.value: return False return True
class ExpressionEncoder(json.JSONEncoder): """ Special json encoder for Condition types """ def default(self, obj): if isinstance(obj, Expression): json_obj = { 'att': obj.variable, 'op': obj.operator2string(), 'thr': obj.value, } return json_obj return json.JSONEncoder.default(self, obj)
[docs]class RuleEncoder(json.JSONEncoder): """ Special json encoder for Rule types """
[docs] def default(self, obj): if isinstance(obj, Rule): ce = ExpressionEncoder() json_obj = { 'premise': [ce.default(p) for p in obj.premises], 'cons': obj.consequences, } return json_obj return json.JSONEncoder.default(self, obj)