run_blocking_model.py 4.27 KB
Newer Older
Martin Řepa's avatar
Martin Řepa committed
1 2
import os
import pickle
Martin Řepa's avatar
Martin Řepa committed
3
import sys
Martin Řepa's avatar
Martin Řepa committed
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
import time
from dataclasses import dataclass
from os.path import dirname
from pathlib import Path
from typing import List

import torch
import yaml

from src.config import RootConfig
from src.game import Game

CONFIG_FILE = 'configuration.yaml'


@dataclass
class AttackerAction:
    action: List[float]
    prob: float


@dataclass
class DefenderAction:
    model_file_id: str
    prob: float
    loss: float
    fp_part: float


@dataclass
class SubResult:
    legacy_folder: str
    time: float
    zero_sum_game_value: float
    almost_zero_attacker_value: float
    almost_zero_defender_value: float
    attacker_actions: List[AttackerAction]
    defender_actions: List[DefenderAction]


@dataclass
class Experiment:
    sub_results: List[SubResult]


def get_configuration() -> dict:
    with open(Path(dirname(__file__)) / CONFIG_FILE, 'r', encoding='utf-8')\
            as file:
        content = file.read()
    return yaml.load(content, Loader=yaml.FullLoader)


def get_root_conf(conf_of_conf: dict) -> RootConfig:
    conf = RootConfig()
    conf.debug = False
    conf.plot_conf.plot_enabled = False

    conf.model_conf.use_blocking_model = bool(conf_of_conf['use_blocking_model'])
    conf.model_conf.set_ia_id_benign_ration(.5, .5, # W/e these values
                                            conf_of_conf['benign_ratio'])
    conf.model_conf.set_data_file(conf_of_conf['data_file'])
    conf.model_conf.defender_conf.nn_conf.epochs = conf_of_conf['nn_epochs']
    return conf


Martin Řepa's avatar
Martin Řepa committed
69
def main(experiment_conf, base_dir):
Martin Řepa's avatar
Martin Řepa committed
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
    conf = get_root_conf(experiment_conf)
    repetitions = experiment_conf['num_of_experiments']

    sub_results = []
    for i in range(repetitions):
        print(f'Starting {i+1}. iteration.')

        os.mkdir(f'{base_dir}/{i}')

        start = time.time()
        result = Game(conf).solve_game()
        time_taken = time.time() - start

        attacker_actions = []
        for p1_action, p1_prob in zip(result.ordered_actions_p1,
                                      result.probs_p1):
            if p1_prob == 0: continue
            attacker_actions.append(AttackerAction(p1_action, p1_prob))

        defender_actions = []
        for p2_action, p2_prob in zip(result.ordered_actions_p2,
                                      result.probs_p2):
            if p2_prob == 0: continue
            defender_actions.append(DefenderAction(p2_action.id,
                                                   p2_prob,
                                                   p2_action.final_loss,
                                                   p2_action.final_fp_cost))
            torch.save(p2_action.model.state_dict(),
                       f'{base_dir}/{i}/{p2_action.id}.pt')

        sub_result = SubResult(str(i),
                               time_taken,
                               result.zero_sum_nash_val,
                               result.attacker_value,
                               result.defender_value,
                               attacker_actions,
                               defender_actions)
        sub_results.append(sub_result)

    done_experiment = Experiment(sub_results)
    print('Experiment done.')

    # Save the result data
    data_file = f'{base_dir}/data'
    print(f'Saving result data to {data_file} file.')
    with open(data_file, 'wb') as file:
        pickle.dump(done_experiment, file)
    print('File saved.\n')

    # Save the model config just to be sure
    model_config_file = f'{base_dir}/model_config'
    print(f'Saving model config to {model_config_file} file.')
    with open(model_config_file, 'wb') as file:
        conf.__attrs_post_init__ = None
        conf.model_conf.attacker_torch_utility = None
        conf.model_conf.attacker_utility = None
        conf.model_conf.defender_conf.nn_conf.loss_function = None
        pickle.dump(conf, file)
    print('File saved.')
Martin Řepa's avatar
Martin Řepa committed
129 130 131 132 133 134 135 136 137 138 139 140


if __name__ == "__main__":
    print('Starting game theory blocking model experiment')
    experiment_conf = get_configuration()

    base_dir = experiment_conf["legacy_folder"]
    if not os.path.exists(base_dir):
        print(f'Creating base dir {base_dir}')
        os.mkdir(base_dir)

    log_file = f'{base_dir}/log'
141
    with open(log_file, 'a') as log:
Martin Řepa's avatar
Martin Řepa committed
142
        sys.stderr = log
Martin Řepa's avatar
Martin Řepa committed
143 144
        sys.stdout = log
        main(experiment_conf, base_dir)