First commit
This commit is contained in:
77
import_data.py
Normal file
77
import_data.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from PIL import Image
|
||||
from typing import Union, List, Generator, Callable
|
||||
|
||||
|
||||
BOX_SHADING = " ░▒▓██"
|
||||
IMAGE_ROW_SIZE = 28
|
||||
IMAGE_COL_SIZE = 28
|
||||
IMAGE_SIZE = IMAGE_ROW_SIZE*IMAGE_COL_SIZE
|
||||
|
||||
|
||||
def show_picture(img_bytes: Union[bytes, List[int]]):
|
||||
img = Image.new("RGB", (IMAGE_ROW_SIZE, IMAGE_COL_SIZE), "black")
|
||||
pixels = img.load()
|
||||
for i in range(IMAGE_ROW_SIZE):
|
||||
for j in range(IMAGE_COL_SIZE):
|
||||
pixel = img_bytes[IMAGE_ROW_SIZE*i + j]
|
||||
pixels[j, i] = (pixel, pixel, pixel)
|
||||
img.show()
|
||||
|
||||
|
||||
def print_img_to_console(img: Union[bytes, List[int]]):
|
||||
for row_start in range(0, IMAGE_SIZE, IMAGE_ROW_SIZE):
|
||||
print("".join([BOX_SHADING[pixel // 51]*2 for pixel in img[row_start:row_start + IMAGE_ROW_SIZE]]))
|
||||
print()
|
||||
|
||||
|
||||
def read_labels(file_location: str):
|
||||
with open(file_location, 'rb') as img_file:
|
||||
img_data = img_file.read()
|
||||
num_items = int.from_bytes(img_data[4:8], byteorder="big")
|
||||
for i in range(8, num_items):
|
||||
yield int.from_bytes(img_data[i:i + 1], byteorder="big")
|
||||
|
||||
|
||||
def read_imgs(file_location: str, as_bytes=False):
|
||||
with open(file_location, 'rb') as img_file:
|
||||
img_data = img_file.read()
|
||||
num_items = int.from_bytes(img_data[4:8], byteorder="big")
|
||||
num_rows = int.from_bytes(img_data[8:12], byteorder="big")
|
||||
num_cols = int.from_bytes(img_data[12:16], byteorder="big")
|
||||
img_size = num_rows*num_cols
|
||||
start_byte = 16
|
||||
if as_bytes:
|
||||
for end_byte in range(start_byte + img_size, num_items*img_size, img_size):
|
||||
yield img_data[start_byte:end_byte]
|
||||
start_byte = end_byte
|
||||
else:
|
||||
for end_byte in range(start_byte + img_size, num_items*img_size, img_size):
|
||||
yield [pixel for pixel in img_data[start_byte:end_byte]]
|
||||
start_byte = end_byte
|
||||
|
||||
|
||||
def read_img_lbl_pairs(imgs_file: str, lbls_file: str):
|
||||
for img, label in zip(read_imgs(imgs_file), read_labels(lbls_file)):
|
||||
yield img, label
|
||||
|
||||
|
||||
def test_x_y(num: int = -1) -> Callable[[], Generator]:
|
||||
if num == -1:
|
||||
num = 9992
|
||||
|
||||
def generator():
|
||||
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte")):
|
||||
yield img, lbl
|
||||
|
||||
return generator
|
||||
|
||||
|
||||
def train_x_y(num: int = -1) -> Callable[[], Generator]:
|
||||
if num == -1:
|
||||
num = 60000
|
||||
|
||||
def generator():
|
||||
for i, (img, lbl) in zip(range(num), read_img_lbl_pairs("train-images.idx3-ubyte", "train-labels.idx1-ubyte")):
|
||||
yield img, lbl
|
||||
|
||||
return generator
|
||||
Reference in New Issue
Block a user