23 lines
445 B
Python
23 lines
445 B
Python
import torch
|
|
from typing import Tuple, NamedTuple, List, Callable
|
|
import numpy as np
|
|
|
|
TrainingBatch = Tuple[List[List[float]], List[int]]
|
|
|
|
|
|
class LossFun(NamedTuple):
|
|
exec: Callable[[np.array, np.array], float]
|
|
deriv: Callable[[np.array, np.array], np.array]
|
|
|
|
|
|
class EvaluationResults(NamedTuple):
|
|
total: int
|
|
correct: int
|
|
accumulated_loss: float
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
DEVICE = "cuda:0"
|
|
else:
|
|
DEVICE = "cpu"
|