93 lines
3.3 KiB
Python
93 lines
3.3 KiB
Python
from PIL import Image
|
|
from typing import Tuple, Union, List, Generator, Callable
|
|
from custom_types import TrainingBatch
|
|
|
|
|
|
BOX_SHADING = " ░▒▓██"
|
|
IMAGE_ROW_SIZE = 28
|
|
IMAGE_COL_SIZE = 28
|
|
IMAGE_SIZE = IMAGE_ROW_SIZE*IMAGE_COL_SIZE
|
|
|
|
|
|
def show_picture(img_bytes: Union[bytes, List[int]]):
|
|
img = Image.new("RGB", (IMAGE_ROW_SIZE, IMAGE_COL_SIZE), "black")
|
|
pixels = img.load()
|
|
for i in range(IMAGE_ROW_SIZE):
|
|
for j in range(IMAGE_COL_SIZE):
|
|
pixel = img_bytes[IMAGE_ROW_SIZE*i + j]
|
|
pixels[j, i] = (pixel, pixel, pixel)
|
|
img.show()
|
|
|
|
|
|
def print_img_to_console(img: Union[bytes, List[int]]):
|
|
for row_start in range(0, IMAGE_SIZE, IMAGE_ROW_SIZE):
|
|
print("".join([BOX_SHADING[pixel // 51]*2 for pixel in img[row_start:row_start + IMAGE_ROW_SIZE]]))
|
|
print()
|
|
|
|
|
|
def read_labels(file_location: str) -> int:
|
|
with open(file_location, 'rb') as img_file:
|
|
img_data = img_file.read()
|
|
num_items = int.from_bytes(img_data[4:8], byteorder="big")
|
|
for i in range(8, num_items):
|
|
yield int.from_bytes(img_data[i:i + 1], byteorder="big")
|
|
|
|
|
|
def read_imgs(file_location: str, as_bytes=False) -> List[int]:
|
|
with open(file_location, 'rb') as img_file:
|
|
img_data = img_file.read()
|
|
num_items = int.from_bytes(img_data[4:8], byteorder="big")
|
|
num_rows = int.from_bytes(img_data[8:12], byteorder="big")
|
|
num_cols = int.from_bytes(img_data[12:16], byteorder="big")
|
|
img_size = num_rows*num_cols
|
|
start_byte = 16
|
|
if as_bytes:
|
|
for end_byte in range(start_byte + img_size, num_items*img_size, img_size):
|
|
yield img_data[start_byte:end_byte]
|
|
start_byte = end_byte
|
|
else:
|
|
for end_byte in range(start_byte + img_size, num_items*img_size, img_size):
|
|
yield [pixel for pixel in img_data[start_byte:end_byte]]
|
|
start_byte = end_byte
|
|
|
|
|
|
def read_img_lbl_pairs(imgs_file: str, lbls_file: str) -> Tuple[List[int], int]:
|
|
for img, label in zip(read_imgs(imgs_file), read_labels(lbls_file)):
|
|
yield img, label
|
|
|
|
|
|
def get_test_data_generator(batch_size: int = 1, num: int = -1) -> Callable[[], Generator[TrainingBatch, None, None]]:
|
|
if num == -1:
|
|
num = 9992
|
|
|
|
def generator():
|
|
accum_x, accum_y = [], []
|
|
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("data/t10k-images.idx3-ubyte", "data/t10k-labels.idx1-ubyte")):
|
|
accum_x.append(img)
|
|
accum_y.append(lbl)
|
|
if (i + 1) % batch_size == 0:
|
|
yield accum_x, accum_y
|
|
accum_x, accum_y = [], []
|
|
elif i == num:
|
|
yield accum_x, accum_y
|
|
|
|
return generator
|
|
|
|
|
|
def get_training_data_generator(batch_size: int = 1, num: int = -1) -> Callable[[], Generator[TrainingBatch, None, None]]:
|
|
if num == -1:
|
|
num = 60000
|
|
|
|
def generator():
|
|
accum_x, accum_y = [], []
|
|
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("data/train-images.idx3-ubyte", "data/train-labels.idx1-ubyte")):
|
|
accum_x.append(img)
|
|
accum_y.append(lbl)
|
|
if (i + 1) % batch_size == 0:
|
|
yield accum_x, accum_y
|
|
accum_x, accum_y = [], []
|
|
elif i == num:
|
|
yield accum_x, accum_y
|
|
|
|
return generator
|