-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaudit_prediction_shap.mpc
56 lines (34 loc) · 1.63 KB
/
audit_prediction_shap.mpc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from Compiler.script_utils import output_utils
from Compiler.script_utils.data import data
from Compiler import ml
from Compiler import library
from Compiler.script_utils.audit import shap
from Compiler.script_utils import config, timers
class AuditConfig(config.BaseAuditModel):
test_var: int = 1
# n_batches: int = 10
program.options_from_args()
cfg = config.from_program_args(program.args, AuditConfig)
if not cfg.emulate:
program.use_trunc_pr = cfg.trunc_pr
# program.use_edabits = True
# program.use_split(4)
ml.set_n_threads(cfg.n_threads)
library.start_timer(timer_id=timers.TIMER_LOAD_DATA)
input_shape_size = cfg.batch_size * cfg.n_batches if cfg.n_batches > 0 else None
input_loader = data.get_input_loader(dataset=cfg.dataset, audit_trigger_idx=cfg.audit_trigger_idx, batch_size=cfg.batch_size,
debug=cfg.debug, emulate=cfg.emulate, consistency_check=cfg.consistency_check,
sha3_approx_factor=cfg.sha3_approx_factor, input_shape_size=input_shape_size)
library.stop_timer(timer_id=timers.TIMER_LOAD_DATA)
# expect 91 * (13136) + 13136 + (91 * 32) + 2*32 + 1
library.start_timer(timer_id=timers.TIMER_AUDIT)
result, debug_output = shap.audit(input_loader, cfg, debug=cfg.debug)
# explanation algorithm, shap
# ...
library.stop_timer(timer_id=timers.TIMER_AUDIT)
print("Done with audit, outputting results")
# REVEALING MIGHT TAKE A LOT OF COMPILATION TIME / OVERHEAD
for k, v in result.items():
output_utils.output_value(name=k, value=v)
for k, v in debug_output.items():
output_utils.output_value_debug(name=k, value=v, repeat=False)