Added generic pytorch implementation for multi layer linear NN.
This commit is contained in:
50
GenericTorchMlpNetwork.py
Normal file
50
GenericTorchMlpNetwork.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import List, Generator
|
||||
from custom_types import TrainingBatch, EvaluationResults, DEVICE
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class GenericTorchMlpClassifier(nn.Module):
|
||||
def __init__(self, dims_per_layer: List[int], learning_rate: float):
|
||||
super(GenericTorchMlpClassifier, self).__init__()
|
||||
self.layers = []
|
||||
for i, layer_dims in enumerate(dims_per_layer[1:], 1):
|
||||
self.layers.append(nn.Linear(dims_per_layer[i - 1], layer_dims))
|
||||
self.loss_fn = nn.CrossEntropyLoss()
|
||||
self.optimiser = torch.optim.Adam(params=[{"params": layer.parameters()} for layer in self.layers], lr=learning_rate)
|
||||
self.to(torch.device(DEVICE))
|
||||
|
||||
def forward(self, input_batch: List[int]) -> torch.Tensor:
|
||||
x = torch.tensor(input_batch, dtype=torch.float)
|
||||
for layer in self.layers[:-1]:
|
||||
x = torch.sigmoid(layer(x))
|
||||
x = self.layers[-1](x)
|
||||
return x
|
||||
|
||||
def training_epoch(self, training_data: Generator[TrainingBatch, None, None]) -> None:
|
||||
self.train(True)
|
||||
for x, y in tqdm(training_data):
|
||||
prediction_probs, targets = self.forward(x), torch.tensor(y, dtype=torch.long, device=torch.device(DEVICE))
|
||||
self.optimiser.zero_grad()
|
||||
self.loss_fn(prediction_probs, targets).backward()
|
||||
self.optimiser.step()
|
||||
|
||||
def evaluate(self, evaluation_data: Generator[TrainingBatch, None, None]) -> EvaluationResults:
|
||||
self.train(False)
|
||||
accumulated_loss = 0.0
|
||||
total = 0
|
||||
total_correctly_classified = 0
|
||||
for x, y in tqdm(evaluation_data):
|
||||
prediction_probs, targets = self.forward(x), torch.tensor(y, device=torch.device(DEVICE))
|
||||
predictions = torch.argmax(prediction_probs, dim=1)
|
||||
total += len(targets)
|
||||
total_correctly_classified += sum(predictions == targets)
|
||||
accumulated_loss += self.loss_fn(prediction_probs, targets)
|
||||
return EvaluationResults(
|
||||
total=total,
|
||||
correct=total_correctly_classified,
|
||||
accumulated_loss=accumulated_loss
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user