51 lines
2.1 KiB
Python
51 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from typing import List, Generator
|
|
from custom_types import TrainingBatch, EvaluationResults, DEVICE
|
|
from tqdm import tqdm
|
|
|
|
|
|
class GenericTorchMlpClassifier(nn.Module):
|
|
def __init__(self, dims_per_layer: List[int], learning_rate: float):
|
|
super(GenericTorchMlpClassifier, self).__init__()
|
|
self.layers = []
|
|
for i, layer_dims in enumerate(dims_per_layer[1:], 1):
|
|
self.layers.append(nn.Linear(dims_per_layer[i - 1], layer_dims))
|
|
self.loss_fn = nn.CrossEntropyLoss()
|
|
self.optimiser = torch.optim.Adam(params=[{"params": layer.parameters()} for layer in self.layers], lr=learning_rate)
|
|
self.to(torch.device(DEVICE))
|
|
|
|
def forward(self, input_batch: List[int]) -> torch.Tensor:
|
|
x = torch.tensor(input_batch, dtype=torch.float)
|
|
for layer in self.layers[:-1]:
|
|
x = torch.sigmoid(layer(x))
|
|
x = self.layers[-1](x)
|
|
return x
|
|
|
|
def training_epoch(self, training_data: Generator[TrainingBatch, None, None]) -> None:
|
|
self.train(True)
|
|
for x, y in tqdm(training_data):
|
|
prediction_probs, targets = self.forward(x), torch.tensor(y, dtype=torch.long, device=torch.device(DEVICE))
|
|
self.optimiser.zero_grad()
|
|
self.loss_fn(prediction_probs, targets).backward()
|
|
self.optimiser.step()
|
|
|
|
def evaluate(self, evaluation_data: Generator[TrainingBatch, None, None]) -> EvaluationResults:
|
|
self.train(False)
|
|
accumulated_loss = 0.0
|
|
total = 0
|
|
total_correctly_classified = 0
|
|
for x, y in tqdm(evaluation_data):
|
|
prediction_probs, targets = self.forward(x), torch.tensor(y, device=torch.device(DEVICE))
|
|
predictions = torch.argmax(prediction_probs, dim=1)
|
|
total += len(targets)
|
|
total_correctly_classified += sum(predictions == targets)
|
|
accumulated_loss += self.loss_fn(prediction_probs, targets)
|
|
return EvaluationResults(
|
|
total=total,
|
|
correct=total_correctly_classified,
|
|
accumulated_loss=accumulated_loss
|
|
)
|
|
|
|
|