Files
perceptron_ffnn/main.py

51 lines
1.6 KiB
Python

from MlpNetwork import train_and_test_multiclass_perceptron
from mlp_network import train_and_test_neural_network
from import_data import show_picture, get_test_data_generator, get_training_data_generator, IMAGE_SIZE
from GenericTorchMlpNetwork import GenericTorchMlpClassifier
import argparse
def main():
args = get_args()
classifier = GenericTorchMlpClassifier(
dims_per_layer=[IMAGE_SIZE, 200, 80, 10],
learning_rate=args.learning_rate,
)
for i in range(args.num_epochs):
print(f"Begin training epoch {i + 1}.")
classifier.training_epoch(get_training_data_generator(20)())
results = classifier.evaluate(get_test_data_generator(20)())
print(f"Evaluation results: {results.correct} / {results.total}",
f"Accumulated loss = {results.accumulated_loss:.3f}",
f"Average loss = {results.accumulated_loss / results.correct:.3f}",
f"Accuracy = {100 * float(results.correct) / float(results.total):.2f}%",
sep="\n", end="\n\n")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_epochs",
"-e",
type=int,
default=5,
help="Number of training epochs to undertake."
)
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="Learning rate for the optimiser."
)
parser.add_argument(
"--num_training_samples",
type=int,
default=-1,
help="Number of samples to train with (default = all)."
)
return parser.parse_args()
if __name__ == "__main__":
main()