from MlpNetwork import train_and_test_multiclass_perceptron from mlp_network import train_and_test_neural_network from import_data import show_picture, get_test_data_generator, get_training_data_generator, IMAGE_SIZE from GenericTorchMlpNetwork import GenericTorchMlpClassifier import argparse def main(): args = get_args() classifier = GenericTorchMlpClassifier( dims_per_layer=[IMAGE_SIZE, 200, 80, 10], learning_rate=args.learning_rate, ) for i in range(args.num_epochs): print(f"Begin training epoch {i + 1}.") classifier.training_epoch(get_training_data_generator(20)()) results = classifier.evaluate(get_test_data_generator(20)()) print(f"Evaluation results: {results.correct} / {results.total}", f"Accumulated loss = {results.accumulated_loss:.3f}", f"Average loss = {results.accumulated_loss / results.correct:.3f}", f"Accuracy = {100 * float(results.correct) / float(results.total):.2f}%", sep="\n", end="\n\n") def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--num_epochs", "-e", type=int, default=5, help="Number of training epochs to undertake." ) parser.add_argument( "--learning_rate", type=float, default=0.001, help="Learning rate for the optimiser." ) parser.add_argument( "--num_training_samples", type=int, default=-1, help="Number of samples to train with (default = all)." ) return parser.parse_args() if __name__ == "__main__": main()