Source code for pynas.train.custom_iou

import torch
    


[docs] def calculate_iou(logits, targets, num_classes=4): with torch.no_grad(): preds = torch.argmax(logits, dim=1) # Convert logits to class predictions targets = torch.argmax(targets, dim=1) # Convert one-hot targets to class labels #print(f"Shape of logits: {logits.shape}") #print(f"Shape of preds: {preds.shape}") #print(f"Shape of targets: {targets.shape}") iou = [] for cls in range(num_classes): pred_mask = (preds == cls) target_mask = (targets == cls) #print(f"Class {cls}:") #print(f"pred_mask shape: {pred_mask.shape}") #print(f"target_mask shape: {target_mask.shape}") intersection = (pred_mask & target_mask).float().sum((1, 2)) union = (pred_mask | target_mask).float().sum((1, 2)) #print(f"intersection: {intersection}") #print(f"union: {union}") iou.append((intersection + 1e-6) / (union + 1e-6)) mean_iou = torch.stack(iou).mean(dim=0) # Mean IoU across all classes #print(f"Mean IoU: {mean_iou}") return mean_iou