Source code for pynas.blocks.heads
# Description: Contains custom head layers for neural networks.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional
[docs]
class Dropout(nn.Module):
"""
Dropout layer for regularization in neural networks.
Args:
p (float, optional): Probability of an element to be zeroed. Default: 0.5
inplace (bool, optional): If set to True, will do this operation in-place. Default: False
"""
[docs]
def __init__(self, p=0.5, inplace=False):
super(Dropout, self).__init__()
self.p = p
self.inplace = inplace
[docs]
class MultiInputClassifier(nn.Module):
"""
A PyTorch module for a multi-input classifier that processes multiple input tensors
with different shapes and combines their features for classification.
Args:
input_shapes (List[Tuple[int, ...]]): A list of shapes for each input tensor,
excluding the batch dimension. Each shape can be either (C, H, W) for spatial
inputs or (D,) for flat vector inputs.
common_dim (int, optional): The dimension to which all inputs are projected.
Defaults to 256.
mlp_depth (int, optional): The number of layers in the final MLP classifier.
Defaults to 2.
mlp_hidden_dim (int, optional): The number of hidden units in each MLP layer.
Defaults to 512.
num_classes (int, optional): The number of output classes for classification.
Defaults to 10.
use_adaptive_pool (bool, optional): Whether to apply adaptive average pooling
for spatial inputs. Defaults to True.
pool_size (Tuple[int, int], optional): The target size for adaptive pooling
if it is used. Defaults to (4, 4).
Attributes:
projections (nn.ModuleList): A list of projection modules for each input tensor.
These modules transform the inputs to the common dimension.
flatten (nn.Flatten): A module to flatten the projected tensors.
total_input_dim (int): The total input dimension after concatenating all
projected tensors.
classifier (nn.Sequential): The MLP classifier that processes the concatenated
features and outputs class probabilities.
Methods:
forward(inputs: List[torch.Tensor]) -> torch.Tensor:
Processes the input tensors, projects them to a common dimension, concatenates
their features, and passes them through the MLP classifier to produce the
output logits.
Raises:
ValueError: If an input shape is not supported (e.g., not (C, H, W) or (D,)).
"""
[docs]
def __init__(
self,
input_shapes: List[Tuple[int, ...]], # shapes of each input tensor (excluding batch dim)
common_dim: int = 256, # project all inputs to this dim
mlp_depth: int = 2, # depth of final MLP
mlp_hidden_dim: int = 512, # hidden units in MLP
num_classes: int = 10, # number of output classes
use_adaptive_pool: bool = True, # apply adaptive pooling for spatial inputs
pool_size: Tuple[int, int] = (4, 4) # target size if pooling is used
):
super().__init__()
self.projections = nn.ModuleList()
self.flatten = nn.Flatten(start_dim=1)
for shape in input_shapes:
in_dim = shape[0]
if len(shape) == 3: # (C, H, W)
proj = nn.Sequential(
nn.AdaptiveAvgPool2d(pool_size) if use_adaptive_pool else nn.Identity(),
nn.Conv2d(in_dim, common_dim, kernel_size=1),
)
out_dim = common_dim * pool_size[0] * pool_size[1]
elif len(shape) == 1: # flat vector
proj = nn.Sequential(
nn.Linear(in_dim, common_dim),
)
out_dim = common_dim
else:
raise ValueError(f"Unsupported input shape: {shape}")
self.projections.append(proj)
self.total_input_dim = sum([
common_dim * pool_size[0] * pool_size[1] if len(shape) == 3 else common_dim
for shape in input_shapes
])
# Build MLP
layers = []
in_dim = self.total_input_dim
for i in range(mlp_depth - 1):
layers.append(nn.Linear(in_dim, mlp_hidden_dim))
layers.append(nn.ReLU())
in_dim = mlp_hidden_dim
layers.append(nn.Linear(in_dim, num_classes))
self.classifier = nn.Sequential(*layers)
[docs]
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
projected = []
for x, proj in zip(inputs, self.projections):
x = proj(x)
x = self.flatten(x)
projected.append(x)
x = torch.cat(projected, dim=1)
return self.classifier(x)
[docs]
class Classifier:
[docs]
def __init__(self, encoder, dm, verbose=False):
# Validate that dm has the necessary attributes.
if not hasattr(dm, "num_classes") or not hasattr(dm, "input_shape"):
raise ValueError("dm must have 'num_classes' and 'input_shape' attributes.")
self.encoder = encoder
self.num_classes = dm.num_classes
self.input_shape = dm.input_shape
self.verbose = verbose
if self.verbose:
print(f"Input shape: {self.input_shape}")
# Verify that encoder has parameters.
try:
next(self.encoder.parameters())
except StopIteration:
raise ValueError("Encoder appears to have no parameters.")
except Exception as e:
raise ValueError("Provided encoder does not follow expected API.") from e
# Validate input_shape is a tuple and properly dimensioned.
if not isinstance(self.input_shape, tuple):
raise TypeError("input_shape must be a tuple.")
if len(self.input_shape) == 3:
if self.verbose:
print("Adding channel dimension to input shape.")
print(f"Original input shape: {self.input_shape}")
self.input_shape = (1,) + self.input_shape
if self.verbose:
print(f"Updated input shape: {self.input_shape}")
elif len(self.input_shape) != 4:
raise ValueError("input_shape must be of length 3 or 4.")
self.head_layer = self.build_head(input_shape=self.input_shape)
self.model = nn.Sequential(
self.encoder,
self.head_layer
)
self.valid_model = self.dummy_test()
[docs]
def build_head(self, input_shape=(1, 2, 256, 256)):
# Get the device from the encoder's parameters.
try:
device = next(self.encoder.parameters()).device
except Exception as e:
raise ValueError("Unable to determine device from encoder parameters.") from e
# Run a dummy input through the encoder to get the feature shape.
dummy = torch.randn(*input_shape).float().to(device)
try:
features = self.encoder(dummy)
except Exception as e:
raise RuntimeError("Error when running dummy input through encoder.") from e
if not isinstance(features, torch.Tensor):
raise TypeError("Encoder output should be a torch.Tensor.")
if self.verbose:
print("Feature map shape from the feature extractor:", features.shape)
# Check that the features tensor has at least 2 dimensions.
if features.dim() < 2:
raise ValueError("Encoded features should have at least 2 dimensions.")
# Determine the number of channels from the dummy output.
feature_channels = features.shape[1]
# Build the head layer.
head_layer = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Linear(feature_channels, self.num_classes)
)
if self.verbose:
print("Constructed head layer:", head_layer)
return head_layer
[docs]
def dummy_test(self):
try:
device = next(self.encoder.parameters()).device
dummy = torch.randn(*self.input_shape).float().to(device)
output = self.model(dummy)
if self.verbose:
print("Network test passed. Output shape from the model:", output.shape)
if not isinstance(output, torch.Tensor):
raise TypeError("Output of the model should be a torch.Tensor.")
if output.shape[0] != dummy.shape[0]:
raise ValueError("Batch size mismatch between input and output.")
return True
except Exception as e:
if self.verbose:
print("An error occurred during dummy_test:", e)
return False
[docs]
def forward(self, x):
if not isinstance(x, torch.Tensor):
raise TypeError("Input must be a torch.Tensor.")
return self.model(x)