Source code for pynas.train.mean_squared_error

import torch
from torchmetrics import Metric

[docs] class MeanSquaredError(Metric):
[docs] def __init__(self): super().__init__() self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
[docs] def update(self, preds: torch.Tensor, target: torch.Tensor): sum_squared_error = torch.sum((preds - target) ** 2) total = target.numel() self.sum_squared_error += sum_squared_error self.total += total
[docs] def compute(self): return self.sum_squared_error / self.total