Added generic pytorch implementation for multi layer linear NN.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from PIL import Image
|
||||
from typing import Union, List, Generator, Callable
|
||||
from typing import Tuple, Union, List, Generator, Callable
|
||||
from custom_types import TrainingBatch
|
||||
|
||||
|
||||
BOX_SHADING = " ░▒▓██"
|
||||
@@ -24,7 +25,7 @@ def print_img_to_console(img: Union[bytes, List[int]]):
|
||||
print()
|
||||
|
||||
|
||||
def read_labels(file_location: str):
|
||||
def read_labels(file_location: str) -> int:
|
||||
with open(file_location, 'rb') as img_file:
|
||||
img_data = img_file.read()
|
||||
num_items = int.from_bytes(img_data[4:8], byteorder="big")
|
||||
@@ -32,7 +33,7 @@ def read_labels(file_location: str):
|
||||
yield int.from_bytes(img_data[i:i + 1], byteorder="big")
|
||||
|
||||
|
||||
def read_imgs(file_location: str, as_bytes=False):
|
||||
def read_imgs(file_location: str, as_bytes=False) -> List[int]:
|
||||
with open(file_location, 'rb') as img_file:
|
||||
img_data = img_file.read()
|
||||
num_items = int.from_bytes(img_data[4:8], byteorder="big")
|
||||
@@ -50,28 +51,42 @@ def read_imgs(file_location: str, as_bytes=False):
|
||||
start_byte = end_byte
|
||||
|
||||
|
||||
def read_img_lbl_pairs(imgs_file: str, lbls_file: str):
|
||||
def read_img_lbl_pairs(imgs_file: str, lbls_file: str) -> Tuple[List[int], int]:
|
||||
for img, label in zip(read_imgs(imgs_file), read_labels(lbls_file)):
|
||||
yield img, label
|
||||
|
||||
|
||||
def test_x_y(num: int = -1) -> Callable[[], Generator]:
|
||||
def get_test_data_generator(batch_size: int = 1, num: int = -1) -> Callable[[], Generator[TrainingBatch, None, None]]:
|
||||
if num == -1:
|
||||
num = 9992
|
||||
|
||||
def generator():
|
||||
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte")):
|
||||
yield img, lbl
|
||||
accum_x, accum_y = [], []
|
||||
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("data/t10k-images.idx3-ubyte", "data/t10k-labels.idx1-ubyte")):
|
||||
accum_x.append(img)
|
||||
accum_y.append(lbl)
|
||||
if (i + 1) % batch_size == 0:
|
||||
yield accum_x, accum_y
|
||||
accum_x, accum_y = [], []
|
||||
elif i == num:
|
||||
yield accum_x, accum_y
|
||||
|
||||
return generator
|
||||
|
||||
|
||||
def train_x_y(num: int = -1) -> Callable[[], Generator]:
|
||||
def get_training_data_generator(batch_size: int = 1, num: int = -1) -> Callable[[], Generator[TrainingBatch, None, None]]:
|
||||
if num == -1:
|
||||
num = 60000
|
||||
|
||||
def generator():
|
||||
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("train-images.idx3-ubyte", "train-labels.idx1-ubyte")):
|
||||
yield img, lbl
|
||||
accum_x, accum_y = [], []
|
||||
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("data/train-images.idx3-ubyte", "data/train-labels.idx1-ubyte")):
|
||||
accum_x.append(img)
|
||||
accum_y.append(lbl)
|
||||
if (i + 1) % batch_size == 0:
|
||||
yield accum_x, accum_y
|
||||
accum_x, accum_y = [], []
|
||||
elif i == num:
|
||||
yield accum_x, accum_y
|
||||
|
||||
return generator
|
||||
|
||||
Reference in New Issue
Block a user