-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaudit_fairness.mpc
53 lines (31 loc) · 1.66 KB
/
audit_fairness.mpc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from Compiler.script_utils import output_utils
from Compiler.script_utils.data import data
from Compiler import ml
from Compiler import library
from Compiler.script_utils.audit import rand_smoothing
from Compiler.script_utils import config, timers
class AuditConfig(config.BaseAuditModel):
type: str = "fairness" # "robustness" or "fairness"
seed: int = 42 # seed chosen by audit requestor to sample perturbations
# parameters taken from paper used on their MNIST model
L: float = 2.97
theta: float = 1 # NOTE: I don't understand how they set this for adult
batch_size: int = 16
n: int = 1024
alpha: float = 0.001 # 1-alpha is the confidence level
program.options_from_args()
cfg = config.from_program_args(program.args, AuditConfig)
assert cfg.n % cfg.batch_size == 0, "n must be divisible by batch_size"
program.use_trunc_pr = cfg.trunc_pr
sfix.round_nearest = cfg.round_nearest
ml.set_n_threads(cfg.n_threads)
library.start_timer(timer_id=timers.TIMER_LOAD_DATA)
input_loader = data.get_input_loader(dataset=cfg.dataset, audit_trigger_idx=cfg.audit_trigger_idx, batch_size=cfg.batch_size,
debug=cfg.debug, emulate=cfg.emulate, consistency_check=cfg.consistency_check, sha3_approx_factor=cfg.sha3_approx_factor,
load_model_weights=True, load_dataset=False)
library.stop_timer(timer_id=timers.TIMER_LOAD_DATA)
library.start_timer(timer_id=timers.TIMER_AUDIT)
result = rand_smoothing.audit(input_loader, config=cfg, debug=cfg.debug)
library.stop_timer(timer_id=timers.TIMER_AUDIT)
for k, v in result.items():
output_utils.output_value_debug(name=k, value=v, repeat=False)