learning_epochs.py 5.35 KB
Newer Older
Martin Řepa's avatar
Martin Řepa committed
1 2
import os
import pickle
Martin Řepa's avatar
Martin Řepa committed
3
import sys
Martin Řepa's avatar
Martin Řepa committed
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
import time
from os.path import dirname
from pathlib import Path
from typing import List

import torch
import yaml
from dataclasses import dataclass

from src.config import RootConfig
from src.game import Game

CONFIG_FILE = 'configuration.yaml'


@dataclass
class AttackerAction:
    action: List[float]
    prob: float


@dataclass
class DefenderAction:
    model_file_id: str
    prob: float
    loss: float
    fp_part: float


@dataclass
class SubResult:
Martin Řepa's avatar
Martin Řepa committed
35
    iterations: int
Martin Řepa's avatar
Martin Řepa committed
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    legacy_folder: str
    time: float
    zero_sum_game_value: float
    almost_zero_attacker_value: float
    almost_zero_defender_value: float
    attacker_actions: List[AttackerAction]
    defender_actions: List[DefenderAction]


@dataclass
class Setup:
    epochs: int
    results: List[SubResult]


@dataclass
class Experiment:
    setups: List[Setup]


def get_configuration() -> dict:
    with open(Path(dirname(__file__)) / CONFIG_FILE, 'r', encoding='utf-8')\
            as file:
        content = file.read()
    return yaml.load(content, Loader=yaml.FullLoader)


def exec_new_setup(conf: RootConfig, folder: str, iterations: int) -> Setup:
    cur_epochs = conf.model_conf.defender_conf.nn_conf.epochs
    sub_results = []
66 67 68

    i = 0
    while i < iterations:
69 70 71
        if not os.path.exists(f'{folder}/{i}'):
            os.mkdir(f'{folder}/{i}')

Martin Řepa's avatar
Martin Řepa committed
72 73 74
        print(f'Starting {i+1}. iteration of setup with {cur_epochs} epochs')

        start = time.time()
75
        result = Game(conf).solve_game()
Martin Řepa's avatar
Martin Řepa committed
76 77
        time_taken = time.time() - start

78 79 80
        if result.iterations == 1:
            continue

Martin Řepa's avatar
Martin Řepa committed
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
        attacker_actions = []
        for p1_action, p1_prob in zip(result.ordered_actions_p1,
                                      result.probs_p1):
            if p1_prob == 0: continue
            attacker_actions.append(AttackerAction(p1_action, p1_prob))

        defender_actions = []
        for p2_action, p2_prob in zip(result.ordered_actions_p2,
                                      result.probs_p2):
            if p2_prob == 0: continue
            defender_actions.append(DefenderAction(p2_action.id,
                                                   p2_prob,
                                                   p2_action.final_loss,
                                                   p2_action.final_fp_cost))
            torch.save(p2_action.model.state_dict(),
                       f'{folder}/{i}/{p2_action.id}.pt')

Martin Řepa's avatar
Martin Řepa committed
98 99
        sub_result = SubResult(result.iterations,
                               str(i),
Martin Řepa's avatar
Martin Řepa committed
100 101 102 103 104 105 106
                               time_taken,
                               result.zero_sum_nash_val,
                               result.attacker_value,
                               result.defender_value,
                               attacker_actions,
                               defender_actions)
        sub_results.append(sub_result)
107 108
        i += 1

Martin Řepa's avatar
Martin Řepa committed
109 110 111 112 113
    return Setup(cur_epochs, sub_results)


def get_root_conf(conf_of_conf: dict) -> RootConfig:
    conf = RootConfig()
Martin Řepa's avatar
Martin Řepa committed
114
    conf.debug = False
Martin Řepa's avatar
Martin Řepa committed
115 116 117 118 119 120 121 122 123
    conf.plot_conf.plot_enabled = False

    conf.model_conf.set_ia_id_benign_ration(conf_of_conf['i_a'],
                                            conf_of_conf['i_d'],
                                            conf_of_conf['benign_ratio'])
    conf.model_conf.set_data_file(conf_of_conf['data_file'])
    return conf


Martin Řepa's avatar
Martin Řepa committed
124
def main(experiment_conf, base_dir):
Martin Řepa's avatar
Martin Řepa committed
125 126 127 128
    conf = get_root_conf(experiment_conf['conf'])

    cur_epochs = experiment_conf['epochs']['lower_bound']
    upper_bound = experiment_conf['epochs']['upper_bound']
Martin Řepa's avatar
Martin Řepa committed
129
    step = (upper_bound - cur_epochs) / (experiment_conf['epochs']['number_of_steps'] - 1)
Martin Řepa's avatar
Martin Řepa committed
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

    setups = []
    while cur_epochs <= upper_bound:
        print(f'Let\'s do subexperiments for {cur_epochs} number of epochs.')

        folder = f'{base_dir}/{cur_epochs}epochs'
        os.mkdir(folder)
        print(f'Result is gonna be stored in {folder}.')

        conf.model_conf.defender_conf.nn_conf.epochs = cur_epochs
        setup = exec_new_setup(conf, folder,
                               experiment_conf['experiments_per_setup'])
        setups.append(setup)

        # Increment epoch
        cur_epochs = int(cur_epochs + step)

    done_experiment = Experiment(setups)
    print('Experiment done.')

    # Save the result data
    data_file = f'{base_dir}/data'
    print(f'Saving result data to {data_file} file.')
    with open(data_file, 'wb') as file:
        pickle.dump(done_experiment, file)
    print('File saved.\n')

    # Save the model config just to be sure
    model_config_file = f'{base_dir}/model_config'
    print(f'Saving model config to {model_config_file} file.')
    with open(model_config_file, 'wb') as file:
        conf.__attrs_post_init__ = None
        conf.model_conf.attacker_torch_utility = None
        conf.model_conf.attacker_utility = None
        conf.model_conf.defender_conf.nn_conf.loss_function = None
        pickle.dump(conf, file)
    print('File saved.')
Martin Řepa's avatar
Martin Řepa committed
167 168 169 170 171 172 173 174 175 176 177


if __name__ == "__main__":
    print('Starting learning epochs experiment')
    experiment_conf = get_configuration()

    base_dir = experiment_conf["legacy_folder"]
    if not os.path.exists(base_dir):
        print(f'Creating base dir {base_dir}')
        os.mkdir(base_dir)

Martin Řepa's avatar
Martin Řepa committed
178
    log_file = f'{base_dir}/log'
179
    with open(log_file, 'a') as log:
Martin Řepa's avatar
Martin Řepa committed
180
        sys.stderr = log
Martin Řepa's avatar
Martin Řepa committed
181 182 183
        sys.stdout = log
        main(experiment_conf, base_dir)
    # main(experiment_conf, base_dir)