Remote Sensing Applications

This example demonstrates using PyNAS for remote sensing applications, including satellite imagery analysis, land cover classification, and change detection.

Overview

Remote sensing applications have unique requirements:

  • Multi-spectral Data: Handling different spectral bands

  • High Resolution: Processing large images efficiently

  • Temporal Analysis: Analyzing time series data

  • Domain Adaptation: Adapting to different sensors and regions

Satellite Image Classification

Land Cover Classification

Classify land cover types from Sentinel-2 imagery:

import torch
import torch.nn as nn
import numpy as np
from pynas.core.population import Population
from pynas.opt.evo import GeneticAlgorithm
from pynas.blocks.convolutions import ConvBlock

class RemoteSensingDataset(torch.utils.data.Dataset):
    """
    Dataset class for remote sensing imagery.
    """

    def __init__(self, images, labels, transform=None):
        """
        Initialize dataset.

        Args:
            images (numpy.ndarray): Satellite images [N, H, W, C]
            labels (numpy.ndarray): Land cover labels [N, H, W]
            transform: Data augmentation transforms
        """
        self.images = images
        self.labels = labels
        self.transform = transform
        self.num_classes = len(np.unique(labels))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        # Convert to tensor
        image = torch.FloatTensor(image).permute(2, 0, 1)  # [C, H, W]
        label = torch.LongTensor(label)

        if self.transform:
            image = self.transform(image)

        return image, label

class MultiSpectralBlock(nn.Module):
    """
    Custom block for multi-spectral data processing.
    """

    def __init__(self, in_channels, out_channels, num_bands=13):
        super().__init__()
        self.num_bands = num_bands

        # Band-specific processing
        self.band_conv = nn.ModuleList([
            nn.Conv2d(1, out_channels // num_bands, 3, padding=1)
            for _ in range(num_bands)
        ])

        # Cross-band fusion
        self.fusion_conv = nn.Conv2d(out_channels, out_channels, 1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Process each band separately
        band_features = []
        for i, band_conv in enumerate(self.band_conv):
            band_input = x[:, i:i+1, :, :]  # Extract single band
            band_feat = band_conv(band_input)
            band_features.append(band_feat)

        # Concatenate band features
        fused = torch.cat(band_features, dim=1)

        # Apply fusion convolution
        output = self.fusion_conv(fused)
        output = self.relu(self.bn(output))

        return output

def create_landcover_fitness_function(dataset_loader, device='cuda'):
    """
    Create fitness function for land cover classification.
    """

    def fitness_function(individual):
        model = individual.phenotype.to(device)
        model.eval()

        total_loss = 0
        correct_pixels = 0
        total_pixels = 0
        class_correct = torch.zeros(dataset_loader.dataset.num_classes)
        class_total = torch.zeros(dataset_loader.dataset.num_classes)

        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for images, labels in dataset_loader:
                images, labels = images.to(device), labels.to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)
                total_loss += loss.item()

                # Calculate pixel-wise accuracy
                _, predicted = torch.max(outputs, 1)
                correct_pixels += (predicted == labels).sum().item()
                total_pixels += labels.numel()

                # Per-class accuracy
                for c in range(dataset_loader.dataset.num_classes):
                    class_mask = (labels == c)
                    class_correct[c] += (predicted[class_mask] == labels[class_mask]).sum().item()
                    class_total[c] += class_mask.sum().item()

        # Calculate metrics
        pixel_accuracy = correct_pixels / total_pixels
        mean_iou = calculate_mean_iou(class_correct, class_total)
        avg_loss = total_loss / len(dataset_loader)

        # Multi-objective fitness (accuracy and efficiency)
        model_size = sum(p.numel() for p in model.parameters())
        efficiency_score = 1.0 / (model_size / 1e6)  # Inverse of model size in millions

        # Combined fitness score
        fitness = 0.7 * pixel_accuracy + 0.2 * mean_iou + 0.1 * efficiency_score

        individual.fitness = fitness
        individual.metrics = {
            'pixel_accuracy': pixel_accuracy,
            'mean_iou': mean_iou,
            'loss': avg_loss,
            'model_size': model_size,
            'efficiency_score': efficiency_score
        }

        return fitness

    return fitness_function

def calculate_mean_iou(class_correct, class_total):
    """Calculate mean Intersection over Union."""
    iou_scores = []
    for c in range(len(class_correct)):
        if class_total[c] > 0:
            iou = class_correct[c] / class_total[c]
            iou_scores.append(iou)
    return sum(iou_scores) / len(iou_scores) if iou_scores else 0.0

Change Detection

Detect changes between multi-temporal images:

class ChangeDetectionArchitecture(nn.Module):
    """
    Architecture for change detection in satellite imagery.
    """

    def __init__(self, num_bands=13, num_classes=2):
        super().__init__()
        self.num_bands = num_bands

        # Siamese encoder for temporal images
        self.encoder = self.build_encoder()

        # Change detection head
        self.change_head = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, num_classes, 1)
        )

    def build_encoder(self):
        """Build encoder for feature extraction."""
        layers = []
        in_ch = self.num_bands

        # Multi-scale feature extraction
        for out_ch in [64, 128, 256, 512]:
            layers.extend([
                MultiSpectralBlock(in_ch, out_ch, self.num_bands),
                nn.MaxPool2d(2, 2)
            ])
            in_ch = out_ch

        return nn.Sequential(*layers)

    def forward(self, x1, x2):
        """
        Forward pass for change detection.

        Args:
            x1: Image at time t1 [B, C, H, W]
            x2: Image at time t2 [B, C, H, W]
        """
        # Extract features from both images
        feat1 = self.encoder(x1)
        feat2 = self.encoder(x2)

        # Calculate feature difference
        diff = torch.abs(feat1 - feat2)

        # Detect changes
        change_map = self.change_head(diff)

        # Upsample to original resolution
        change_map = F.interpolate(
            change_map, size=x1.shape[-2:],
            mode='bilinear', align_corners=False
        )

        return change_map

def create_change_detection_fitness(dataset_loader, device='cuda'):
    """
    Create fitness function for change detection.
    """

    def fitness_function(individual):
        model = individual.phenotype.to(device)
        model.eval()

        total_loss = 0
        true_positives = 0
        false_positives = 0
        false_negatives = 0

        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for (img1, img2), labels in dataset_loader:
                img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)

                outputs = model(img1, img2)
                loss = criterion(outputs, labels)
                total_loss += loss.item()

                # Calculate change detection metrics
                _, predicted = torch.max(outputs, 1)

                # Binary change detection metrics
                changed_mask = (labels == 1)  # Changed pixels
                unchanged_mask = (labels == 0)  # Unchanged pixels

                true_positives += (predicted[changed_mask] == 1).sum().item()
                false_positives += (predicted[unchanged_mask] == 1).sum().item()
                false_negatives += (predicted[changed_mask] == 0).sum().item()

        # Calculate F1 score
        precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
        recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        avg_loss = total_loss / len(dataset_loader)

        # Fitness combines F1 score and model efficiency
        model_size = sum(p.numel() for p in model.parameters())
        efficiency_score = 1.0 / (model_size / 1e6)

        fitness = 0.8 * f1_score + 0.2 * efficiency_score

        individual.fitness = fitness
        individual.metrics = {
            'f1_score': f1_score,
            'precision': precision,
            'recall': recall,
            'loss': avg_loss,
            'model_size': model_size
        }

        return fitness

    return fitness_function

Hyperspectral Image Analysis

Processing hyperspectral data with hundreds of bands:

class HyperspectralBlock(nn.Module):
    """
    Block for processing hyperspectral data.
    """

    def __init__(self, num_bands=224, out_channels=64):
        super().__init__()
        self.num_bands = num_bands

        # Spectral dimension reduction
        self.spectral_conv = nn.Conv1d(num_bands, out_channels, 1)

        # Spatial convolutions
        self.spatial_conv = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Spectral-spatial attention
        self.attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(out_channels, out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Linear(out_channels // 4, out_channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x shape: [B, H, W, Bands]
        B, H, W, Bands = x.shape

        # Reshape for spectral processing
        x_spectral = x.view(B * H * W, Bands, 1)
        x_spectral = self.spectral_conv(x_spectral)
        x_spectral = x_spectral.view(B, H, W, -1)

        # Convert to [B, C, H, W] for spatial processing
        x_spatial = x_spectral.permute(0, 3, 1, 2)

        # Apply spatial convolutions
        features = self.spatial_conv(x_spatial)

        # Apply attention
        attention_weights = self.attention(features)
        attention_weights = attention_weights.unsqueeze(-1).unsqueeze(-1)
        features = features * attention_weights

        return features

def create_hyperspectral_architecture():
    """Create architecture for hyperspectral image classification."""

    class HyperspectralNet(nn.Module):
        def __init__(self, num_bands=224, num_classes=16):
            super().__init__()

            # Multi-scale hyperspectral processing
            self.block1 = HyperspectralBlock(num_bands, 64)
            self.block2 = HyperspectralBlock(64, 128)
            self.block3 = HyperspectralBlock(128, 256)

            # Classification head
            self.classifier = nn.Sequential(
                nn.AdaptiveAvgPool2d(1),
                nn.Flatten(),
                nn.Dropout(0.5),
                nn.Linear(256, 128),
                nn.ReLU(inplace=True),
                nn.Dropout(0.3),
                nn.Linear(128, num_classes)
            )

        def forward(self, x):
            x = self.block1(x)
            x = F.max_pool2d(x, 2)
            x = self.block2(x)
            x = F.max_pool2d(x, 2)
            x = self.block3(x)
            x = self.classifier(x)
            return x

    return HyperspectralNet()

SAR Image Processing

Process Synthetic Aperture Radar (SAR) imagery:

class SARProcessor(nn.Module):
    """
    Specialized processor for SAR imagery.
    """

    def __init__(self, in_channels=2, out_channels=64):  # I and Q channels
        super().__init__()

        # SAR-specific preprocessing
        self.complex_conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        # Speckle reduction
        self.speckle_filter = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 5, padding=2, groups=out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Edge enhancement
        self.edge_enhance = nn.Conv2d(out_channels, out_channels, 3, padding=1)

    def forward(self, x):
        # Process complex SAR data
        complex_features = self.complex_conv(x)

        # Apply speckle reduction
        denoised = self.speckle_filter(complex_features)

        # Enhance edges
        enhanced = self.edge_enhance(denoised)

        return enhanced + complex_features  # Residual connection

def create_sar_nas_config():
    """Create NAS configuration for SAR image analysis."""

    config = {
        'population_size': 25,
        'num_generations': 40,
        'mutation_rate': 0.15,
        'crossover_rate': 0.7,

        # SAR-specific architecture constraints
        'architecture_constraints': {
            'max_depth': 8,
            'min_channels': 32,
            'max_channels': 512,
            'use_sar_blocks': True,
            'speckle_reduction': True
        },

        # Training configuration
        'training': {
            'epochs': 50,
            'batch_size': 16,
            'learning_rate': 0.001,
            'optimizer': 'AdamW',
            'scheduler': 'CosineAnnealingLR'
        },

        # Evaluation metrics
        'metrics': ['accuracy', 'f1_score', 'iou', 'model_size']
    }

    return config

Complete Remote Sensing Example

Here’s a complete example for satellite image classification:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from pynas.core.population import Population
from pynas.opt.evo import GeneticAlgorithm

def run_remote_sensing_nas():
    """
    Complete remote sensing NAS example.
    """

    # Load remote sensing dataset
    print("Loading remote sensing dataset...")
    train_dataset = RemoteSensingDataset(
        images=np.load('sentinel2_images_train.npy'),
        labels=np.load('landcover_labels_train.npy')
    )

    val_dataset = RemoteSensingDataset(
        images=np.load('sentinel2_images_val.npy'),
        labels=np.load('landcover_labels_val.npy')
    )

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

    # Define architecture space for remote sensing
    def create_rs_genome():
        """Create genome for remote sensing architecture."""
        genome = []

        # Encoder blocks
        channels = [64, 128, 256, 512]
        for i, ch in enumerate(channels):
            block = {
                'type': np.random.choice(['multispectral', 'conv', 'resnet']),
                'channels': ch,
                'kernel_size': np.random.choice([3, 5, 7]),
                'use_attention': np.random.choice([True, False]),
                'dropout_rate': np.random.uniform(0.1, 0.5)
            }
            genome.append(block)

        # Decoder blocks for segmentation
        for i, ch in enumerate(reversed(channels[:-1])):
            block = {
                'type': 'decoder',
                'channels': ch,
                'upsample_mode': np.random.choice(['bilinear', 'nearest']),
                'skip_connection': True
            }
            genome.append(block)

        return genome

    # Create custom architecture builder
    class RemoteSensingBuilder:
        def __init__(self, num_bands=13, num_classes=10):
            self.num_bands = num_bands
            self.num_classes = num_classes

        def build_architecture(self, genome):
            """Build architecture from genome."""
            layers = []
            current_channels = self.num_bands

            # Build encoder
            encoder_features = []
            for gene in genome[:4]:  # First 4 are encoder blocks
                if gene['type'] == 'multispectral':
                    layer = MultiSpectralBlock(current_channels, gene['channels'])
                else:
                    layer = nn.Sequential(
                        nn.Conv2d(current_channels, gene['channels'],
                                gene['kernel_size'], padding=gene['kernel_size']//2),
                        nn.BatchNorm2d(gene['channels']),
                        nn.ReLU(inplace=True),
                        nn.Dropout(gene['dropout_rate'])
                    )

                layers.append(layer)
                encoder_features.append(gene['channels'])
                current_channels = gene['channels']

                # Add pooling
                layers.append(nn.MaxPool2d(2, 2))

            # Build decoder
            for i, gene in enumerate(genome[4:]):  # Decoder blocks
                target_channels = encoder_features[-(i+2)]

                # Upsampling
                layers.append(nn.Upsample(scale_factor=2, mode=gene['upsample_mode']))

                # Decoder convolution
                layers.append(nn.Conv2d(current_channels, target_channels, 3, padding=1))
                layers.append(nn.BatchNorm2d(target_channels))
                layers.append(nn.ReLU(inplace=True))

                current_channels = target_channels

            # Final classification layer
            layers.append(nn.Conv2d(current_channels, self.num_classes, 1))

            return nn.Sequential(*layers)

    # Initialize population
    builder = RemoteSensingBuilder(num_bands=13, num_classes=10)
    population = Population(
        population_size=20,
        genome_factory=create_rs_genome,
        architecture_builder=builder
    )

    # Create fitness function
    fitness_fn = create_landcover_fitness_function(val_loader)

    # Configure genetic algorithm
    ga = GeneticAlgorithm(
        population=population,
        fitness_function=fitness_fn,
        mutation_rate=0.1,
        crossover_rate=0.8,
        selection_method='tournament',
        tournament_size=3
    )

    # Run evolution
    print("Starting neural architecture search...")
    best_individuals = []

    for generation in range(30):
        print(f"Generation {generation + 1}/30")

        # Evaluate population
        ga.evaluate_population()

        # Get best individual
        best = ga.get_best_individual()
        best_individuals.append(best)

        print(f"Best fitness: {best.fitness:.4f}")
        print(f"Best metrics: {best.metrics}")

        # Evolve population
        if generation < 29:  # Don't evolve on last generation
            ga.evolve()

    # Return best architecture
    final_best = max(best_individuals, key=lambda x: x.fitness)
    print(f"\nFinal best architecture fitness: {final_best.fitness:.4f}")
    print(f"Final best metrics: {final_best.metrics}")

    return final_best

if __name__ == "__main__":
    # Run remote sensing NAS
    best_architecture = run_remote_sensing_nas()

    # Save best model
    torch.save(best_architecture.phenotype.state_dict(), 'best_remote_sensing_model.pth')

    print("Remote sensing NAS completed successfully!")

Data Preprocessing Utilities

Utilities for handling remote sensing data:

class RemoteSensingPreprocessor:
    """
    Preprocessing utilities for remote sensing data.
    """

    def __init__(self):
        self.band_statistics = {}

    def normalize_bands(self, image, method='percentile'):
        """
        Normalize spectral bands.

        Args:
            image: Multi-band image [H, W, Bands]
            method: Normalization method ('percentile', 'minmax', 'zscore')
        """
        if method == 'percentile':
            # Normalize to 2nd and 98th percentiles
            normalized = np.zeros_like(image)
            for band in range(image.shape[-1]):
                band_data = image[:, :, band]
                p2, p98 = np.percentile(band_data, [2, 98])
                normalized[:, :, band] = np.clip((band_data - p2) / (p98 - p2), 0, 1)

        elif method == 'minmax':
            # Min-max normalization
            normalized = (image - image.min()) / (image.max() - image.min())

        elif method == 'zscore':
            # Z-score normalization
            normalized = (image - image.mean()) / image.std()

        return normalized

    def calculate_indices(self, image, bands_config):
        """
        Calculate vegetation and other spectral indices.

        Args:
            image: Multi-band image
            bands_config: Dictionary mapping band names to indices
        """
        indices = {}

        # NDVI (Normalized Difference Vegetation Index)
        if 'red' in bands_config and 'nir' in bands_config:
            red = image[:, :, bands_config['red']]
            nir = image[:, :, bands_config['nir']]
            indices['ndvi'] = (nir - red) / (nir + red + 1e-8)

        # NDWI (Normalized Difference Water Index)
        if 'green' in bands_config and 'nir' in bands_config:
            green = image[:, :, bands_config['green']]
            nir = image[:, :, bands_config['nir']]
            indices['ndwi'] = (green - nir) / (green + nir + 1e-8)

        # EVI (Enhanced Vegetation Index)
        if all(band in bands_config for band in ['red', 'nir', 'blue']):
            red = image[:, :, bands_config['red']]
            nir = image[:, :, bands_config['nir']]
            blue = image[:, :, bands_config['blue']]
            indices['evi'] = 2.5 * (nir - red) / (nir + 6 * red - 7.5 * blue + 1)

        return indices

    def create_composite_bands(self, image, indices):
        """
        Create composite image with original bands and calculated indices.
        """
        composite_bands = [image]

        for idx_name, idx_data in indices.items():
            composite_bands.append(idx_data[:, :, np.newaxis])

        composite = np.concatenate(composite_bands, axis=-1)
        return composite

Performance Optimization for Large Images

Handle large satellite images efficiently:

class TiledInference:
    """
    Handle large images through tiled processing.
    """

    def __init__(self, model, tile_size=512, overlap=64):
        self.model = model
        self.tile_size = tile_size
        self.overlap = overlap

    def predict_large_image(self, large_image):
        """
        Predict on large image using tiled approach.

        Args:
            large_image: Large input image [H, W, C]

        Returns:
            Prediction map for entire image
        """
        H, W, C = large_image.shape

        # Calculate tile positions
        tiles = self.calculate_tile_positions(H, W)

        # Initialize output
        output = np.zeros((H, W), dtype=np.float32)
        weight_map = np.zeros((H, W), dtype=np.float32)

        for tile_info in tiles:
            # Extract tile
            tile = self.extract_tile(large_image, tile_info)

            # Predict on tile
            tile_tensor = torch.FloatTensor(tile).permute(2, 0, 1).unsqueeze(0)

            with torch.no_grad():
                tile_pred = self.model(tile_tensor)
                tile_pred = torch.softmax(tile_pred, dim=1)
                tile_pred = tile_pred.squeeze(0).cpu().numpy()

            # Merge prediction
            self.merge_tile_prediction(output, weight_map, tile_pred, tile_info)

        # Normalize by weights
        output = output / (weight_map + 1e-8)

        return output

    def calculate_tile_positions(self, H, W):
        """Calculate tile positions with overlap."""
        tiles = []
        step = self.tile_size - self.overlap

        for y in range(0, H, step):
            for x in range(0, W, step):
                # Ensure tile doesn't exceed image bounds
                y_end = min(y + self.tile_size, H)
                x_end = min(x + self.tile_size, W)

                # Adjust start position if needed
                y_start = max(0, y_end - self.tile_size)
                x_start = max(0, x_end - self.tile_size)

                tiles.append({
                    'y_start': y_start, 'y_end': y_end,
                    'x_start': x_start, 'x_end': x_end
                })

        return tiles

    def extract_tile(self, image, tile_info):
        """Extract tile from image."""
        return image[
            tile_info['y_start']:tile_info['y_end'],
            tile_info['x_start']:tile_info['x_end']
        ]

    def merge_tile_prediction(self, output, weight_map, tile_pred, tile_info):
        """Merge tile prediction into output."""
        # Create weight mask (higher weight in center, lower at edges)
        tile_h = tile_info['y_end'] - tile_info['y_start']
        tile_w = tile_info['x_end'] - tile_info['x_start']

        weight_mask = self.create_weight_mask(tile_h, tile_w)

        # Add to output
        output[
            tile_info['y_start']:tile_info['y_end'],
            tile_info['x_start']:tile_info['x_end']
        ] += tile_pred[0] * weight_mask  # Assuming class 1 is target

        weight_map[
            tile_info['y_start']:tile_info['y_end'],
            tile_info['x_start']:tile_info['x_end']
        ] += weight_mask

    def create_weight_mask(self, h, w):
        """Create weight mask with higher weights in center."""
        y, x = np.ogrid[:h, :w]
        cy, cx = h // 2, w // 2

        # Distance from center
        dist_y = np.abs(y - cy) / cy
        dist_x = np.abs(x - cx) / cx

        # Weight decreases with distance from center
        weight = (1 - dist_y) * (1 - dist_x)
        return weight

See also