Added generic pytorch implementation for multi layer linear NN.
This commit is contained in:
53
main.py
53
main.py
@@ -1,5 +1,50 @@
|
||||
import torch
|
||||
from multiclass_perceptron import train_and_test_multiclass_perceptron
|
||||
from import_data import show_picture, test_x_y
|
||||
from MlpNetwork import train_and_test_multiclass_perceptron
|
||||
from mlp_network import train_and_test_neural_network
|
||||
from import_data import show_picture, get_test_data_generator, get_training_data_generator, IMAGE_SIZE
|
||||
from GenericTorchMlpNetwork import GenericTorchMlpClassifier
|
||||
import argparse
|
||||
|
||||
train_and_test_multiclass_perceptron()
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
classifier = GenericTorchMlpClassifier(
|
||||
dims_per_layer=[IMAGE_SIZE, 200, 80, 10],
|
||||
learning_rate=args.learning_rate,
|
||||
)
|
||||
for i in range(args.num_epochs):
|
||||
print(f"Begin training epoch {i + 1}.")
|
||||
classifier.training_epoch(get_training_data_generator(20)())
|
||||
results = classifier.evaluate(get_test_data_generator(20)())
|
||||
print(f"Evaluation results: {results.correct} / {results.total}",
|
||||
f"Accumulated loss = {results.accumulated_loss:.3f}",
|
||||
f"Average loss = {results.accumulated_loss / results.correct:.3f}",
|
||||
f"Accuracy = {100 * float(results.correct) / float(results.total):.2f}%",
|
||||
sep="\n", end="\n\n")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--num_epochs",
|
||||
"-e",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of training epochs to undertake."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Learning rate for the optimiser."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_training_samples",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Number of samples to train with (default = all)."
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user