network.py 9.55 KB
Newer Older
1
import itertools
Martin Řepa's avatar
backup  
Martin Řepa committed
2
import logging
Martin Řepa's avatar
Martin Řepa committed
3
import os
4
import time
5 6
from pathlib import Path

7
import attr
8
import matplotlib.pyplot as plt
9
import numpy as np
Martin Řepa's avatar
backup  
Martin Řepa committed
10 11 12
import torch
from torch import nn
from torch import optim
13

Martin Řepa's avatar
Martin Řepa committed
14 15
from src.config import NeuralNetworkConfig, RootConfig
from src.data.loader import np_matrix_from_scored_csv
16

Martin Řepa's avatar
backup  
Martin Řepa committed
17
logger = logging.getLogger(__name__)
Martin Řepa's avatar
Martin Řepa committed
18
DEVICE = torch.device(os.environ.get('device', 'cuda'))
19

20
@attr.s
21
class FormattedData:
22 23 24 25 26
    unique_x: np.array = attr.ib()
    probs_x: np.array = attr.ib()
    y: np.array = attr.ib()


27 28 29 30 31 32
@attr.s
class BenignData:
    unique_x: np.array = attr.ib()
    counts: np.array = attr.ib()
    y: np.array = attr.ib()

Martin Řepa's avatar
backup  
Martin Řepa committed
33 34 35 36 37 38 39 40
class OrderCounter:
    order = 0

    @staticmethod
    def next():
        OrderCounter.order += 1
        return OrderCounter.order

41

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
class SoftClip(nn.Module):
    """ SoftClipping activation function

    https://arxiv.org/pdf/1810.11509.pdf
    """
    def __init__(self, p=50.0):
        super().__init__()
        self.p = p

    def forward(self, x):
        first_pow = torch.pow(np.e, torch.mul(x, self.p))
        second_pow = torch.pow(np.e, torch.mul(torch.add(x, -1.0), self.p))

        first_div = torch.add(first_pow, 1.0)
        second_div = torch.add(second_pow, 1.0)

        division = torch.div(first_div, second_div)

        second_part_log = torch.log(division)
        first_part = 1.0 / self.p

        if len(second_part_log[torch.isnan(second_part_log)]):
            print("prdel")  # todo remove me

        return torch.mul(second_part_log, first_part)


69
class NeuralNetwork:
70
    def __init__(self, input_features=2,
71
                 nn_conf: NeuralNetworkConfig = NeuralNetworkConfig()):
Martin Řepa's avatar
backup  
Martin Řepa committed
72
        self.model = nn.Sequential(
Martin Řepa's avatar
Martin Řepa committed
73
            nn.Linear(input_features, 5),
74
            nn.ReLU(),
Martin Řepa's avatar
Martin Řepa committed
75
            nn.Linear(5, 5),
76
            nn.ReLU(),
Martin Řepa's avatar
Martin Řepa committed
77
            nn.Linear(5, 5),
78
            nn.ReLU(),
Martin Řepa's avatar
Martin Řepa committed
79
            nn.Linear(5, 1),
80 81 82 83
            nn.Tanh(),
            SoftClip(50)
            # nn.Sigmoid()
        ).to(DEVICE)
84 85
        self._set_weights()
        self.conf = nn_conf
86
        self.id = OrderCounter.next()
87

88
        # Variables used for loss function
89
        self.attacker_actions: FormattedData = None
90
        self.benign_data: BenignData = None
91

92
        # Variables from last training epoch measuring quality
93
        self.final_loss = None
94
        self.final_fp_cost = None
95

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
        # -------------- <TMP> --------------------
        # self.max_constant = constant
        # self.cur_value = .0
        # self.step = .05
        # self.edge_epoch = (self.conf.epochs*0.2)
        # self.incr_each = self.edge_epoch / (self.max_constant/.05)

    # def get_cur_coefficient(self, epoch) -> float:
    #     return self.max_constant
        # if epoch > self.edge_epoch:
        #     return self.max_constant
        # return .01
        # if self.cur_value < self.max_constant and epoch % self.incr_each == 0:
        #     self.cur_value += self.step
        # return self.cur_value
    # -------------- </TMP> --------------------

113
    def __str__(self):
114
        return f'Neural network id:{self.id}, final loss: {self.final_loss}'
115

116
    def set_data(self, benign_data: BenignData, attack: FormattedData):
117 118 119 120 121 122
        self.attacker_actions = attack
        self.benign_data = benign_data

    def _prepare_data(self):
        defender = self.benign_data
        attacker = self.attacker_actions
123
        x = np.concatenate((defender.unique_x, attacker.unique_x), axis=0)
124
        y = np.concatenate((defender.y, attacker.y), axis=0)
125 126
        probs = np.concatenate((defender.counts/np.sum(defender.counts),
                                attacker.probs_x), axis=0)
127

128 129
        self.train_y = torch.cat((torch.zeros(self.conf.batch_size).float(),
                                  torch.tensor(attacker.y).float())).to(DEVICE)
130

131 132 133
        self.all_x = torch.tensor(x).float().to(DEVICE)
        self.all_y = torch.tensor(y).float().to(DEVICE)
        self.all_probs = torch.tensor(probs).float().to(DEVICE)
134

135 136 137
    def _set_weights(self):
        def init_weights(m):
            if type(m) == nn.Linear:
138
                torch.nn.init.xavier_uniform_(m.weight)
139 140 141
                m.bias.data.fill_(.0)
        self.model.apply(init_weights)

142 143 144 145
    def train(self):
        self._prepare_data()
        self._train()

146 147 148 149 150 151 152 153 154 155 156 157 158 159
    def get_train_batch(self):
        batch_idxs = np.random.choice(len(self.benign_data.unique_x),
                                      self.conf.batch_size)
        current_batch_samples = np.sum(self.benign_data.counts[batch_idxs])

        batch_x_np = np.concatenate((self.benign_data.unique_x[batch_idxs],
                                     self.attacker_actions.unique_x), axis=0)
        batch_probs_np = np.concatenate((self.benign_data.counts[batch_idxs]/current_batch_samples,
                                 self.attacker_actions.probs_x), axis=0)

        batch_x = torch.tensor(batch_x_np).float().to(DEVICE)
        batch_probs = torch.tensor(batch_probs_np).float().to(DEVICE)
        return batch_x, batch_probs

160
    def _train(self):
161
        learning_rate = self.conf.learning_rate
Martin Řepa's avatar
backup  
Martin Řepa committed
162
        optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
163

164
        for e in range(self.conf.epochs):
165 166
            batch_x, batch_probs = self.get_train_batch()

Martin Řepa's avatar
Martin Řepa committed
167
            # Forward pass: compute predicted y by passing x to the model
168
            train_ltncies = self.latency_predict(batch_x, with_grad=True)
169

Martin Řepa's avatar
Martin Řepa committed
170
            # Compute loss
171 172
            loss, _ = self.conf.loss_function(batch_x, train_ltncies,
                                              self.train_y, batch_probs)
Martin Řepa's avatar
backup  
Martin Řepa committed
173

Martin Řepa's avatar
Martin Řepa committed
174
            # Log loss function value each 5 epochs
175
            if e % 50 == 0:
176
                logging.debug(f'Epoch: {e}/{self.conf.epochs},\t'
177
                              f'TrainLoss: {loss},\t')
Martin Řepa's avatar
backup  
Martin Řepa committed
178 179 180 181

            # Before the backward pass, use the optimizer object to zero all of
            # the gradients for the variables it will update
            optimizer.zero_grad()
182

Martin Řepa's avatar
backup  
Martin Řepa committed
183 184 185
            # Backward pass: compute gradient of the loss with respect to model
            # parameters
            loss.backward()
186

Martin Řepa's avatar
backup  
Martin Řepa committed
187 188 189
            # Calling the step function on an Optimizer makes an update to its
            # parameters
            optimizer.step()
190

Martin Řepa's avatar
Martin Řepa committed
191
        with torch.no_grad():
192 193 194 195
            train_ltncies = self.latency_predict(self.all_x, with_grad=True)
            loss, fp_part = self.conf.loss_function(self.all_x, train_ltncies,
                                                 self.all_y, self.all_probs)
            logger.debug(f'Final loss of this nn: {loss}\tfp_part is: {fp_part}')
Martin Řepa's avatar
Martin Řepa committed
196 197 198
            # measuring quality of final network
            self.final_loss = loss.item()
            self.final_fp_cost = fp_part.item()
199

200 201 202 203 204 205 206 207 208 209 210
    def plot_tmp(self):
        # todo remove me
        if self.id == 1:
            return
        try:
            actions = self.actions
        except AttributeError:
            self.plotted = []
            self.plotting = plt.subplots()
            one_axis = np.linspace(0, 1, 101)  # [0.00, 0.01, 0.02, ..., 0.99, 1.00]
            generator = itertools.product(*itertools.repeat(one_axis, 2))
211
            actions = torch.tensor(np.array(list(generator))).float().to(DEVICE)
212 213 214 215 216 217 218 219 220 221 222 223
            self.actions = actions
        finally:
            # Remove all lines from previous iteration plotting
            for item in self.plotted:
                item.remove()
            self.plotted = []
            res = self.latency_predict(actions).numpy().reshape((101, 101), order='F')
            self.plotted.append(self.plotting[1].imshow(res, cmap='Reds', vmin=0, vmax=1, origin='lower'))
            self.plotting[0].canvas.draw()
            plt.pause(0.000001)
            time.sleep(1)

224
    def _raw_predict(self, tensor: torch.Tensor):
225 226
        pred = self.model(tensor)
        return pred.flatten().float()
227

228
    def latency_predict(self, x: torch.Tensor, with_grad=False):
229 230 231 232 233
        if with_grad:
                raw_prediction = self._raw_predict(x)
        else:
            with torch.no_grad():
                raw_prediction = self._raw_predict(x)
234

235
        return raw_prediction
Martin Řepa's avatar
backup  
Martin Řepa committed
236

Martin Řepa's avatar
Martin Řepa committed
237
    def predict_single_latency(self, input, return_tensor=False):
238 239 240
        in_type = type(input)
        if in_type == list or in_type == tuple or \
                in_type == np.array or in_type == np.ndarray:
241
            input = torch.tensor(input).float().to(DEVICE)
242

243
        if return_tensor:
244
            return self.latency_predict(input)[0]
245
        else:
246
            return self.latency_predict(input)[0].item()
247

Martin Řepa's avatar
backup  
Martin Řepa committed
248

Martin Řepa's avatar
Martin Řepa committed
249
def setup_loger(debug: bool):
Martin Řepa's avatar
backup  
Martin Řepa committed
250 251
    log_format = ('%(asctime)-15s\t%(name)s:%(levelname)s\t'
                  '%(module)s:%(funcName)s:%(lineno)s\t%(message)s')
Martin Řepa's avatar
Martin Řepa committed
252
    level = logging.DEBUG if debug else logging.INFO
Martin Řepa's avatar
backup  
Martin Řepa committed
253 254 255 256
    logging.basicConfig(level=level, format=log_format)


if __name__ == '__main__':
Martin Řepa's avatar
Martin Řepa committed
257
    setup_loger(True)
258
    benign_x, _ = np_matrix_from_scored_csv(
259
        Path('all_benign_scored.csv'), 0, 500)
260
    malicious_x, _ = np_matrix_from_scored_csv(
261
        Path('scored_malicious.csv'), 1, 400)
262 263 264 265

    benign_unique_x, counts = np.unique(benign_x, axis=0, return_counts=True)
    probs_benign = np.array([count / len(benign_x) for count in counts])
    benign_y = np.zeros(len(benign_unique_x))
266
    benign_data = FormattedData(benign_unique_x, probs_benign, benign_y)
267 268 269 270

    malicious_unique_x, counts = np.unique(malicious_x, axis=0, return_counts=True)
    probs_malicious = np.array([count / len(malicious_unique_x) for count in counts])
    malicious_y = np.ones(len(malicious_unique_x))
271
    malicious_data = FormattedData(malicious_unique_x, probs_malicious, malicious_y)
Martin Řepa's avatar
backup  
Martin Řepa committed
272

Martin Řepa's avatar
Martin Řepa committed
273
    conf = RootConfig()
274
    nn = NeuralNetwork(2, conf.model_conf.defender_conf.nn_conf)
275 276
    nn.set_data(benign_data, malicious_data)
    nn.train()