import torch import torch.nn as nn from typing import List, Generator from custom_types import TrainingBatch, EvaluationResults, DEVICE from tqdm import tqdm class GenericTorchMlpClassifier(nn.Module): def __init__(self, dims_per_layer: List[int], learning_rate: float): super(GenericTorchMlpClassifier, self).__init__() self.layers = [] for i, layer_dims in enumerate(dims_per_layer[1:], 1): self.layers.append(nn.Linear(dims_per_layer[i - 1], layer_dims)) self.loss_fn = nn.CrossEntropyLoss() self.optimiser = torch.optim.Adam(params=[{"params": layer.parameters()} for layer in self.layers], lr=learning_rate) self.to(torch.device(DEVICE)) def forward(self, input_batch: List[int]) -> torch.Tensor: x = torch.tensor(input_batch, dtype=torch.float) for layer in self.layers[:-1]: x = torch.sigmoid(layer(x)) x = self.layers[-1](x) return x def training_epoch(self, training_data: Generator[TrainingBatch, None, None]) -> None: self.train(True) for x, y in tqdm(training_data): prediction_probs, targets = self.forward(x), torch.tensor(y, dtype=torch.long, device=torch.device(DEVICE)) self.optimiser.zero_grad() self.loss_fn(prediction_probs, targets).backward() self.optimiser.step() def evaluate(self, evaluation_data: Generator[TrainingBatch, None, None]) -> EvaluationResults: self.train(False) accumulated_loss = 0.0 total = 0 total_correctly_classified = 0 for x, y in tqdm(evaluation_data): prediction_probs, targets = self.forward(x), torch.tensor(y, device=torch.device(DEVICE)) predictions = torch.argmax(prediction_probs, dim=1) total += len(targets) total_correctly_classified += sum(predictions == targets) accumulated_loss += self.loss_fn(prediction_probs, targets) return EvaluationResults( total=total, correct=total_correctly_classified, accumulated_loss=accumulated_loss )