Source code for pynas.core.generic_unet

import torch
import torch.nn as nn
import torch.nn.functional as F
import configparser
import inspect

from ..blocks import convolutions, pooling, activations


[docs] class UNetDecoder(nn.Module): """ A PyTorch implementation of a U-Net decoder module. This class implements the decoder part of the U-Net architecture, which reconstructs the output from the bottleneck features by progressively upsampling and combining them with skip connections from the encoder. Attributes: num_stages (int): The number of decoding stages, equal to the number of skip connections. up_convs (nn.ModuleList): A list of transposed convolution layers for upsampling. conv_blocks (nn.ModuleList): A list of convolutional blocks for processing concatenated upsampled and skip connection features. out_conv (nn.Conv2d): The final convolutional layer that produces the output. Args: encoder_shapes (list of torch.Size): A list of shapes of the encoder features in the order: [skip0, skip1, ..., skip_(N-1), bottleneck]. Each shape is expected to be a `torch.Size` object. num_classes (int, optional): The number of output classes. Default is 2. Methods: forward(encoder_features): Performs the forward pass of the decoder. Args: encoder_features (list of torch.Tensor): A list of encoder feature maps in the order: [skip0, skip1, ..., skip_(N-1), bottleneck]. The number of feature maps must match the number of stages + 1. Returns: torch.Tensor: The output tensor after decoding. """
[docs] def __init__(self, encoder_shapes, num_classes=2, output_shape=None): super(UNetDecoder, self).__init__() # Expecting encoder_shapes as a list of torch.Size objects in order: # [skip0, skip1, ..., skip_(N-1), bottleneck] # Number of decoding stages equals the number of skip connections. self.num_stages = len(encoder_shapes) - 1 # Build upsampling and convolution blocks dynamically. self.up_convs = nn.ModuleList() self.conv_blocks = nn.ModuleList() # Initial number of channels from the bottleneck. in_channels = encoder_shapes[-1][1] # Iterate over skip connections in reverse order. for i in range(self.num_stages - 1, -1, -1): skip_channels = encoder_shapes[i][1] # Store the expected output size for each upsampling operation out_h, out_w = encoder_shapes[i][2], encoder_shapes[i][3] # Replace transposed convolution with upsampling + conv self.up_convs.append( nn.Sequential( nn.Upsample(size=(out_h, out_w), mode='bilinear', align_corners=False), nn.Conv2d(in_channels, skip_channels, kernel_size=1) ) ) # The conv block takes the concatenated tensor (upsampled + skip) as input. self.conv_blocks.append( nn.Sequential( nn.Conv2d(skip_channels * 2, skip_channels, kernel_size=3, padding=1), nn.BatchNorm2d(skip_channels), nn.ReLU(inplace=True), nn.Conv2d(skip_channels, skip_channels, kernel_size=3, padding=1), nn.BatchNorm2d(skip_channels), nn.ReLU(inplace=True) ) ) # Update in_channels to be the skip_channels for the next stage. in_channels = skip_channels # Final output convolution followed by upsampling to match the input height and width. self.out_conv = nn.Sequential( nn.Conv2d(in_channels, num_classes, kernel_size=1), nn.Upsample(size=output_shape, mode='bilinear', align_corners=False) )
[docs] def forward(self, encoder_features, verbose=False): # encoder_features: list in order [skip0, skip1, ..., skip_(N-1), bottleneck] if verbose: print("=" * 100) print(f"{'Encoder Feature Shapes':^100}") print(encoder_features.shape) print("=" * 100) if len(encoder_features) != self.num_stages + 1: raise ValueError(f"Expected {self.num_stages + 1} encoder features, but got {len(encoder_features)}.") x = encoder_features[-1] # start with bottleneck # For each decoding stage, use the corresponding skip connection in reverse order. for i in range(self.num_stages): # Skip connection index: from last skip to the first. skip = encoder_features[self.num_stages - 1 - i] x = self.up_convs[i](x) # Match spatial dimensions with skip connection (just in case upsampling mismatch occurs) if x.shape[2:] != skip.shape[2:]: x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, skip], dim=1) x = self.conv_blocks[i](x) output = self.out_conv(x) return output
[docs] class GenericUNetNetwork(nn.Module): """GenericUNetNetwork is a PyTorch-based implementation of a generic U-Net architecture. This class allows for flexible construction of U-Net models by parsing layer configurations and dynamically building the encoder and decoder components. parsed_layers (list): A list of layer configurations for building the encoder. input_channels (int): Number of input channels for the input tensor. Default is 3. input_height (int): Height of the input tensor. Default is 256. input_width (int): Width of the input tensor. Default is 256. num_classes (int): Number of output classes for the segmentation task. Default is 2. max_params (int): Maximum allowed number of parameters for the model. Default is 200,000,000. encoder (nn.ModuleList): A list of layers forming the encoder part of the U-Net. decoder (nn.ModuleList): A list of layers forming the decoder part of the U-Net. encoder_shapes (list): A list of shapes of the encoder outputs for use in the decoder. total_params (int): Total number of parameters in the model. config (ConfigParser): Configuration parser for reading additional settings from 'config.ini'. Methods: __init__(self, parsed_layers, input_channels=3, input_height=256, input_width=256, num_classes=2, MaxParams=200_000_000): Initializes the GenericUNetNetwork with the given parameters and builds the encoder and decoder. encoder_forward(self, x, features_only=True): Performs a forward pass through the encoder and optionally returns only the encoder features. _encoder_shapes_tracing(self): Creates a dummy forward pass through the encoder to determine the shapes of the encoder outputs. _build_encoder(self, parsed_layers): _build_decoder(self): Builds the decoder component of the U-Net model using the encoder shapes and number of output classes. forward(self, x): Defines the forward pass of the model, passing the input through the encoder and decoder. get_activation_fn(activation): Retrieves the specified activation function from the `activations` module. """
[docs] def __init__(self, parsed_layers, input_channels=3, input_height=256, input_width=256, num_classes=2, MaxParams=200_000_00, encoder_only=False): super(GenericUNetNetwork, self).__init__() self.config = configparser.ConfigParser() self.config.read('config.ini') self.parsed_layers = parsed_layers self.num_classes = num_classes self.total_params = 0 self.max_params = MaxParams self.input_channels = input_channels self.input_height = input_height self.input_width = input_width self.encoder = nn.ModuleList() self.decoder = nn.ModuleList() self.encoder_shapes = [] # Lists to record all encoder shapes. (batch, channels, height, width) per encoder layer. # Encder Building: self._build_encoder(parsed_layers) if encoder_only: self.decoder = nn.Identity() # Set decoder to identity function. else: # Decoder Building: if self.encoder is not None: self._build_decoder() else: raise ValueError("Error building encoder. Decoder not built.")
[docs] def encoder_forward(self, x, features_only=True): encoder_outputs = [] for idx, layer in enumerate(self.encoder): x = layer(x) conv_layer_types = tuple(list_convolution_layers()) if isinstance(layer, conv_layer_types): encoder_outputs.append(x.clone()) if features_only: return encoder_outputs else: return x
[docs] def _encoder_shapes_tracing(self): """ Creates a dummy forward pass through the encoder to determine output shapes. This method generates a random tensor with shape (1, 3, 256, 256) and passes it through the encoder with features_only=True to obtain the output tensors. It then collects the shapes of all output tensors for network architecture analysis. Returns: list: A list containing the shapes of the encoder outputs. For tensor outputs, their shapes are directly included. For list outputs, a list of their individual tensor shapes is included. """ dummy_input = torch.randn(1, self.input_channels, self.input_height, self.input_width) output = self.encoder_forward(dummy_input, features_only=True) output_shapes = [] for o in output: if isinstance(o, torch.Tensor): output_shapes.append(o.shape) elif isinstance(o, list): output_shapes.append([item.shape for item in o]) return output_shapes
[docs] def _build_encoder(self, parsed_layers, verbose=False): """ Builds the encoder part of the U-Net model based on the parsed layer configurations. Args: parsed_layers (list): A list of layer configurations to be used for building the encoder. Each configuration specifies the type and parameters of the layer. Raises: AssertionError: If the output dimensions (height or width) of any layer are zero or negative. AssertionError: If the total number of parameters exceeds the maximum allowed (`self.max_params`). Exception: If any other error occurs during the encoder construction, it is caught, and the encoder is set to None. Notes: - The method iterates through the parsed layers, constructs each layer using the `build_layer` function, and appends it to the encoder (`self.encoder`). - The method updates the current dimensions (`self.current_channels`, `self.current_height`, `self.current_width`) after each layer is built. - Skip connections are recorded for layers that produce activation features intended for use in the decoder. - The total number of parameters is tracked and validated against the maximum allowed limit. - The shapes of the encoder layers are traced and stored in `self.encoder_shapes` for use in the decoder. """ # Use local variables for current dims. if verbose: print("=" * 100) print(f"{'Building U-Net':^100}") print("=" * 100) self.encoder = nn.ModuleList() self.current_channels = self.input_channels self.current_height = self.input_height self.current_width = self.input_width if verbose: print(f"Initial Input Dimensions: Channels={self.current_channels}, Height={self.current_height}, Width={self.current_width}") try: for idx, layer in enumerate(parsed_layers): result = build_layer(layer, self.config, self.current_channels, self.current_height, self.current_width, idx, self.get_activation_fn) layer_inst, self.current_channels, self.current_height, self.current_width = result # Assert that the output dimensions are valid. assert self.current_height > 0 and self.current_width > 0, f"Invalid output dimensions: height ({self.current_height}) width ({self.current_width}) is zero." self.encoder.append(layer_inst) self.total_params += sum(p.numel() for p in layer_inst.parameters()) assert self.total_params <= self.max_params, f"Exceeded parameter limit. P: {self.total_params:,} > M: {self.max_params:,}" # Tracing shapes to be used in decoder. self.encoder_shapes = self._encoder_shapes_tracing() if verbose: print("_" * 100) print(f"{'U-Net Encoder Built Successfully!':^100}") print("_" * 100) print(f"{'- Total Encoder Parameters:':<25} {self.total_params:,}") print("=" * 100) except Exception as e: # self.encoder = None if verbose: print(f"Error building encoder: {e}") raise e
[docs] def _build_decoder(self, verbose=False): """ Builds the decoder component of the U-Net model. This method initializes the decoder using the `UNetDecoder` class, passing the encoder shapes and the number of output classes as parameters. It also updates the total parameter count by summing the number of parameters in the decoder. Attributes: self.decoder (UNetDecoder): The decoder instance for the U-Net model. self.total_params (int): The total number of parameters in the model, updated to include the decoder's parameters. """ output_shape = (self.input_height, self.input_width) self.decoder = UNetDecoder(self.encoder_shapes, self.num_classes, output_shape) #self.decoder = UNetDecoder(self.encoder_shapes, self.num_classes) self.total_params += sum(p.numel() for p in self.decoder.parameters()) if verbose: print(f"{'U-Net Decoder Built Successfully!':^100}") print("_" * 100) print(f"{'- Total Parameters:':<25}{self.total_params:,}") print("=" * 100)
[docs] def forward(self, x): """ Defines the forward pass of the model. Args: x (torch.Tensor): Input tensor to the model. Returns: torch.Tensor: Output tensor after passing through the encoder and decoder. """ x = self.encoder_forward(x) # Here is x is a list of encoder outputs. x = self.decoder(x) return x
[docs] @staticmethod def get_activation_fn(activation): """ Retrieves the specified activation function from the `activations` module. Args: activation (str): The name of the activation function to retrieve. Returns: Callable: The activation function corresponding to the given name. If the specified activation function is not found, defaults to `activations.ReLU`. """ return getattr(activations, activation, activations.ReLU)
[docs] def list_convolution_layers(): """ Retrieves a list of all classes defined in the `convolutions` module. This function uses the `inspect` module to dynamically inspect the `convolutions` module and collect all objects that are classes. Returns: list: A list of class objects defined in the `convolutions` module. """ # List all classes defined in the convolutions module return [obj for name, obj in inspect.getmembers(convolutions, inspect.isclass)]
[docs] def build_layer(layer, config, current_channels, current_height, current_width, idx, get_activation_fn): """ Builds a neural network layer based on the provided configuration. Args: layer (dict): A dictionary containing the layer configuration. Must include the key 'layer_type' which specifies the type of layer to build. config (dict): A dictionary containing default configurations for various layer types. current_channels (int): The number of input channels to the layer. current_height (int): The height of the input tensor to the layer. current_width (int): The width of the input tensor to the layer. idx (int): The index of the layer in the model (used for debugging or logging purposes). get_activation_fn (callable): A function that takes an activation name (str) and returns the corresponding activation function. Returns: tuple: A tuple containing: - layer_inst (nn.Module): The instantiated layer object. - current_channels (int): The number of output channels after the layer. - current_height (int): The height of the output tensor after the layer. - current_width (int): The width of the output tensor after the layer. Raises: ValueError: If the 'layer_type' in the layer dictionary is unknown or unsupported. Supported Layer Types: - 'ConvAct', 'ConvBnAct', 'ConvSE': Convolutional layers with optional batch normalization and activation. - 'MBConv', 'MBConvNoRes': MobileNetV2-style inverted residual blocks. - 'CSPConvBlock', 'CSPMBConvBlock': Cross Stage Partial blocks for convolution or MBConv. - 'DenseNetBlock': DenseNet-style block with concatenated outputs. - 'ResNetBlock': ResNet-style residual block. - 'AvgPool', 'MaxPool': Pooling layers (average or max pooling). - 'Dropout': Dropout layer for regularization. """ lt = layer['layer_type'] if lt in ['ConvAct', 'ConvBnAct', 'ConvSE']: kernel_size, stride, padding, out_channels, new_height, new_width = parse_conv_params( layer, config, lt, current_channels, current_height, current_width ) conv_cls = { 'ConvAct': convolutions.ConvAct, 'ConvBnAct': convolutions.ConvBnAct, 'ConvSE': convolutions.ConvSE, }[lt] layer_inst = conv_cls( in_channels=current_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, activation=get_activation_fn(layer['activation']), ) current_channels, current_height, current_width = out_channels, new_height, new_width assert new_height > 0 and new_width > 0, f"Invalid output dimensions: height ({new_height}) width ({new_width}) is zero." elif lt in ['MBConv', 'MBConvNoRes']: expansion_factor = int(layer.get('expansion_factor', config[lt]['default_expansion_factor'])) dw_kernel_size = int(layer.get('dw_kernel_size', config[lt]['default_dw_kernel_size'])) conv_cls = {'MBConv': convolutions.MBConv, 'MBConvNoRes': convolutions.MBConvNoRes}[lt] layer_inst = conv_cls( in_channels=current_channels, out_channels=current_channels, expansion_factor=expansion_factor, dw_kernel_size=dw_kernel_size, activation=get_activation_fn(layer['activation']), ) elif lt in ['CSPConvBlock', 'CSPMBConvBlock']: num_blocks = int(layer.get('num_blocks', config[lt]['default_num_blocks'])) # Get out_channels from conv params helper even if we don't use all values. _, _, _, out_channels, _, _ = parse_conv_params(layer, config, lt, current_channels, current_height, current_width) if lt == 'CSPConvBlock': layer_inst = convolutions.CSPConvBlock( in_channels=current_channels, num_blocks=num_blocks, activation=get_activation_fn(layer['activation']), ) else: expansion_factor = int(layer.get('expansion_factor', config[lt]['default_expansion_factor'])) dw_kernel_size = int(layer.get('dw_kernel_size', config[lt]['default_dw_kernel_size'])) layer_inst = convolutions.CSPMBConvBlock( in_channels=current_channels, expansion_factor=expansion_factor, dw_kernel_size=dw_kernel_size, num_blocks=num_blocks, activation=get_activation_fn(layer['activation']), ) current_channels = out_channels elif lt == 'DenseNetBlock': out_channels_coeff = float(layer.get('out_channels_coefficient', config['DenseNetBlock']['default_out_channels_coefficient'])) out_channels = int(current_channels * out_channels_coeff) layer_inst = convolutions.DenseNetBlock( in_channels=current_channels, out_channels=out_channels, activation=get_activation_fn(layer['activation']), ) current_channels += out_channels elif lt == 'ResNetBlock': _ = float(layer.get('out_channels_coefficient', config['ResNetBlock']['default_out_channels_coefficient'])) layer_inst = convolutions.ResNetBlock( in_channels=current_channels, out_channels=current_channels, activation=get_activation_fn(layer['activation']), ) elif lt in ['AvgPool', 'MaxPool']: pool_cls = pooling.AvgPool if lt == 'AvgPool' else pooling.MaxPool kernel_size = int(layer.get('kernel_size', config[lt]['default_kernel_size'])) stride = int(layer.get('tride', config[lt]['default_stride'])) layer_inst = pool_cls(kernel_size=kernel_size, stride=stride) current_height = ((current_height - kernel_size) // stride) + 1 current_width = ((current_width - kernel_size) // stride) + 1 elif lt == 'Dropout': dropout_rate = float(layer.get('dropout_rate', config['Dropout']['default_dropout_rate'])) layer_inst = nn.Dropout(p=dropout_rate) else: raise ValueError(f"Unknown layer type: {lt}") return layer_inst, current_channels, current_height, current_width
[docs] def parse_conv_params(layer, config, key, current_channels, current_height, current_width): """ Parse convolutional layer parameters and calculate output dimensions. This function extracts parameters for a convolutional layer from the provided configuration, calculates the output dimensions, and returns all necessary values for setting up a convolutional layer. Args: layer (dict): Dictionary containing layer-specific configuration parameters. config (dict): Dictionary containing default configuration parameters. key (str): Key to access specific configurations within the config dictionary. current_channels (int): Number of input channels for the current layer. current_height (int): Height of the input feature map. current_width (int): Width of the input feature map. Returns: tuple: A tuple containing: - kernel_size (int): Size of the convolutional kernel. - stride (int): Stride of the convolution. - padding (int): Padding added to input feature map. - out_channels (int): Number of output channels. - new_height (int): Height of the output feature map after convolution. - new_width (int): Width of the output feature map after convolution. """ kernel_size = int(layer.get('kernel_size', config[key]['default_kernel_size'])) stride = int(layer.get('stride', config[key]['default_stride'])) padding = int(layer.get('padding', config[key]['default_padding'])) out_channels_coeff = float(layer.get('out_channels_coefficient', config[key]['default_out_channels_coefficient'])) out_channels = int(current_channels * out_channels_coeff) new_height = ((current_height - kernel_size + 2 * padding) // stride) + 1 new_width = ((current_width - kernel_size + 2 * padding) // stride) + 1 return kernel_size, stride, padding, out_channels, new_height, new_width