Commit 49a8c25a authored by Martin Řepa's avatar Martin Řepa

Add configurable i_a and i_d

parent 4e74b2d9
......@@ -17,4 +17,4 @@ matplotlib = "*"
torch = "*"
[requires]
python_version = "3.6"
python_version = "3.7"
{
"_meta": {
"hash": {
"sha256": "7d6b47e2ff2ba43cd8f50922d385f306e1876ba3f4a3e67e3c74c60ee152787f"
"sha256": "45798a30df6e2c3a6012896cf83a70b58a21f44ca2a3ee42af00d15432d1917a"
},
"pipfile-spec": 6,
"requires": {
"python_version": "3.6"
"python_version": "3.7"
},
"sources": [
{
......@@ -18,9 +18,9 @@
"default": {
"absl-py": {
"hashes": [
"sha256:8718189e4bd6013bf79910b9d1cb0a76aecad8ce664f78e1144980fabdd2cd23"
"sha256:b943d1c567743ed0455878fcd60bc28ac9fae38d129d1ccfad58079da00b8951"
],
"version": "==0.7.0"
"version": "==0.7.1"
},
"astor": {
"hashes": [
......@@ -39,11 +39,11 @@
},
"attrs": {
"hashes": [
"sha256:10cbf6e27dbce8c30807caf056c8eb50917e0eaafe86347671b57254006c3e69",
"sha256:ca4be454458f9dec299268d472aaa5a11f67a4ff70093396e1ceae9c76cf4bbb"
"sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79",
"sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399"
],
"index": "pypi",
"version": "==18.2.0"
"version": "==19.1.0"
},
"cycler": {
"hashes": [
......@@ -184,22 +184,22 @@
},
"matplotlib": {
"hashes": [
"sha256:16aa61846efddf91df623bbb4598e63be1068a6b6a2e6361cc802b41c7a286eb",
"sha256:1975b71a33ac986bb39b6d5cfbc15c7b1f218f1134efb4eb3881839d6ae69984",
"sha256:2b222744bd54781e6cc0b717fa35a54e5f176ba2ced337f27c5b435b334ef854",
"sha256:317643c0e88fad55414347216362b2e229c130edd5655fea5f8159a803098468",
"sha256:4269ce3d1b897d46fc3cc2273a0cc2a730345bb47e4456af662e6fca85c89dd7",
"sha256:65214fd668975077cdf8d408ccf2b2d6bdf73b4e6895a79f8e99ce4f0b43fcdb",
"sha256:74bc213ab8a92d86a0b304d9359d1e1d14168d4c6121b83862c9d8a88b89a738",
"sha256:88949be0db54755995dfb0210d0099a8712a3c696c860441971354c3debfc4af",
"sha256:8e1223d868be89423ec95ada5f37aa408ee64fe76ccb8e4d5f533699ba4c0e4a",
"sha256:9fa00f2d7a552a95fa6016e498fdeb6d74df537853dda79a9055c53dfc8b6e1a",
"sha256:c27fd46cab905097ba4bc28d5ba5289930f313fb1970c9d41092c9975b80e9b4",
"sha256:c94b792af431f6adb6859eb218137acd9a35f4f7442cea57e4a59c54751c36af",
"sha256:f4c12a01eb2dc16693887a874ba948b18c92f425c4d329639ece6d3bb8e631bb"
"sha256:1ae6549976b6ceb6ee426272a28c0fc9715b3e3669694d560c8f661c5b39e2c5",
"sha256:4d4250bf508dd07cca3b43888097f873cadb66eec6ac63dbbfb798798ec07af2",
"sha256:53af2e01d7f1700ed2b64a9091bc865360c9c4032f625451c4589a826854c787",
"sha256:63e498067d32d627111cd1162cae1621f1221f9d4c6a9745dd7233f29de581b6",
"sha256:7169a34971e398dd58e87e173f97366fd88a3fa80852704530433eb224a8ca57",
"sha256:91c54d6bb9eeaaff965656c5ea6cbdcbf780bad8462ac99b30b451548194746f",
"sha256:aeef177647bb3fccfe09065481989d7dfc5ac59e9367d6a00a3481062cf651e4",
"sha256:cf8ae10559a78aee0409ede1e9d4fda03895433eeafe609dd9ed67e45f552db0",
"sha256:d51d0889d1c4d51c51a9822265c0494ea3e70a52bdd88358e0863daca46fa23a",
"sha256:de5ccd3500247f85fe4f9fad90f80a8bd397e4f110a4c33fabf95f07403e8372",
"sha256:e1d33589e32f482d0a7d1957bf473d43341115d40d33f578dad44432e47df7b7",
"sha256:e8d1939262aa6b36d0c51f50a50a43a04b9618d20db31e6c0192b1463067aeef",
"sha256:e918d51b1fda82a65fdf52d2f3914b2246481cc2a9cd10e223e6be6078916ff3"
],
"index": "pypi",
"version": "==3.0.2"
"version": "==3.0.3"
},
"mock": {
"hashes": [
......@@ -239,57 +239,59 @@
},
"pandas": {
"hashes": [
"sha256:02c830f951f3dc8c3164e2639a8961881390f7492f71a7835c2330f54539ad57",
"sha256:179015834c72a577486337394493cc2969feee9a04a2ea09f50c724e4b52ab42",
"sha256:3894960d43c64cfea5142ac783b101362f5008ee92e962392156a3f8d1558995",
"sha256:435821cb2501eabbcee7e83614bd710940dc0cf28b5afbc4bdb816c31cec71af",
"sha256:8294dea9aa1811f93558702856e3b68dd1dfd7e9dbc8e0865918a07ee0f21c2c",
"sha256:844e745ab27a9a01c86925fe776f9d2e09575e65f0bf8eba5090edddd655dffc",
"sha256:a08d49f5fa2a2243262fe5581cb89f6c0c7cc525b8d6411719ab9400a9dc4a82",
"sha256:a435c251246075337eb9fdc4160fd15c8a87cc0679d8d61fb5255d8d5a12f044",
"sha256:a799f03c0ec6d8687f425d7d6c075e8055a9a808f1ba87604d91f20507631d8d",
"sha256:aea72ce5b3a016b578cc05c04a2f68d9cafacf5d784b6fe832e66381cb62c719",
"sha256:c145e94c6da2af7eaf1fd827293ac1090a61a9b80150bebe99f8966a02378db9",
"sha256:c8a7b470c88c779301b73b23cabdbbd94b83b93040b2ccffa409e06df23831c0",
"sha256:c9e31b36abbd7b94c547d9047f13e1546e3ba967044cf4f9718575fcb7b81bb6",
"sha256:d960b7a03c33c328c723cfc2f8902a6291645f4efa0a5c1d4c5fa008cdc1ea77",
"sha256:da21fae4c173781b012217c9444f13c67449957a4d45184a9718268732c09564",
"sha256:db26c0fea0bd7d33c356da98bafd2c0dfb8f338e45e2824ff8f4f3e61b5c5f25",
"sha256:dc296c3f16ec620cfb4daf0f672e3c90f3920ece8261b2760cd0ebd9cd4daa55",
"sha256:e8da67cb2e9333ec30d53cfb96e27a4865d1648688e5471699070d35d8ab38cf",
"sha256:fb4f047a63f91f22aade4438aaf790400b96644e802daab4293e9b799802f93f",
"sha256:fef9939176cba0c2526ebeefffb8b9807543dc0954877b7226f751ec1294a869"
"sha256:071e42b89b57baa17031af8c6b6bbd2e9a5c68c595bc6bf9adabd7a9ed125d3b",
"sha256:17450e25ae69e2e6b303817bdf26b2cd57f69595d8550a77c308be0cd0fd58fa",
"sha256:17916d818592c9ec891cbef2e90f98cc85e0f1e89ed0924c9b5220dc3209c846",
"sha256:2538f099ab0e9f9c9d09bbcd94b47fd889bad06dc7ae96b1ed583f1dc1a7a822",
"sha256:366f30710172cb45a6b4f43b66c220653b1ea50303fbbd94e50571637ffb9167",
"sha256:42e5ad741a0d09232efbc7fc648226ed93306551772fc8aecc6dce9f0e676794",
"sha256:4e718e7f395ba5bfe8b6f6aaf2ff1c65a09bb77a36af6394621434e7cc813204",
"sha256:4f919f409c433577a501e023943e582c57355d50a724c589e78bc1d551a535a2",
"sha256:4fe0d7e6438212e839fc5010c78b822664f1a824c0d263fd858f44131d9166e2",
"sha256:5149a6db3e74f23dc3f5a216c2c9ae2e12920aa2d4a5b77e44e5b804a5f93248",
"sha256:627594338d6dd995cfc0bacd8e654cd9e1252d2a7c959449228df6740d737eb8",
"sha256:83c702615052f2a0a7fb1dd289726e29ec87a27272d775cb77affe749cca28f8",
"sha256:8c872f7fdf3018b7891e1e3e86c55b190e6c5cee70cab771e8f246c855001296",
"sha256:90f116086063934afd51e61a802a943826d2aac572b2f7d55caaac51c13db5b5",
"sha256:a3352bacac12e1fc646213b998bce586f965c9d431773d9e91db27c7c48a1f7d",
"sha256:bcdd06007cca02d51350f96debe51331dec429ac8f93930a43eb8fb5639e3eb5",
"sha256:c1bd07ebc15285535f61ddd8c0c75d0d6293e80e1ee6d9a8d73f3f36954342d0",
"sha256:c9a4b7c55115eb278c19aa14b34fcf5920c8fe7797a09b7b053ddd6195ea89b3",
"sha256:cc8fc0c7a8d5951dc738f1c1447f71c43734244453616f32b8aa0ef6013a5dfb",
"sha256:d7b460bc316064540ce0c41c1438c416a40746fd8a4fb2999668bf18f3c4acf1"
],
"index": "pypi",
"version": "==0.24.1"
"version": "==0.24.2"
},
"pbr": {
"hashes": [
"sha256:a7953f66e1f82e4b061f43096a4bcc058f7d3d41de9b94ac871770e8bdd831a2",
"sha256:d717573351cfe09f49df61906cd272abaa759b3e91744396b804965ff7bff38b"
"sha256:8257baf496c8522437e8a6cfe0f15e00aedc6c0e0e7c9d55eeeeab31e0853843",
"sha256:8c361cc353d988e4f5b998555c88098b9d5964c2e11acf7b0d21925a66bb5824"
],
"version": "==5.1.2"
"version": "==5.1.3"
},
"protobuf": {
"hashes": [
"sha256:10394a4d03af7060fa8a6e1cbf38cea44be1467053b0aea5bbfcb4b13c4b88c4",
"sha256:1489b376b0f364bcc6f89519718c057eb191d7ad6f1b395ffd93d1aa45587811",
"sha256:1931d8efce896981fe410c802fd66df14f9f429c32a72dd9cfeeac9815ec6444",
"sha256:196d3a80f93c537f27d2a19a4fafb826fb4c331b0b99110f985119391d170f96",
"sha256:46e34fdcc2b1f2620172d3a4885128705a4e658b9b62355ae5e98f9ea19f42c2",
"sha256:4b92e235a3afd42e7493b281c8b80c0c65cbef45de30f43d571d1ee40a1f77ef",
"sha256:574085a33ca0d2c67433e5f3e9a0965c487410d6cb3406c83bdaf549bfc2992e",
"sha256:59cd75ded98094d3cf2d79e84cdb38a46e33e7441b2826f3838dcc7c07f82995",
"sha256:5ee0522eed6680bb5bac5b6d738f7b0923b3cafce8c4b1a039a6107f0841d7ed",
"sha256:65917cfd5da9dfc993d5684643063318a2e875f798047911a9dd71ca066641c9",
"sha256:685bc4ec61a50f7360c9fd18e277b65db90105adbf9c79938bd315435e526b90",
"sha256:92e8418976e52201364a3174e40dc31f5fd8c147186d72380cbda54e0464ee19",
"sha256:9335f79d1940dfb9bcaf8ec881fb8ab47d7a2c721fb8b02949aab8bbf8b68625",
"sha256:a7ee3bb6de78185e5411487bef8bc1c59ebd97e47713cba3c460ef44e99b3db9",
"sha256:ceec283da2323e2431c49de58f80e1718986b79be59c266bb0509cbf90ca5b9e",
"sha256:fcfc907746ec22716f05ea96b7f41597dfe1a1c088f861efb8a0d4f4196a6f10"
],
"version": "==3.6.1"
"sha256:03666634d038e35d90155756914bc3a6316e8bcc0d300f3ee539e586889436b9",
"sha256:049d5900e442d4cc0fd2afd146786b429151e2b29adebed28e6376026ab0ee0b",
"sha256:0eb9e62a48cc818b1719b5035042310c7e4f57b01f5283b32998c68c2f1c6a7c",
"sha256:255d10c2c9059964f6ebb5c900a830fc8a089731dda94a5cc873f673193d208b",
"sha256:358cc59e4e02a15d3725f204f2eb5777fc10595e2d9a9c4c8d82292f49af6d41",
"sha256:41f1b737d5f97f1e2af23d16fac6c0b8572f9c7ea73054f1258ca57f4f97cb80",
"sha256:6a5129576a2cf925cd100e06ead5f9ae4c86db70a854fb91cedb8d680112734a",
"sha256:80722b0d56dcb7ca8f75f99d8dadd7c7efd0d2265714d68f871ed437c32d82b3",
"sha256:88a960e949ec356f7016d84f8262dcff2b842fca5355b4c1be759f5c103b19b3",
"sha256:97872686223f47d95e914881cb0ca46e1bc622562600043da9edddcb54f2fe1e",
"sha256:a1df9d22433ab44b7c7e0bd33817134832ae8a8f3d93d9b9719fc032c5b20e96",
"sha256:ad385fbb9754023d17be14dd5aa67efff07f43c5df7f93118aef3c20e635ea19",
"sha256:b2d5ee7ba5c03b735c02e6ae75fd4ff8c831133e7ca078f2963408dc7beac428",
"sha256:c8c07cd8635d45b28ec53ee695e5ac8b0f9d9a4ae488a8d8ee168fe8fc75ba43",
"sha256:d44ebc9838b183e8237e7507885d52e8d08c48fdc953fd4a7ee3e56cb9d20977",
"sha256:dff97b0ee9256f0afdfc9eaa430736cdcdc18899d9a666658f161afd137cf93d",
"sha256:e47d248d614c68e4b029442de212bdd4f6ae02ae36821de319ae90314ea2578c",
"sha256:e650b521b429fed3d525428b1401a40051097a5a92c30076c91f36b31717e087"
],
"version": "==3.7.0"
},
"pulp": {
"hashes": [
......@@ -321,36 +323,36 @@
},
"scikit-learn": {
"hashes": [
"sha256:05d061606657af85365b5f71484e3362d924429edde17a90068960843ad597f5",
"sha256:071317afbb5c67fa493635376ddd724b414290255cbf6947c1155846956e93f7",
"sha256:0d03aaf19a25e59edac3099cda6879ba05129f0fa1e152e23b728ccd36104f57",
"sha256:1665ea0d4b75ef24f5f2a9d1527b7296eeabcbe3a1329791c954541e2ebde5a2",
"sha256:24eccb0ff31f84e88e00936c09197735ef1dcabd370aacb10e55dbc8ee464a78",
"sha256:27b48cabacce677a205e6bcda1f32bdc968fbf40cd2aa0a4f52852f6997fce51",
"sha256:2c51826b9daa87d7d356bebd39f8665f7c32e90e3b21cbe853d6c7f0d6b0d23b",
"sha256:3116299d392bd1d054655fa2a740e7854de87f1d573fa85503e64494e52ac795",
"sha256:3771861abe1fd1b2bbeaec7ba8cfca58fdedd75d790f099960e5332af9d1ff7a",
"sha256:473ba7d9a5eaec47909ee83d74b4a3be47a44505c5189d2cab67c0418cd030f1",
"sha256:621e2c91f9afde06e9295d128cb15cb6fc77dc00719393e9ec9d47119895b0d4",
"sha256:645865462c383e5faad473b93145a8aee97d839c9ad1fd7a17ae54ec8256d42b",
"sha256:80e2276d4869d302e84b7c03b5bac4a67f6cd331162e62ae775a3e5855441a60",
"sha256:84d2cfe0dee3c22b26364266d69850e0eb406d99714045929875032f91d3c918",
"sha256:87ea9ace7fe811638dfc39b850b60887509b8bfc93c4006d5552fa066d04ddc7",
"sha256:a4d1e535c75881f668010e6e53dfeb89dd50db85b05c5c45af1991c8b832d757",
"sha256:a4f14c4327d2e44567bfb3a0bee8c55470f820bc9a67af3faf200abd8ed79bf2",
"sha256:a7b3c24e193e8c6eaeac075b5d0bb0a7fea478aa2e4b991f6a7b030fc4fd410d",
"sha256:ab2919aca84f1ac6ef60a482148eec0944364ab1832e63f28679b16f9ef279c8",
"sha256:b0f79d5ff74f3c68a4198ad5b4dfa891326b5ce272dd064d11d572b25aae5b43",
"sha256:bc5bc7c7ee2572a1edcb51698a6caf11fae554194aaab9a38105d9ec419f29e6",
"sha256:bc5c750d548795def79576533f8f0f065915f17f48d6e443afce2a111f713747",
"sha256:c68969c30b3b2c1fe07c1376110928eade61da4fc29c24c9f1a89435a7d08abe",
"sha256:d3b4f791d2645fe936579d61f1ff9b5dcf0c8f50db7f0245ca8f16407d7a5a46",
"sha256:dac0cd9fdd8ac6dd6108a10558e2e0ca1b411b8ea0a3165641f9ab0b4322df4e",
"sha256:eb7ddbdf33eb822fdc916819b0ab7009d954eb43c3a78e7dd2ec5455e074922a",
"sha256:ed537844348402ed53420187b3a6948c576986d0b2811a987a49613b6a26f29e",
"sha256:fcca54733e692fe03b8584f7d4b9344f4b6e3a74f5b326c6e5f5e9d2504bdce7"
],
"version": "==0.20.2"
"sha256:018f470a7e685767d84ce6fac87af59e064e87ec3cea71eaf12646f9538e293d",
"sha256:0ae00d570331b8a5c552f721167818b4739a5c855fbc76b11231ccdea2dd26ab",
"sha256:13079520dd8211967d1871e439b59818d335439672818e9683847091d0e07778",
"sha256:1c133749a526b33af2b6695d94d2cc43ba212c5aa7bd3a45619335556ced7637",
"sha256:382e7053567b7b11e862782e3de2940e2141be24e6262aa0b4a9cb7fdd61f85a",
"sha256:384df81fdba12d21063072f2cf472a7a8425a3d4fa3915faef0a88e94e07b332",
"sha256:4705073de7bbcc6b9cd2f24dc9189aa8d3935e8621d3e65546c4b7fee9a042bf",
"sha256:4f829d6c09b997e1d0a998f970cf3ff82cd6796d56148c63c29174367878d490",
"sha256:51a933224b1b11986d4c7c123e5b28eb69602899d0179e6888b7abf2ffc85265",
"sha256:63ad98c6512b52aebde9bd806ec1127e13e2a8d42a00ebdf805153819f7c2cad",
"sha256:67e15514c9df4c5354b3ecc89451f5baa0f1b62c7ed68f4d20febf9c9d9e17a6",
"sha256:75f0e0e93851b30639baabfc1a4433aabc57eef269d55ee4c6f649fb60686218",
"sha256:89609708e819342dd5c94617fd53a36187d7d6a80435ddb282f6a60b058dbe77",
"sha256:8ca274d4e91685e4547af718b6f1e9a9d4912c7a6dcb0c68925de84f81a09d2a",
"sha256:9987f3d31efc427ebf9926f703e5171552cfb3b6935f880e4f0d3a17b7f91540",
"sha256:9f3e08dbd3f2f574913faba9b48d3c24a43fcc0eb14a0e962431005434b9cfe6",
"sha256:a7a403bcea250cac37971058fca0c30b0144737a375f99d3855e5e7a34c43348",
"sha256:ad7e4e823db1271d344e0c3ce0988b2e0fecc49079eec9c818d866c38b2824bd",
"sha256:b1e9037a582e650d866324a50d2741724ea5f6c175200bef0b549d014898035a",
"sha256:b82fbd8843ead2640158b2c0946d354b66f3d49472e6790d70c4ceec35663b3f",
"sha256:b91c82bfd25145d428de99429de97d7a1c2c2658c212689fe2839b29a5251159",
"sha256:ba57b73ec7074f60bb85f953296df437784d560553d0cc04b253c43f1846ccad",
"sha256:c503802a81de18b8b4d40d069f5e363795ee44b1605f38bc104160ca3bfe2c41",
"sha256:d30e8e0dffbc299533f47044fec26c5087473cb29cf51f1995986ac8354c7b4c",
"sha256:d89b810bfb0e16a0de7f18773849bdf83dd7fd0614ae5225e5a9214cdb9be245",
"sha256:e22e1d47def2944ad7a12c09452de085587ec5baad2174683e56a42b6918a76f",
"sha256:f650ddc023c95681fccd5e297820f35de039e008265040c08188be95b3275a0f",
"sha256:f7d4b3885ad1a7a6f07719ab6b1790d9892d6d41d973e8d4543a93bb15226fb4"
],
"version": "==0.20.3"
},
"scipy": {
"hashes": [
......@@ -401,10 +403,10 @@
},
"tensorboard": {
"hashes": [
"sha256:2ecfad35284e91d7c76945245c535245ba6900b0596d5c126d5b4ae3b434fb62",
"sha256:82c9c711b76949b7b3794fc319dc3d3b0fad25f7c0c5260ec4a8371b02d23da6"
"sha256:53d8f40589c903dae65f39a799c2bc49defae3703754984d90613d26ebd714a4",
"sha256:b664fe7772be5670d8b04200342e681af7795a12cd752709aed565c06c0cc196"
],
"version": "==1.13.0"
"version": "==1.13.1"
},
"tensorflow": {
"hashes": [
......
......@@ -24,6 +24,7 @@ class Attacker:
self.conf = model_conf.attacker_conf
self.features_count = model_conf.features_count
self.utility = model_conf.attacker_utility
self.torch_utility = model_conf.attacker_torch_utility
self.actions: np.array = None
if not self.conf.use_gradient_descent:
......@@ -39,14 +40,6 @@ class Attacker:
return self._gradient_best_response(actions, probs)
else:
return self._discrete_best_response(actions, probs)
# # TMP
# optimal = self._discrete_best_response(def_actions, def_probs)
# gradient_brp = self._gradient_best_response(def_actions, def_probs)
#
# if list(map(lambda a: round(a, 2), gradient_brp)) != optimal:
# print("A je to v píči")
#
# return optimal
def _discrete_best_response(self, def_actions: List, def_probs: List) -> List:
best_rp = max(self.actions, key=lambda a1: sum(map(operator.mul, map(
......@@ -68,13 +61,9 @@ class Attacker:
# logger.debug(f'Epoch {i} in attacker best response searching')
loss = 0
for nn, prob in zip(def_actions, def_probs):
prediction = nn._limit_predict(attacker_action,
with_grad=True)
# Attacker wants to maximize its gain, but optimiser tries
# to minimize. That's why we negate the objective function
loss += -(torch.add(1, -prediction) * prob *
torch.prod(attacker_action))
loss += -(self.torch_utility(attacker_action, nn) * prob)
# Calculate gradient and update the value
optimizer.zero_grad()
......@@ -85,7 +74,7 @@ class Attacker:
attacker_action.data.clamp_(min=0.0, max=1.0)
action = [attacker_action[i].item() for i in range(self.features_count)]
action_gain = - loss.item() # Negate the loss again
action_gain = - loss.item() # Negate the loss again for correct value
return action, action_gain
def _gradient_best_response(self, def_actions: List, def_probs: List) -> List:
......
......@@ -2,7 +2,8 @@ from typing import Callable
import attr
from utility import attacker_rate_limit_utility
from utility import get_attacker_utility, get_attacker_torch_grad_utility, \
get_nn_loss_function
@attr.s
......@@ -13,6 +14,9 @@ class NeuralNetworkConfig:
# Learning rate for Adam optimiser
learning_rate = 0.5e-1
# Loss function used for training
loss_function: Callable = attr.ib(init=False)
@attr.s
class DefenderConfig:
......@@ -58,11 +62,12 @@ class AttackerConfig:
# Attention. Used only when use_gradient_descent is set to True!
epochs = 500
@attr.s
class ModelConfig:
# Name of .csv file in src/data/scored directory with scored data which will
# be used as benign data in neural network training phase
benign_data_file_name: str = attr.ib(default='test.csv') # all_benign_scored.csv
benign_data_file_name: str = attr.ib(default='test.csv')
# Number of benign records to be loaded
benign_data_count: int = attr.ib(default=1000)
......@@ -76,9 +81,25 @@ class ModelConfig:
# Defender
defender_conf: DefenderConfig = attr.ib(default=DefenderConfig())
# i_a
i_a: int = attr.ib(default=1)
# i_d
i_d: int = attr.ib(default=4)
# Function to calculate utility for attacker given the actions
# f: List[float], NeuralNetwork -> float
attacker_utility: Callable = attr.ib(default=attacker_rate_limit_utility)
attacker_utility: Callable = attr.ib(init=False)
# Attacker utility function using torch tensors with gradient property
# Used for attacker to find best response via gradient descent
attacker_torch_utility: Callable = attr.ib(init=False)
def __attrs_post_init__(self):
self.attacker_utility = get_attacker_utility(self.i_a)
self.attacker_torch_utility = get_attacker_torch_grad_utility(self.i_a)
self.defender_conf.nn_conf.loss_function = get_nn_loss_function(
self.i_a, self.i_d)
@attr.s
......@@ -95,4 +116,4 @@ class RootConfig:
if __name__ == "__main__":
pass
a = RootConfig()
import logging
from pathlib import Path
from typing import Callable
import attr
import numpy as np
......@@ -60,15 +61,6 @@ class NeuralNetwork:
self.attacker_actions = attack
self.benign_data = benign_data
def loss_function(self, x, limits, real_y, probs):
zero_sum_part = torch.sum(real_y*(1-limits)*torch.prod(x, dim=1)*probs)
fp_cost = self._fp_cost_tensor(limits, real_y, probs)
sum_loss = torch.add(zero_sum_part, fp_cost)
return sum_loss
def _fp_cost_tensor(self, limits, real_y, probs):
return torch.sum((1-real_y) * probs * torch.pow(limits, 4))
def _prepare_data(self):
defender = self.benign_data
attacker = self.attacker_actions
......@@ -100,15 +92,14 @@ class NeuralNetwork:
optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
for e in range(self.conf.epochs):
# Forward pass: compute predicted y by passing x to the model.
train_limits = self._limit_predict(self.x_train, with_grad=True)
# Forward pass: compute predicted y by passing x to the model
train_ltncies = self._latency_predict(self.x_train, with_grad=True)
# Compute loss.
loss = self.loss_function(self.x_train, train_limits, self.y_train,
self.probs_train)
# loss = self.loss_fn(train_limits, self.y_train)
# Compute loss
loss, _ = self.conf.loss_function(self.x_train, train_ltncies,
self.y_train, self.probs_train)
# Compute validation loss and report some info
# Log loss function value each 5 epochs
if e % 5 == 0:
logging.debug(f'Epoch: {e}/{self.conf.epochs},\t'
f'TrainLoss: {loss},\t')
......@@ -125,15 +116,18 @@ class NeuralNetwork:
# parameters
optimizer.step()
self.final_fp_cost = self._fp_cost_tensor(train_limits, self.y_train,
self.probs_train).item()
self.final_loss = loss
with torch.no_grad():
loss, fp_part = self.loss_function(self.x_train, train_ltncies,
self.y_train, self.probs_train)
# measuring quality of final network
self.final_loss = loss.item()
self.final_fp_cost = fp_part.item()
def _raw_predict(self, tensor: torch.Tensor):
pred = self.model(tensor)
return pred.flatten().float()
def _limit_predict(self, x: torch.Tensor, with_grad=False):
def _latency_predict(self, x: torch.Tensor, with_grad=False):
if with_grad:
raw_prediction = self._raw_predict(x)
else:
......@@ -143,30 +137,30 @@ class NeuralNetwork:
# The same as lambda p: 0 if p < 0.5 else (p - 0.5) * 2
# TODO try to use e.g. sigmoid
clamped = raw_prediction.clamp(min=0.5, max=1)
limit = torch.mul(torch.add(clamped, -0.5), 2)
return limit
latency = torch.mul(torch.add(clamped, -0.5), 2)
return latency
def predict_single_limit(self, input, return_tensor=False):
def predict_single_latency(self, input, return_tensor=False):
in_type = type(input)
if in_type == list or in_type == tuple or \
in_type == np.array or in_type == np.ndarray:
input = torch.tensor(input).float()
if return_tensor:
return self._limit_predict(input)[0]
return self._latency_predict(input)[0]
else:
return self._limit_predict(input)[0].item()
return self._latency_predict(input)[0].item()
def setup_loger(conf):
def setup_loger(debug: bool):
log_format = ('%(asctime)-15s\t%(name)s:%(levelname)s\t'
'%(module)s:%(funcName)s:%(lineno)s\t%(message)s')
level = logging.DEBUG if conf.base_conf.debug else logging.INFO
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(level=level, format=log_format)
if __name__ == '__main__':
setup_loger(RootConfig())
setup_loger(True)
benign_x, _ = np_arrays_from_scored_csv(
Path('all_benign_scored.csv'), 0, 500)
malicious_x, _ = np_arrays_from_scored_csv(
......@@ -182,7 +176,8 @@ if __name__ == '__main__':
malicious_y = np.ones(len(malicious_unique_x))
malicious_data = FormattedData(malicious_unique_x, probs_malicious, malicious_y)
nn = NeuralNetwork()
conf = RootConfig()
nn = NeuralNetwork(conf.model_conf.nn_loss_function)
nn.set_data(benign_data, malicious_data)
nn.train()
import functools
import operator
from typing import List
from typing import TYPE_CHECKING
import numpy as np
# Hack to avoid cycle imports while using type checking
# The TYPE_CHECKING constant is always False at runtime
import torch
if TYPE_CHECKING:
from src.neural_networks.network import NeuralNetwork
def attacker_rate_limit_utility(attacker_features: List[float], defender_network: 'NeuralNetwork'):
pred = defender_network.predict_single_limit(attacker_features)
return functools.reduce(operator.mul, attacker_features, 1) * (1 - pred)
def get_attacker_utility(i_a: int):
def attacker_utility(attacker_features: List[float], nn: 'NeuralNetwork'):
pred = nn.predict_single_latency(attacker_features)
return np.product(attacker_features) * (1 - pred) ** i_a
return attacker_utility
def get_attacker_torch_grad_utility(i_a: int):
def attacker_utility(attacker_features: torch.tensor, nn: 'NeuralNetwork'):
latency = nn._latency_predict(attacker_features, with_grad=True)
return torch.pow(torch.add(1, -latency), i_a) \
* torch.prod(attacker_features)
return attacker_utility
def get_nn_loss_function(i_a: int, i_d: int):
def loss_function(x, latencies, real_y, probs):
zero_sum_part = torch.sum(real_y*((1-latencies)**i_a)*torch.prod(x, dim=1)*probs)
fp_cost = torch.sum((1-real_y) * probs * torch.pow(latencies, i_d))
return torch.add(zero_sum_part, fp_cost), fp_cost
return loss_function
......@@ -43,7 +43,7 @@ class Plotter:
for nn, prob in self.defenders:
if prob == 0:
continue
pred = nn.predict_single_limit(point)
pred = nn.predict_single_latency(point)
if pred:
sum_prob += prob
......@@ -63,7 +63,7 @@ class Plotter:
plt.xlabel('entropy')
plt.ylabel('length')
for point in points:
pred = neural_network.predict_single_limit(point)
pred = neural_network.predict_single_latency(point)
red = pred
green = 1-pred
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment