Files
perceptron_ffnn/GenericTorchMlpNetwork.py

51 lines
2.1 KiB
Python

import torch
import torch.nn as nn
from typing import List, Generator
from custom_types import TrainingBatch, EvaluationResults, DEVICE
from tqdm import tqdm
class GenericTorchMlpClassifier(nn.Module):
def __init__(self, dims_per_layer: List[int], learning_rate: float):
super(GenericTorchMlpClassifier, self).__init__()
self.layers = []
for i, layer_dims in enumerate(dims_per_layer[1:], 1):
self.layers.append(nn.Linear(dims_per_layer[i - 1], layer_dims))
self.loss_fn = nn.CrossEntropyLoss()
self.optimiser = torch.optim.Adam(params=[{"params": layer.parameters()} for layer in self.layers], lr=learning_rate)
self.to(torch.device(DEVICE))
def forward(self, input_batch: List[int]) -> torch.Tensor:
x = torch.tensor(input_batch, dtype=torch.float)
for layer in self.layers[:-1]:
x = torch.sigmoid(layer(x))
x = self.layers[-1](x)
return x
def training_epoch(self, training_data: Generator[TrainingBatch, None, None]) -> None:
self.train(True)
for x, y in tqdm(training_data):
prediction_probs, targets = self.forward(x), torch.tensor(y, dtype=torch.long, device=torch.device(DEVICE))
self.optimiser.zero_grad()
self.loss_fn(prediction_probs, targets).backward()
self.optimiser.step()
def evaluate(self, evaluation_data: Generator[TrainingBatch, None, None]) -> EvaluationResults:
self.train(False)
accumulated_loss = 0.0
total = 0
total_correctly_classified = 0
for x, y in tqdm(evaluation_data):
prediction_probs, targets = self.forward(x), torch.tensor(y, device=torch.device(DEVICE))
predictions = torch.argmax(prediction_probs, dim=1)
total += len(targets)
total_correctly_classified += sum(predictions == targets)
accumulated_loss += self.loss_fn(prediction_probs, targets)
return EvaluationResults(
total=total,
correct=total_correctly_classified,
accumulated_loss=accumulated_loss
)