From 88508a0d07a05ee3b01648454a07ad125988d003 Mon Sep 17 00:00:00 2001 From: Daniel Ledda Date: Sat, 11 Jul 2020 12:02:22 +0200 Subject: [PATCH] Added generic pytorch implementation for multi layer linear NN. --- .idea/MNIST.iml | 2 +- .idea/inspectionProfiles/Project_Default.xml | 12 ++++ .idea/misc.xml | 2 +- GenericTorchMlpNetwork.py | 50 +++++++++++++++++ multiclass_perceptron.py => MlpNetwork.py | 12 ++-- .../GenericTorchMlpNetwork.cpython-38.pyc | Bin 0 -> 2627 bytes __pycache__/MlpNetwork.cpython-38.pyc | Bin 0 -> 4945 bytes __pycache__/custom_types.cpython-38.pyc | Bin 0 -> 912 bytes __pycache__/import_data.cpython-38.pyc | Bin 0 -> 3592 bytes __pycache__/mlp_network.cpython-38.pyc | Bin 0 -> 4983 bytes .../multiclass_perceptron.cpython-36.pyc | Bin 4897 -> 0 bytes custom_types.py | 22 ++++++++ import_data.py | 35 ++++++++---- main.py | 53 ++++++++++++++++-- mlp_network.py | 13 ++--- 15 files changed, 171 insertions(+), 30 deletions(-) create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 GenericTorchMlpNetwork.py rename multiclass_perceptron.py => MlpNetwork.py (85%) create mode 100644 __pycache__/GenericTorchMlpNetwork.cpython-38.pyc create mode 100644 __pycache__/MlpNetwork.cpython-38.pyc create mode 100644 __pycache__/custom_types.cpython-38.pyc create mode 100644 __pycache__/import_data.cpython-38.pyc create mode 100644 __pycache__/mlp_network.cpython-38.pyc delete mode 100644 __pycache__/multiclass_perceptron.cpython-36.pyc create mode 100644 custom_types.py diff --git a/.idea/MNIST.iml b/.idea/MNIST.iml index 85c7612..3772c7c 100644 --- a/.idea/MNIST.iml +++ b/.idea/MNIST.iml @@ -4,7 +4,7 @@ - + diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..cc52ec8 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,12 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index 207156a..c38566b 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/GenericTorchMlpNetwork.py b/GenericTorchMlpNetwork.py new file mode 100644 index 0000000..6145330 --- /dev/null +++ b/GenericTorchMlpNetwork.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from typing import List, Generator +from custom_types import TrainingBatch, EvaluationResults, DEVICE +from tqdm import tqdm + + +class GenericTorchMlpClassifier(nn.Module): + def __init__(self, dims_per_layer: List[int], learning_rate: float): + super(GenericTorchMlpClassifier, self).__init__() + self.layers = [] + for i, layer_dims in enumerate(dims_per_layer[1:], 1): + self.layers.append(nn.Linear(dims_per_layer[i - 1], layer_dims)) + self.loss_fn = nn.CrossEntropyLoss() + self.optimiser = torch.optim.Adam(params=[{"params": layer.parameters()} for layer in self.layers], lr=learning_rate) + self.to(torch.device(DEVICE)) + + def forward(self, input_batch: List[int]) -> torch.Tensor: + x = torch.tensor(input_batch, dtype=torch.float) + for layer in self.layers[:-1]: + x = torch.sigmoid(layer(x)) + x = self.layers[-1](x) + return x + + def training_epoch(self, training_data: Generator[TrainingBatch, None, None]) -> None: + self.train(True) + for x, y in tqdm(training_data): + prediction_probs, targets = self.forward(x), torch.tensor(y, dtype=torch.long, device=torch.device(DEVICE)) + self.optimiser.zero_grad() + self.loss_fn(prediction_probs, targets).backward() + self.optimiser.step() + + def evaluate(self, evaluation_data: Generator[TrainingBatch, None, None]) -> EvaluationResults: + self.train(False) + accumulated_loss = 0.0 + total = 0 + total_correctly_classified = 0 + for x, y in tqdm(evaluation_data): + prediction_probs, targets = self.forward(x), torch.tensor(y, device=torch.device(DEVICE)) + predictions = torch.argmax(prediction_probs, dim=1) + total += len(targets) + total_correctly_classified += sum(predictions == targets) + accumulated_loss += self.loss_fn(prediction_probs, targets) + return EvaluationResults( + total=total, + correct=total_correctly_classified, + accumulated_loss=accumulated_loss + ) + + diff --git a/multiclass_perceptron.py b/MlpNetwork.py similarity index 85% rename from multiclass_perceptron.py rename to MlpNetwork.py index 15f8200..df901a1 100644 --- a/multiclass_perceptron.py +++ b/MlpNetwork.py @@ -1,5 +1,5 @@ from typing import Tuple, List, Callable, Generator -from import_data import test_x_y, train_x_y, IMAGE_SIZE, print_img_to_console, show_picture +from import_data import get_test_data_generator, get_training_data_generator, IMAGE_SIZE, print_img_to_console import numpy as np @@ -44,15 +44,15 @@ class MulticlassPerceptron: return self.classifiers[classifier_index].get_normalised_weight_array() -def train_and_test_multiclass_perceptron(iterations: int = 5, training_inputs: int = 5000, test_inputs: int = 1000): +def train_and_test_multiclass_perceptron(iterations: int = 5, training_inputs: int = -1, test_inputs: int = -1): print("Loading data") - training_data_gen = train_x_y(training_inputs) + training_data_gen = get_training_data_generator(training_inputs) print("Begin training model!") model = MulticlassPerceptron(IMAGE_SIZE, 10) model.train(training_data_gen, iterations) print("Model successfully trained.") print("Testing model...") - test_data = list(test_x_y(test_inputs)()) + test_data = list(get_test_data_generator(test_inputs)()) n_correct = sum(model.prediction(x) == y for x, y in test_data) accuracy = n_correct / len(test_data) print(f"Accuracy: {accuracy} ({n_correct} correctly classified out of {len(test_data)} total test inputs.)") @@ -60,8 +60,8 @@ def train_and_test_multiclass_perceptron(iterations: int = 5, training_inputs: i print_img_to_console(model.view_for_classifier(i)) -def get_trained_digit_model(iterations: int = 5, training_inputs: int = 5000, test_inputs: int = 1000): - training_data_gen = train_x_y(training_inputs) +def make_trained_digit_model(iterations: int = 5, training_inputs: int = 5000, test_inputs: int = 1000): + training_data_gen = get_training_data_generator(training_inputs) model = MulticlassPerceptron(IMAGE_SIZE, 10) model.train(training_data_gen, iterations) return model diff --git a/__pycache__/GenericTorchMlpNetwork.cpython-38.pyc b/__pycache__/GenericTorchMlpNetwork.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94af4141ff4a62393c770e07bee0f7552ef748dd GIT binary patch literal 2627 zcmah~TaO$^6|U-g&t*NfAo5yZIUzt=pq=oN0J1EI*F;G6DoPAOEuyB~RkPdUzRXtF zB-vHZ(>e+-mf$@PGmm-Z-_$El{tHM5zEd-6I|f6I`b^#Xob!F>>)(yX1BUjuFXQYF z5o3R+!OP9V;8%FM6h<(?E0%8mdF`YQr@m9UwU>IepZa#}RzV%6VI8GW9jCFK^Qu8T zNC%vK$b>I~Crkv=nGY}AbR_sEED4|Ces+{N#=Bp3+Qc8qMyg!5Dsj#Dp~}mqY>q$3 zb#Y>@-uo=Cmbot5=406{E8Usk@ZRr#d-q<#ji*mVUC?DOHwuGVr(h&YxnQXyUFq#J z=?ng~^A($M;RyH1C=GV=;XD%F*N%19OJg~ZLlKH_=I%56M_7r);E9)xUS5%7F~rJ9 zjIpx!74ukn1>x-+I1g8&>kaI;S2^a}z9bmDe45 z1KX2Yst(C<4&F7s&C54TO>oCQ{xm&lYdNi?5c%{(i0NTlENj{5Zu*gGk8w+l+)W?k z#Yx%7TlaU#oFc)6Txiuc*=*J{(^nullKS&jJw8~RnZfM}EG*i3apwxTPW%a&!_MD( z<;)8h53&p#)me6MyK0NP>h7R(c}e1^5w0*!rIM@0@0M^y5sW6fJZ1tLV`pO7AQ%GP z1o>hi8)2NLF~NPrk}EU0tJ1(H!@=6H{8ToRuEIO%%)ylrLEtEp)=hc z4C8w|;5Yb9KEQMS*8fisFddf7VyUx7ww?ne^-?tjk>#>c!ihL{@pOkUn$3C7VI9Gr zIxEm{?vnNp#X;|DD_66%*So!Q0ZOi%)0=h&R%r(N9gxgD0=s;n&ld98zt9q2eIIw5 zK+C3UmGNg)n``w3&4pcgT(_l2JoN*N)sJbCf1!vE+guPAZ9+^qYt`qu64x-(SwZC! zehtqaKY#P}3E7<@)f?-b#APDaxq1f|eJdY?7xBgB;{+X)AWEBx59QA1_X}*UDQVx{Cx&UJstOPA3X+@G4^`UAFzdev~e&7+4T_A$Gixi zx*&RV^qXV&Wt$>FPWTp;q3s(N<&6(31ViVQ35!-KS!grK3kd15g7%1vR3#aypW(!R z!}a*RKLk!3+^xJmOEr_!=yLti?iK(tFx`h(C1KC3aIyJ z>-%JWPWy*1=f%mN9$G=!mWNAgBxvdvWGF4_0~j;jiEAON z>QZ2>sQi$&SgEC&r>C5{NDzc(~&@_su4=Foa;NxdqQ}kvtDwdsY u>x@*d+xlT$AW5;nn&5#A&tCgdd0AJ;3lrVmCg2WH$qHO#llmP#a{mpJ+K?Fl literal 0 HcmV?d00001 diff --git a/__pycache__/MlpNetwork.cpython-38.pyc b/__pycache__/MlpNetwork.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1f2da2d3dbdf44ad9e1138b315b756122c9a967 GIT binary patch literal 4945 zcmbtY&2JmW6`$D;a`_>OmT61=NW!#EV77^rv}xKVsvOsL9jBJ!)@~1rri(RaC^O#W zQnSn0B9=e_8L)e3fnIuv4M;~_`p5LR$M)a~y|n19D3Jc%EGbd892e*kJ3Bi&yYoJN z@AuyD7c(k)Sjb%-8mx+CoS~m^9-E>h$`5R}CXP*w&4?PtW@TN> z-ZfMzYg-K8GHkigF&9-a^T2W|q9$e^Ft>_%b&MVpa~Pcw$6=F`Sz~n&B;9u4$MFrR+OnUhsP|$i zKc13O)TAj4e9;sZWU|Dpv>{VL%T$y_QGmQf;ozC0rDJYMR74fBl*J5IuMb~4YW2kf zL7u*S>FnSC^8V=yEuNJF>8qXx56QMUWkTv(I4|Lh&!I`!=X~2px#06`2X<`0nhPvt zo|zhGEy1w1xwgkKch5l6vNNOCcMHnzi70d}x@N_Dvb%mKi4~BI>6sPFpp%umz5XEa z;_guD2fIui$Mb*RTDlX3aw(8P_)B+$Sh^gw2chgG@zM~Cg{=K*Jl=Q&s-CM6j^&4=e-VztG|YkL%*kiDis z9NZ4PzTZ`GR_jM`H|c&MJwRUu&?+J1cLCD2UWHb|QNc+|YtuePCt<0%#V@nZem}Aj z7&x_rG0(naPmC?jSpuVPG5tLFl`*n$7DmOCZJTLfhmV}Jm>LMmPlVAkggGjujep$BuKhsYvnbH_Y*bFm@%fZ}ZsD}kQiCOh?V_;Pd0HUHHrPBl)!I0};b@l@)`oM} zCW;NCEJNc zP2M!{c4j9;DFn5c7G=E3=lDrp=5^lW79YOxGMDT3+l1!HLH#05?Yz_IBy5|*lm@P* z;4q_(j~xTiV63S#n9;IwMNJWnm2YBn*U$=QHc&JOpOIX>h8ZVu#(Hf&eDh_J=C*B< zGbRW1i3qi*=h4aO55ls9LK)W`eg-r?k7R(W-C&o2GYDQ-JMgMj)~&)XRI!T5>Y2_ipHCrDoM zh9cjpZ&C9mH3S*8M9qOpx~D^-d_eyM-q8_)5}4yNZld#$J>(CK4q^Z~)&Vj=UIeE{hKKB{bHj6>D&Rd5F{Xgm zpsJ7X!Zv%v9`POH5U{k6Ekk`9Q#0=E=2GT7qM2h~SM?UAyh5KPAvL((JH6 z2YNL-k;-qILvwL)5o*4QrPL3pAvO=IWNWfVR{P4E$S4f3_;bgWDbCayT)>zf{|KGo(P4Iwua7Vk62@J*3C>xhuV18wf=3;g zi?E+Z);aWOd}6Mn>)eCokd5jG@I=ukh;JM(6vc*{_w<)<$AZ*oEV^q zEG>iAbyyA0tkO+09<fN_UYdH+jdaf2k&aS3dYW8oHpSk7VJL~oa8wq9n6%u36&&%WPRE0m*jPPE|k z;rF#F9(HBTgYd}Ua^*lrJ$-`@P7IW(H*hp#lp!>SH#VbA zbJ%DmQQ`+pLP|47%3`Zk%`*$tA{uLo0FxIongj@ynHdkl%nW2N=Y$eXAuP8X4M+7c z4P2v{$0mN+|MS2t?f=o>mU{TBK<^M-#~%x?(^qb$Pv|og581i@#U{iAF+~OAj{cRg zub^HHdy6V|7FDip;OX8YMBCoq;Ne4&K|w_xzl1ZU?2}i~h=`oQjgh5u_&5?$)Yi0< z(3PU4jxZQ9v(Cr+)=;0~tLmrJ9FF&FA@uJ`Pb-NMT6EX34c2%e9v|mv2m2#va1$d` zmxkCuc^;5N+*6bi;n6@;`nQIXBlt zd!791+9)^sn|h~|141(#*-?v3CcaTrVGt__v6V@4>mfC_4Kn_8OHWGE#(v zGgeU(v$#oE(oja5l%>%+pRp0dKyl|YG9!3y%*ZmH5mvk)GLiHXkra}T3EMw>ZD-g<5A!iWq#XJD&<;#Enmvsuj*pI2T&CPPn?& zmh-Q~Oz`HQex8-9wt-TCl%iZoSsu1k-+*?ud}G>kBbuhx=jCfrt!fSTk3kq60v(tN zI%itVmrJe7s<*CfxV!1VKI(2vh81!Y8y>lccB3KkfBqigwY&WN;Y%VrfLzK1S)#~N z>UWN=gzATW-h%Cwvp#j$pz@=tH3ll@r=r3>dLtm>tB@|Y!6*UzU6%bRNDGwdJPF>&}d%!Uc=uOBAW zeS^eohQLto9$J2UVniJneI{JB8b}%nFd7Q5t=e*NCF)`NSax`H_Uhp1Ur2qvx1Nfp W_@@yYz~{(%9`7d8bea`W7XJd9P}N5O literal 0 HcmV?d00001 diff --git a/__pycache__/import_data.cpython-38.pyc b/__pycache__/import_data.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ed9e070e35e0f22a31c271540b0445fe762201e GIT binary patch literal 3592 zcmd5<&vV>J74B|HGn(=EMGoKO zG#a<$+Qg+fInBmCvbu9wd{%Q^@iexL$ABP_r3Sp zkA{b<49_1fm78FY|oAy&xuN2DJpwq zGiHTuRPid1ZLcaFZwRHPhE#2wdBe(9Bg#>u=h&J!$7Jahmu2a`u)VPpBNdEPI zMNZ1eHS1sxQz99i z{c;C?{R@TqN)WZ1tIt!Fv|xt8Q-Vb021+7PrEHf=o{H336H@G1J)ZGhdyn-*>ZD?u zFR{-*>{*a4BhU5hjHM#8##v@hvU9BGWaeG7GLhLCv@EiX`Hgp>rL@P$usF1};#m)F zU-z8V&~H6$l(oRhxz$l$uid`cd~olx=EG(0=3)EBy*owQ4r5>Dj`llk<=M%4 z{6(XriGaC<6MCg?@JxkC&V!sk$sN<$XL=kH{`F{KJ&x2usHF54)}>rnid&nJ>ZHlS zeI2*8A4Pz2;jZ6W4?1f0j`B5OyUUc(9__WEOZQiIygCq!T)0jDI6pP{@WIizSQHXCB?2S zEor}C8OvaeR7|qpg7iyloA2>$IMPN9);Pr0S?^Th!#<+_0jBh6b70xf^?Q$-4}W=WX?f-LzR1P=WqlSS2GpW@2hF^?7=o&; zIO<-VLOZchm6%{#m%2LWq)j4V8aG>UCy7HfzZk|XKTNJdIQ*>P=uxAo-~icQ6QHmo zX~^v-anNZ9gQ8c_Ks`x)ou1st_-X5r0TN<5gF)uF{x+nB$cNX0P&H{QDS2B-D_F5} z%tQnmnW%+w;365o+7^GjZ=8c0=MYaocqybsrjfREelOrAFkEJB3z%}YXQyTkoWho+ z!i2@VU%_-6AAZJu#6DpU*yF-Q$}73ht9=5o@1n}B)u6r4^9uQJtfkU@GV{KZ+i_PR zPG|+cun4B=nvSEw*A1S4^3v+^jr_Ae5C6ips`)CVG|3LVtz6B$TsC&)wMSe01vU z*ibk(Le(eM_D^+o3nRt}Y1BBOODrlJ z?+E1S0B1@H zsTGHih$>02rM#+fG4rM8s*W_$iH+rz!8QmH!xVNvk_iySR6E zCnW_VE8%Kbi26GiJVq(<0V>7;--D2{lxHvv@($vkaj*KXMGp>#oP&&G9I9?I4PU`J z)>hCpYZ|^bZ2e;>(LbSztdKi~;3p4NKRi?uduw+6Kqs_s13u%^T>p#|(0=;=#zPBL zS1Au0;E>HWqpzQZ%zs#t88mi~ZJ!EVlj3*$OZFv4)}>hQlCvNZBbqN_R^WnLbpF9Z zQHct9V3MI~;}xvOD5R%+2f(C4aw)d)BK|xkfKV(-*t_tkJtPh!Wh=9dgk>k$4x2#i z!geI3ngYl-dN*b=u|u1kc{92(V_U+ACHw^ym6 zbgxZPsrcs^*nXZ#mC%saw|!-iGW2cG~tWrr)i7DAYP_;Vxf8bBx#Bq*rfGs${m4bxd>(OX;8DA z!?U(i@$X$`?5{L9|5zBjg*W*Z8o>k)S(oE$gvQ9|8l2|L&>UG^Yh-urk<)cXm2PF^ zc3r(@h1HSQ^+vUBO^@xcK5BFuy3fOAw<-DXl73sl`NR|z;Xbjtmlv~G)$>*1VSYt6 zXY;;#k2vPgQ5OvYOYWu=mQ(VH_+J(86xQw~=x0zTGt529%1*Q$r7HdyfcT;YO z^#_Lfr`$YdppV~UosCm)&GtGxuPND!mFR^*l636c+Y3fgyf+z#GPjG5b$5`YxqB-J z!{9!qt8dFlsvwP3$H=etW$LFY7(|1p?~5P}{C>HdZ|Fstq)Q7pwsJR4k~@>AW9HSd zl48(HrO2&*9E#jgGM%WXXAq6^hx#VoTX>Tenv`WMC;Kp*p1oiZr5e zZa+GZO6I0~Jbud6Iu;5}&gicsgZ?NU2og%-qG8Il^}q$!{cc(^5=leToWn5;-wSco#VaGu!RDrVg+CMutB;fg9EgC}dly6`Zr zi8{uOa}W*D#7wi8De#F)h&@G2xm3(ZM`KvQ8LOfV_?8~80pq&Z!1%JbCRh4QuI?M4 z#}Qk=a7A1NHEppCEj^!Kx^rhwPE-)?6~2?(cY{NzjTmyj`tLu#bN}5pDNrP-n)DDFWbT%JmU|(&sgHOhds>zd6g>{!g^l0GATu-OvkbQE z^FHjzH&Qc$EeZB(P<+g9vk%_*r7^XpcE*OriIv(zC#{^g8PBXOmf3w{#`V}f`wHV@ zbH+J)_<|rygg-&h&sqepx4GB3W zNL#}1I1DTpr zD}~%oOz6Ov8OM;QZ)Midni>zASiua!9=SuuR$s$Mb*Lfl%bSJ7{QX!x3Y1V^DWF~< zfL95?gvF@q)U!4B%Z%6er4;4%uLCa8g#lPz!O!JK+sh1}-#{6%RN;7%_DmosA7Y(M zPP0yW0MlT9i%m@(B+Q{Xw20v;n9WSz&a9G6r{GX))f926NGmyiPkj|%2!Yt^;$RpC z)F$L(^-avEZx9sHV=@WTB|a6gM)F_8)E4x5bZwcri{iYXN(z0%ljf($v6L1N_7Lu7 z3l*7#F*ZIf@DfMKpV0h*sLSOWZmZ zR`xzOtx^e-kYiH_{-8bO8UKi(IBg3S8G`@LcnBYXIZ+(d#f4S~?83PIs`heJ6=!B{$J2_nv-+=k!xF>QZ@l zrOs4#4oh>xf~uv`4qe3Mbvl+#^+y9CLKu>c_lgRy zPGi~r-XfIb=Uq?eBqx9`Z&B-uB8AZQynSxqi!$9b=X=U=Z9ay7eTWo&4vEgL-PJi3 zJ@SS2y>5NJNskJL_#RQErCnVJ@iyiX@<7I~7#?q;jH@G!w2f_yJjB;FKic|26`)+( z!`2G{h`fO+LLGpH=x9G6+lN;aSz70mJ>(Q*736OOSXe)K1+s|1bpRelxjM=DAxDR~=7hq1=VkkeQ_@H2*o;)=E6nb*ECO zlQ`4Xsx_&8M8mt(yhF{!5TkyGxu4-psP1R{8e|74;)n44^5@C#Sz9_Kn}^p>z{wps zj(Z2mg{Vmpb5RZb00U%bct6!Vl&o<5Fom(fe$W1ZS&k%j%&@jy)}})ZOGqYs=!%aJbg07XXXTPw2Cjqxgw%ul4=h^ZijQCh$T(ulqi3 zZDFyZC^Z$e7UjU)B%-NO$nE_w4pOyC>y+%Yn=49ZI_Ii4sP`r{nnzNnB6%w^XGNMl z$IhZ4Q#)9A5pP0MF87Sh+DfJ6teBhT=1Qxo*Lzr6JcN^8?Hxzxq`0x+zGn7GgVwP* zWvL--_ybQ(LD_>%sp^6U2xo?WG%^D^eW9qwFmGq5UCCFVSt?NM2gWgJ_O-M!bkpk4 z6P1r)3s^_y5V%ha>tq{R>(#wM%f!u3fQ(a;Qva1(+yD5>>sQ_!{GH}WgeaY==>dwB z17vX>%MtB%u#mTwudv-}q2R~Ql5nZqB;)?b@(dEs{5FkINvQ=7ot{^aogs*o+ z`X*EU4X_Uotw*LyN(zK^eXDJ9bwJx;fFt}z;`^N|xs5V^d{|f!8C~wAhhxyFAULMT zlNE{Y%!cyXU^I?ZQ4}c(kXlH2y*EkHc;sWdOth`4`?RP{^B2^kJF}uc0@Salp}bq< zEV^uJU1;*u&uMrTmKH{uyEh7Nxk*?2ger6woNXgO)wye~Tbq_=ww<=swwt%yy379$ D7UFHi literal 0 HcmV?d00001 diff --git a/__pycache__/multiclass_perceptron.cpython-36.pyc b/__pycache__/multiclass_perceptron.cpython-36.pyc deleted file mode 100644 index 518a3e5889070e6239a582063674ecb3b4316cb3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4897 zcmbtYOOM;u6~31biJ~+bc`SSUj*>co)!NgfX_MA5#&&GObsJAz*lhq}Qi7w)Q8Xq+ z%_SAj2oz9220UGK(XJZAF1qL@+y0a;xoj8dpqq4S6iB~wsmIt>?E=-{US5*-aUS3K z&f&bfuwWng{Oe!+slwRb*@2&l@fEcAAv(bX4_L?@9%_yj>W&^7j=|}<7MNkdDTGC* z7+Q`MmYh;(JGOeK2W6*nlL2vzmWM zRPSr@*Z$&&VK4nIYaZH%B&=y<=DmI|@H6A1PMl=cH7^Lf4cwP*`dwdoNhF(QW+i@{ zxI6A{R!XGT>8hc9>-N=~H{83ooEuraCp+E5?SySNiQHDyjU(()i9d}VxV=s*=}UjD zMF&oPQ-k#I#Agl{%+a767aEk)pqwu1zJWhe7*Ng>1t?MwNBm-&`PQa(ocV^Z9&4f` zY`iJ?wr?C~qWqW(Qrd9JqAC_2GN*!hH9S2emhf~z9ES2ov-<6RkaSvs7sq#e+46gd zjJoqB{P6(^Mb+1ZhTU{w;8X(+Fdn~tv3uA#+e(Isq)Z)qtPe3^}4zdEdV zjHS%gQw_Z#7|zw#_BrP6Yv|xmt=r2A((8&S%nZ6_#Rq<;{b>@*1)8B}M(hWhnceC3 z`iUEN2EKanC6h<-{^`%JU0nY~?8|svcwIMMzaF*vq2Ep7^*b`!@>}_!Exj=8bldAQ z?bgGo-QC{Q@GHID%yQjMx0AT8g#%&(9pfkX;M^{m++ZVcdtOJz zS+y6%ouu=b?*iC<0F@HLbq5%2sZ+45j8;TbHvKaOJAX`As*m_}_Q~&uMgrTWhR~MT z7wm~P;+*Xa5?Fi$dt*kht_@AJ!myaKEj=xa_|Qs=sfLhS6Ixdj=CG7nLcedyjkHA9 zUQgY#sHg7P&`$06JXRZMW?HJVq_9;KMw_Pv@@kzelShs1^B=Qq`(scAd|Fr=EZv$Y z)9CiY4PRbt$SUNQrMKHG%Q`(jLETB}PN8!wtaizk8Si9#C*!-B-IKmRw8OTUsjTf3TalE&gMv-o zEn{_NDjoq%vY3uyyvmpOF>do3Z*YST-h7qY)%q5}nSOL7N>L6ZK`AI;N)1WnntduE~2sWZ(waxQwnD$kkk*KlU%-z859GtI-3vPdX=QPZCm6K`q7ne zr$jx6LC$#)mL(L&xQ_Tqkn}RB9#^x@uH%XF_jD6Fn>M+N7L#|{fX645AqQ;vi2qL8 zHty!33Kr2p%}p)SR^FCpG4(lxp4yiJP=6~5d#hwQrB{Ise)cN8+J548BN=*F?h7~P z4_85Efc}f>#w&r?p%hYj^I8kQ^2W=<>xFt zr7t>NAGf%H!6WvFKhidVe56wgm;h-Ji1suWiB_e9XP`uqQ|!m2jDk{ah{;=QkL~f1 z_6(eqa0rp}n3{2SEM3kTiYKNz$_tosnzkjl0#fhH_0fdD=Ya<2oS}2dO7G5Q2Pp?Z ztIi>#$oHwEl%;jm=2rw~c|m zva$k|8#qq>h`QIPBZM{01H?>%%5R}_?AbmCkf#hmt7y4G8{>7bZIusB%r%E1SVY9P zG5-agCA2B9Bp|IvY!eBRkVcF<0@j{I7%?6pKp~8>G;!Qs&S{)tk0=!vBB$|z2MEBr z6nNo=@K)c#>llpas-{ZNBQg)mI_60j>x%AIP z+(oQxlI&>zE#78HdXBA-E$;ifh@su4 zRMdXL-+9KIWPX(-$s7cklpm*98=FJRqk#x^BHb#+&p|#r{8`6;;3A%K>Qod$96b?p zetkJ-5drl*I&7RIs`4E=?t9d|OC6DRPXFJh(Q|o$Od~05^1MLSl9#EQlN=<~qXo2> zvXaHCN9y{C!jg{6KyUVmgZB3jPW%7Bj&K$Qf3~P17DNJDqL_=bvifSvZ>Ah((T->; zg8za+f&zqiDnX5e!d1t;DsdAn(S<(JeuGd&S*oXOq+u3{5;dE?Kxvhk6ezn;DGcr_ zDlRC9BDlbyjDx3UY9_d&qGnjYdMhn}9ZRW+-p=1(RDM#V{dUeID3BGOE)9!`jkC)s z-q5@4Ka5$kvNouE6nTP3ldAo}kq`ZLr`wnqpn;6-2d}FL8=SmNH;uU8YWZ=z*$;x< ze4;N_2DN+mu%&I(3NY<7Ajv+|$h9h`qP{u?hUY^KGN~$y3jALZn`yM>e9(#-QFvqQnatbVwsd$V#(W$#VkMh#>-`4haIMgm0?)NWvnW>@EmKviv_j=9I&FOS zDu9q7$I02k$54gKpW+FXqapp int: with open(file_location, 'rb') as img_file: img_data = img_file.read() num_items = int.from_bytes(img_data[4:8], byteorder="big") @@ -32,7 +33,7 @@ def read_labels(file_location: str): yield int.from_bytes(img_data[i:i + 1], byteorder="big") -def read_imgs(file_location: str, as_bytes=False): +def read_imgs(file_location: str, as_bytes=False) -> List[int]: with open(file_location, 'rb') as img_file: img_data = img_file.read() num_items = int.from_bytes(img_data[4:8], byteorder="big") @@ -50,28 +51,42 @@ def read_imgs(file_location: str, as_bytes=False): start_byte = end_byte -def read_img_lbl_pairs(imgs_file: str, lbls_file: str): +def read_img_lbl_pairs(imgs_file: str, lbls_file: str) -> Tuple[List[int], int]: for img, label in zip(read_imgs(imgs_file), read_labels(lbls_file)): yield img, label -def test_x_y(num: int = -1) -> Callable[[], Generator]: +def get_test_data_generator(batch_size: int = 1, num: int = -1) -> Callable[[], Generator[TrainingBatch, None, None]]: if num == -1: num = 9992 def generator(): - for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte")): - yield img, lbl + accum_x, accum_y = [], [] + for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("data/t10k-images.idx3-ubyte", "data/t10k-labels.idx1-ubyte")): + accum_x.append(img) + accum_y.append(lbl) + if (i + 1) % batch_size == 0: + yield accum_x, accum_y + accum_x, accum_y = [], [] + elif i == num: + yield accum_x, accum_y return generator -def train_x_y(num: int = -1) -> Callable[[], Generator]: +def get_training_data_generator(batch_size: int = 1, num: int = -1) -> Callable[[], Generator[TrainingBatch, None, None]]: if num == -1: num = 60000 def generator(): - for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("train-images.idx3-ubyte", "train-labels.idx1-ubyte")): - yield img, lbl + accum_x, accum_y = [], [] + for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("data/train-images.idx3-ubyte", "data/train-labels.idx1-ubyte")): + accum_x.append(img) + accum_y.append(lbl) + if (i + 1) % batch_size == 0: + yield accum_x, accum_y + accum_x, accum_y = [], [] + elif i == num: + yield accum_x, accum_y return generator diff --git a/main.py b/main.py index 7ed1437..535a762 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,50 @@ -import torch -from multiclass_perceptron import train_and_test_multiclass_perceptron -from import_data import show_picture, test_x_y +from MlpNetwork import train_and_test_multiclass_perceptron +from mlp_network import train_and_test_neural_network +from import_data import show_picture, get_test_data_generator, get_training_data_generator, IMAGE_SIZE +from GenericTorchMlpNetwork import GenericTorchMlpClassifier +import argparse -train_and_test_multiclass_perceptron() \ No newline at end of file + +def main(): + args = get_args() + classifier = GenericTorchMlpClassifier( + dims_per_layer=[IMAGE_SIZE, 200, 80, 10], + learning_rate=args.learning_rate, + ) + for i in range(args.num_epochs): + print(f"Begin training epoch {i + 1}.") + classifier.training_epoch(get_training_data_generator(20)()) + results = classifier.evaluate(get_test_data_generator(20)()) + print(f"Evaluation results: {results.correct} / {results.total}", + f"Accumulated loss = {results.accumulated_loss:.3f}", + f"Average loss = {results.accumulated_loss / results.correct:.3f}", + f"Accuracy = {100 * float(results.correct) / float(results.total):.2f}%", + sep="\n", end="\n\n") + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_epochs", + "-e", + type=int, + default=5, + help="Number of training epochs to undertake." + ) + parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="Learning rate for the optimiser." + ) + parser.add_argument( + "--num_training_samples", + type=int, + default=-1, + help="Number of samples to train with (default = all)." + ) + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/mlp_network.py b/mlp_network.py index a6d9ab9..6a14c7e 100644 --- a/mlp_network.py +++ b/mlp_network.py @@ -1,15 +1,11 @@ import numpy as np from recordclass import recordclass from typing import NamedTuple, Tuple, List, Callable, Generator -from import_data import train_x_y, test_x_y +from import_data import get_training_data_generator, get_test_data_generator +from custom_types import LossFun import sys -class LossFun(NamedTuple): - exec: Callable[[np.array, np.array], float] - deriv: Callable[[np.array, np.array], np.array] - - def sum_squares_loss_func(predicted: np.array, gold: np.array) -> float: return sum((predicted - gold) ** 2) @@ -106,10 +102,11 @@ class FFNeuralNetwork: def train_and_test_neural_network(): model = FFNeuralNetwork([28**2, 100, 10], sum_squares_loss, 0.0001) training_data_gen = train_x_y(1000) - test_data = test_x_y(10)() + test_data = get_test_data_generator(10)() model.train(training_data_gen, 5) for test_datum, label in test_data: - print(model.feed_forward(test_datum), label) + prediction = model.feed_forward(test_datum) + print(prediction, label, label == prediction) np.set_printoptions(threshold=sys.maxsize) print(model.layers[0].weights)