import torch
import torch.nn as nn
from .activations import ReLU
# Classic Conv
[docs]
class ConvAct(nn.Sequential):
[docs]
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, activation=ReLU):
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
activation(),
)
[docs]
class ConvBnAct(nn.Sequential):
[docs]
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, activation=ReLU):
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.BatchNorm2d(num_features=out_channels),
activation(),
)
[docs]
class SEBlock(nn.Module):
[docs]
def __init__(self, in_channels, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduced_channels = max(in_channels // reduction, 1) # Ensure at least 1 output feature
self.fc = nn.Sequential(
nn.Linear(in_channels, reduced_channels, bias=False),
nn.ReLU(inplace=True),
nn.Linear(reduced_channels, in_channels, bias=False),
nn.Sigmoid()
)
[docs]
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
[docs]
class ConvSE(nn.Sequential):
[docs]
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, activation=ReLU):
super().__init__(
ConvBnAct(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, activation=activation),
SEBlock(reduction=16, in_channels=out_channels),
)
# MBConv Inverted
[docs]
class MBConv(nn.Module):
[docs]
def __init__(self, in_channels, out_channels, dw_kernel_size=3, expansion_factor=4, activation=ReLU):
expanded_channels = in_channels * expansion_factor
super().__init__()
self.steps = nn.Sequential(
# Narrow to wide
ConvBnAct(in_channels, expanded_channels, kernel_size=1, stride=1, padding=0, activation=activation),
# Wide to wide (depthwise convolution)
nn.Conv2d(expanded_channels, expanded_channels, kernel_size=dw_kernel_size, stride=1, padding=1, groups=expanded_channels),
nn.BatchNorm2d(expanded_channels),
activation(),
# Wide to narrow
ConvBnAct(expanded_channels, out_channels, kernel_size=1, stride=1, padding=0, activation=activation)
)
[docs]
def forward(self, x):
res = x
x = self.steps(x)
x = torch.add(x, res)
return x
[docs]
class MBConvNoRes(nn.Module):
[docs]
def __init__(self, in_channels, out_channels, dw_kernel_size=3, expansion_factor=4, activation=ReLU):
expanded_channels = in_channels * expansion_factor
super().__init__()
self.steps = nn.Sequential(
# Narrow to wide
ConvBnAct(in_channels, expanded_channels, kernel_size=1, stride=1, padding=0, activation=activation),
# Wide to wide (depthwise convolution)
nn.Conv2d(expanded_channels, expanded_channels, kernel_size=dw_kernel_size, stride=1, padding=1, groups=expanded_channels),
nn.BatchNorm2d(expanded_channels),
activation(),
# Wide to narrow
ConvBnAct(expanded_channels, out_channels, kernel_size=1, stride=1, padding=0, activation=activation)
)
[docs]
def forward(self, x):
x = self.steps(x)
return x
# CSP Conv
[docs]
class CSPConvBlock(nn.Module):
[docs]
def __init__(self, in_channels, num_blocks=1, activation=ReLU):
super().__init__()
# Use the same value for hidden_channels to avoid the issue with odd in_channels
self.main_channels = in_channels // 2
self.shortcut_channels = in_channels-self.main_channels
# Main path (processed part)
self.main_path = nn.Sequential(
*[ConvBnAct(
in_channels=self.main_channels,
out_channels=self.main_channels,
activation=activation,
) for _ in range(num_blocks)],
)
# Shortcut path is just a passthrough
self.shortcut_path = nn.Identity()
# Final 1x1 convolution after merging
self.final_transition = nn.Sequential(
nn.Conv2d(
in_channels=self.main_channels+self.shortcut_channels,
out_channels=in_channels,
kernel_size=1,
),
nn.BatchNorm2d(in_channels),
activation(),
)
[docs]
def forward(self, x):
# Apply first transition which is just a passthrough here
#shortcut = nn.Identity(x)
# Splitting the input tensor into two paths
main_data = x[:, :self.main_channels, :, :]
shortcut_data = x[:, self.main_channels:, :, :]
main_data = self.main_path(main_data)
shortcut_data = self.shortcut_path(shortcut_data)
# Concatenating the main and shortcut paths
combined = torch.cat(tensors=(main_data, shortcut_data), dim=1)
out = self.final_transition(combined)
return out
[docs]
class CSPMBConvBlock(nn.Module):
[docs]
def __init__(self, in_channels, num_blocks=1, dw_kernel_size=3, expansion_factor=4, activation=ReLU):
super().__init__()
# Use the same value for hidden_channels to avoid the issue with odd in_channels
self.main_channels = in_channels // 2
self.shortcut_channels = in_channels-self.main_channels
# Main path (processed part)
self.main_path = nn.Sequential(
*[MBConv(
in_channels=self.main_channels,
out_channels=self.main_channels,
expansion_factor=expansion_factor,
activation=activation,
dw_kernel_size=dw_kernel_size,
) for _ in range(num_blocks)],
)
# Shortcut path is just a passthrough
self.shortcut_path = nn.Identity()
# Final 1x1 convolution after merging
self.final_transition = nn.Sequential(
nn.Conv2d(
in_channels=self.main_channels+self.shortcut_channels,
out_channels=in_channels,
kernel_size=1,
),
nn.BatchNorm2d(in_channels),
activation(),
)
[docs]
def forward(self, x):
# Apply first transition which is just a passthrough here
#shortcut = nn.Identity(x)
# Splitting the input tensor into two paths
main_data = x[:, :self.main_channels, :, :]
shortcut_data = x[:, self.main_channels:, :, :]
main_data = self.main_path(main_data)
shortcut_data = self.shortcut_path(shortcut_data)
# Concatenating the main and shortcut paths
combined = torch.cat(tensors=(main_data, shortcut_data), dim=1)
out = self.final_transition(combined)
return out
# DenseNetBlock
[docs]
class DenseNetBlock(nn.Module):
"""
Basic DenseNet block composed by one 3x3 convs with residual connection.
The residual connection is perfomed by concatenate the input and the output.
"""
[docs]
def __init__(self, in_channels, out_channels, activation=ReLU):
super().__init__()
self.block = nn.Sequential(
nn.BatchNorm2d(in_channels),
activation(),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
)
[docs]
def forward(self, x):
res = x
x = self.block(x)
return torch.cat([res, x], dim=1)
# ResNetBlock and derivatives (https://paperswithcode.com/method/resnext)
[docs]
class ResNetBasicBlock(nn.Module):
[docs]
def __init__(self, in_channels, out_channels, reduction_factor=4, activation=ReLU):
super().__init__()
assert out_channels == in_channels
reduced_channels = in_channels // reduction_factor
if reduced_channels == 0:
reduced_channels = in_channels
self.steps = nn.Sequential(
# Narrow to wide
ConvBnAct(in_channels, reduced_channels, kernel_size=1, stride=1, padding=0, activation=activation),
# Wide to wide (depthwise convolution)
ConvBnAct(reduced_channels, reduced_channels, kernel_size=3, stride=1, padding=1, activation=activation),
# Wide to narrow
ConvBnAct(reduced_channels, out_channels, kernel_size=1, stride=1, padding=0, activation=activation)
)
[docs]
def forward(self, x):
x = self.steps(x)
return x
[docs]
class ResNetBlock(nn.Module):
[docs]
def __init__(self, in_channels, out_channels, reduction_factor=4, activation=ReLU):
super().__init__()
assert out_channels == in_channels
self.main_path = ResNetBasicBlock(
in_channels=in_channels,
out_channels=out_channels,
reduction_factor=reduction_factor,
activation=activation,
)
[docs]
def forward(self, x):
res = x
x = self.main_path(x)
x = torch.add(x, res)
return x
[docs]
class Upsample(nn.Module):
[docs]
def __init__(self, scale_factor=2, mode='nearest'):
super().__init__()
self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode)
[docs]
def forward(self, x):
return self.upsample(x)
"""
class ResNextBlock(nn.Module):
def __init__(self, in_channels, out_channels, reduction_factor=4, cardinality=1, activation=ReLU):
super().__init__()
assert out_channels == in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.cardinality = cardinality
self.is_divisible = False
if in_channels % cardinality == 0:
self.is_divisible = True
# Works only with combinations with module 0
self.parallel_channels = self.in_channels // self.cardinality
self.parallel_paths = []
for path in range(self.cardinality):
path = ResNetBasicBlock(
in_channels=self.parallel_channels,
out_channels=self.out_channels,
reduction_factor=reduction_factor,
activation=activation(),
)
self.parallel_paths.append(path)
def forward(self, x):
res = x
paths_sum = torch.zeros_like(x)
if self.is_divisible:
i = 0
while i<self.in_channels:
path_channels = x[:][i:self.parallel_channels][:][:] # check se funziona.
path_channels = self.parallel_paths(i//self.cardinality)(path_channels)
paths_sum = torch.add(paths_sum, path_channels)
i += self.cardinality
else:
print("Cardinality was not divisible, using Identity instead of ResNextBlock.")
x = torch.add(paths_sum, res)
return x
"""