First commit

This commit is contained in:
2020-06-23 22:53:47 +02:00
commit 368d134228
17 changed files with 308 additions and 0 deletions

77
import_data.py Normal file
View File

@@ -0,0 +1,77 @@
from PIL import Image
from typing import Union, List, Generator, Callable
BOX_SHADING = " ░▒▓██"
IMAGE_ROW_SIZE = 28
IMAGE_COL_SIZE = 28
IMAGE_SIZE = IMAGE_ROW_SIZE*IMAGE_COL_SIZE
def show_picture(img_bytes: Union[bytes, List[int]]):
img = Image.new("RGB", (IMAGE_ROW_SIZE, IMAGE_COL_SIZE), "black")
pixels = img.load()
for i in range(IMAGE_ROW_SIZE):
for j in range(IMAGE_COL_SIZE):
pixel = img_bytes[IMAGE_ROW_SIZE*i + j]
pixels[j, i] = (pixel, pixel, pixel)
img.show()
def print_img_to_console(img: Union[bytes, List[int]]):
for row_start in range(0, IMAGE_SIZE, IMAGE_ROW_SIZE):
print("".join([BOX_SHADING[pixel // 51]*2 for pixel in img[row_start:row_start + IMAGE_ROW_SIZE]]))
print()
def read_labels(file_location: str):
with open(file_location, 'rb') as img_file:
img_data = img_file.read()
num_items = int.from_bytes(img_data[4:8], byteorder="big")
for i in range(8, num_items):
yield int.from_bytes(img_data[i:i + 1], byteorder="big")
def read_imgs(file_location: str, as_bytes=False):
with open(file_location, 'rb') as img_file:
img_data = img_file.read()
num_items = int.from_bytes(img_data[4:8], byteorder="big")
num_rows = int.from_bytes(img_data[8:12], byteorder="big")
num_cols = int.from_bytes(img_data[12:16], byteorder="big")
img_size = num_rows*num_cols
start_byte = 16
if as_bytes:
for end_byte in range(start_byte + img_size, num_items*img_size, img_size):
yield img_data[start_byte:end_byte]
start_byte = end_byte
else:
for end_byte in range(start_byte + img_size, num_items*img_size, img_size):
yield [pixel for pixel in img_data[start_byte:end_byte]]
start_byte = end_byte
def read_img_lbl_pairs(imgs_file: str, lbls_file: str):
for img, label in zip(read_imgs(imgs_file), read_labels(lbls_file)):
yield img, label
def test_x_y(num: int = -1) -> Callable[[], Generator]:
if num == -1:
num = 9992
def generator():
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte")):
yield img, lbl
return generator
def train_x_y(num: int = -1) -> Callable[[], Generator]:
if num == -1:
num = 60000
def generator():
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("train-images.idx3-ubyte", "train-labels.idx1-ubyte")):
yield img, lbl
return generator