from PIL import Image from typing import Union, List, Generator, Callable BOX_SHADING = " ░▒▓██" IMAGE_ROW_SIZE = 28 IMAGE_COL_SIZE = 28 IMAGE_SIZE = IMAGE_ROW_SIZE*IMAGE_COL_SIZE def show_picture(img_bytes: Union[bytes, List[int]]): img = Image.new("RGB", (IMAGE_ROW_SIZE, IMAGE_COL_SIZE), "black") pixels = img.load() for i in range(IMAGE_ROW_SIZE): for j in range(IMAGE_COL_SIZE): pixel = img_bytes[IMAGE_ROW_SIZE*i + j] pixels[j, i] = (pixel, pixel, pixel) img.show() def print_img_to_console(img: Union[bytes, List[int]]): for row_start in range(0, IMAGE_SIZE, IMAGE_ROW_SIZE): print("".join([BOX_SHADING[pixel // 51]*2 for pixel in img[row_start:row_start + IMAGE_ROW_SIZE]])) print() def read_labels(file_location: str): with open(file_location, 'rb') as img_file: img_data = img_file.read() num_items = int.from_bytes(img_data[4:8], byteorder="big") for i in range(8, num_items): yield int.from_bytes(img_data[i:i + 1], byteorder="big") def read_imgs(file_location: str, as_bytes=False): with open(file_location, 'rb') as img_file: img_data = img_file.read() num_items = int.from_bytes(img_data[4:8], byteorder="big") num_rows = int.from_bytes(img_data[8:12], byteorder="big") num_cols = int.from_bytes(img_data[12:16], byteorder="big") img_size = num_rows*num_cols start_byte = 16 if as_bytes: for end_byte in range(start_byte + img_size, num_items*img_size, img_size): yield img_data[start_byte:end_byte] start_byte = end_byte else: for end_byte in range(start_byte + img_size, num_items*img_size, img_size): yield [pixel for pixel in img_data[start_byte:end_byte]] start_byte = end_byte def read_img_lbl_pairs(imgs_file: str, lbls_file: str): for img, label in zip(read_imgs(imgs_file), read_labels(lbls_file)): yield img, label def test_x_y(num: int = -1) -> Callable[[], Generator]: if num == -1: num = 9992 def generator(): for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte")): yield img, lbl return generator def train_x_y(num: int = -1) -> Callable[[], Generator]: if num == -1: num = 60000 def generator(): for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("train-images.idx3-ubyte", "train-labels.idx1-ubyte")): yield img, lbl return generator