51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
from MlpNetwork import train_and_test_multiclass_perceptron
|
|
from mlp_network import train_and_test_neural_network
|
|
from import_data import show_picture, get_test_data_generator, get_training_data_generator, IMAGE_SIZE
|
|
from GenericTorchMlpNetwork import GenericTorchMlpClassifier
|
|
import argparse
|
|
|
|
|
|
def main():
|
|
args = get_args()
|
|
classifier = GenericTorchMlpClassifier(
|
|
dims_per_layer=[IMAGE_SIZE, 200, 80, 10],
|
|
learning_rate=args.learning_rate,
|
|
)
|
|
for i in range(args.num_epochs):
|
|
print(f"Begin training epoch {i + 1}.")
|
|
classifier.training_epoch(get_training_data_generator(20)())
|
|
results = classifier.evaluate(get_test_data_generator(20)())
|
|
print(f"Evaluation results: {results.correct} / {results.total}",
|
|
f"Accumulated loss = {results.accumulated_loss:.3f}",
|
|
f"Average loss = {results.accumulated_loss / results.correct:.3f}",
|
|
f"Accuracy = {100 * float(results.correct) / float(results.total):.2f}%",
|
|
sep="\n", end="\n\n")
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--num_epochs",
|
|
"-e",
|
|
type=int,
|
|
default=5,
|
|
help="Number of training epochs to undertake."
|
|
)
|
|
parser.add_argument(
|
|
"--learning_rate",
|
|
type=float,
|
|
default=0.001,
|
|
help="Learning rate for the optimiser."
|
|
)
|
|
parser.add_argument(
|
|
"--num_training_samples",
|
|
type=int,
|
|
default=-1,
|
|
help="Number of samples to train with (default = all)."
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|