Training Module

The training module provides utilities for model training, loss functions, and evaluation metrics.

Custom IoU

pynas.train.custom_iou.calculate_iou(logits, targets, num_classes=4)[source]

Losses

class pynas.train.losses.CategoricalCrossEntropyLoss[source]

Bases: Module

__init__()[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(logits, targets)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class pynas.train.losses.FocalLoss(alpha=1, gamma=2, reduction='mean')[source]

Bases: Module

__init__(alpha=1, gamma=2, reduction='mean')[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(logits, targets)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Mean Squared Error

class pynas.train.mean_squared_error.MeanSquaredError[source]

Bases: Metric

__init__()[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

update(preds, target)[source]

Override this method to update the state variables of your metric class.

Parameters:
compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

Early Stopping

class pynas.train.my_early_stopping.TrainEarlyStopping(monitor, min_delta=0.0, patience=3, verbose=False, mode='min', strict=True, check_finite=True, stopping_threshold=None, divergence_threshold=None, check_on_train_epoch_end=None, log_rank_zero_only=False)[source]

Bases: EarlyStopping

Parameters:
  • monitor (str)

  • min_delta (float)

  • patience (int)

  • verbose (bool)

  • mode (str)

  • strict (bool)

  • check_finite (bool)

  • stopping_threshold (float | None)

  • divergence_threshold (float | None)

  • check_on_train_epoch_end (bool | None)

  • log_rank_zero_only (bool)

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()

Fitness Evaluation

class pynas.train.myFit.FitnessEvaluator(alpha=1.0, beta=1.0, gamma=1.0, delta=0.1, epsilon=0.01, lambda_=5.0, target_fps=120.0)[source]

Bases: object

Class to compute fitness functions for balancing FPS and MCC.

__init__(alpha=1.0, beta=1.0, gamma=1.0, delta=0.1, epsilon=0.01, lambda_=5.0, target_fps=120.0)[source]

Initializes the fitness evaluator with tunable parameters.

Parameters:
  • alpha (float) – Weight for FPS in the weighted sum formula.

  • beta (float) – Weight for MCC in the weighted sum formula.

  • gamma (float) – Exponential scaling factor for MCC penalty.

  • delta (float) – Bias term to avoid MCC reaching zero.

  • epsilon (float) – Small value for logarithmic stability.

  • lambda (float) – Sigmoid steepness in inverse MCC penalty.

static rebound_metrics(fps, mcc, target_fps=120.0)[source]
weighted_sum_exponential(fps, mcc)[source]

Computes fitness using weighted sum with exponential MCC penalty.

multiplicative_penalty(fps, mcc)[source]

Computes fitness using FPS multiplied by MCC with a bias.

logarithmic_penalty(fps, mcc)[source]

Computes fitness using logarithmic MCC scaling.

inverse_mcc_penalty(fps, mcc)[source]

Computes fitness by scaling FPS using a sigmoid function of MCC.

stepwise_penalty(fps, mcc)[source]

Computes fitness with a stepwise penalty for negative MCC.