API Reference#
Core Modules#
Student & Teacher U-Net–style semantic segmentation models for Martian terrain (AI4Mars).
Model families provided: - UNet: lightweight student model used for distillation. - AttentionUNet: deeper teacher architecture with attention-gated skip connections.
AI4Mars dataset:
- Rover images from Curiosity, Opportunity, and Spirit.
- Terrain classes: 0 = soil, 1 = bedrock, 2 = sand, 3 = big rock.
- 255 = no-label / ignored.
- class martian_terrain_segmentation.models.AttentionGate(*args: Any, **kwargs: Any)#
Bases:
ModuleAttention gate (AG) for U-Net skip connections, from Attention U-Net.
Given encoder skip features \(x\) and decoder gating features \(g\), the gate computes:
\[ \begin{align}\begin{aligned}\psi = \sigma(\mathrm{Conv}( \mathrm{ReLU}(W_g g + W_x x)))\\\tilde{x} = \psi \odot x\end{aligned}\end{align} \]so that skip features can be suppressed when irrelevant.
- Parameters:
F_g (int) – Channels of the decoder gating signal.
F_l (int) – Channels of the encoder skip feature map.
F_int (int) – Reduced intermediate dimensionality for gating.
- forward(g: torch.Tensor, x: torch.Tensor) torch.Tensor#
- Parameters:
g (torch.Tensor) – Decoder gating signal
[B,F_g,H_g,W_g].x (torch.Tensor) – Encoder skip features
[B,F_l,H_x,W_x].
- Returns:
Attention-reweighted skip features
[B,F_l,H_x,W_x].- Return type:
torch.Tensor
- class martian_terrain_segmentation.models.AttentionUNet(*args: Any, **kwargs: Any)#
Bases:
ModuleTeacher model: deeper Attention U-Net.
Differences from student U-Net: - 5 encoder + 5 decoder levels (deeper) - Attention gates on all skip connections - Wider feature maps (larger
base_channels)- Parameters:
in_channels (int) – Image channels.
num_classes (int) – Number of segmentation classes.
base_channels (int) – Width multiplier for encoder.
bilinear (bool) – Bilinear vs transposed conv upsampling.
- forward(x: torch.Tensor) torch.Tensor#
Forward pass of the Attention U-Net.
- get_cam_layer() torch.nn.Module#
Return last decoder block for Grad-CAM.
- class martian_terrain_segmentation.models.DoubleConv(*args: Any, **kwargs: Any)#
Bases:
ModuleTwo consecutive convolution–BatchNorm–ReLU blocks:
\[x \mapsto \mathrm{ReLU}(\mathrm{BN}(\mathrm{Conv}(x)))\]applied twice.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of feature channels in the output.
- forward(x: torch.Tensor) torch.Tensor#
Apply the 2×(Conv–BN–ReLU) block.
- class martian_terrain_segmentation.models.Down(*args: Any, **kwargs: Any)#
Bases:
ModuleDownscaling (encoder) block consisting of:
MaxPool2d(2)DoubleConv(in_channels, out_channels)
- Parameters:
in_channels (int) – Channels of input feature map.
out_channels (int) – Channels after convolution.
- forward(x: torch.Tensor) torch.Tensor#
- class martian_terrain_segmentation.models.OutConv(*args: Any, **kwargs: Any)#
Bases:
ModuleFinal 1Ă—1 convolution producing class logits.
- Parameters:
in_channels (int) – Channels of input feature map.
num_classes (int) – Number of segmentation classes.
- forward(x: torch.Tensor) torch.Tensor#
Return raw class logits of shape
[B,num_classes,H,W].
- class martian_terrain_segmentation.models.UNet(*args: Any, **kwargs: Any)#
Bases:
ModuleLightweight U-Net used as the student in distillation.
Architecture: 4 down → 4 up with skip connections.
- Parameters:
in_channels (int) – Input channels (
1for grayscale Navcam).num_classes (int) – Number of segmentation classes (AI4Mars uses
4).base_channels (int) – Width of first convolution layer. Determines model size.
bilinear (bool) – Whether to use bilinear upsampling (recommended).
- forward(x: torch.Tensor) torch.Tensor#
Forward pass of the U-Net.
- get_cam_layer() torch.nn.Module#
Return the decoder layer used as the CAM target.
This is consumed by explainability utilities such as Grad-CAM.
- class martian_terrain_segmentation.models.Up(*args: Any, **kwargs: Any)#
Bases:
ModuleUpsampling (decoder) block for standard U-Net.
Includes either: - bilinear upsampling, or - transposed convolution.
Followed by
DoubleConv.- Parameters:
in_channels (int) – Total number of channels after concatenation of skip + upsampled features.
out_channels (int) – Output feature channels.
bilinear (bool) – Whether to use bilinear interpolation (preferred for lightweight model).
- forward(x_up: torch.Tensor, x_skip: torch.Tensor) torch.Tensor#
- Parameters:
x_up (torch.Tensor) – Decoder feature map (coarse).
x_skip (torch.Tensor) – Encoder skip feature map (fine).
- Returns:
Feature map after concatenation + convolution.
- Return type:
torch.Tensor
- class martian_terrain_segmentation.models.UpAttn(*args: Any, **kwargs: Any)#
Bases:
ModuleUpsampling block with attention-gated skip connections.
- Parameters:
skip_channels (int) – Channels of encoder skip features.
up_channels (int) – Channels of decoder feature map.
out_channels (int) – Channels after concatenation + DoubleConv.
bilinear (bool) – Whether to use bilinear interpolation instead of transposed conv.
- forward(x_up: torch.Tensor, x_skip: torch.Tensor) torch.Tensor#
Apply attention gate → concat with upsampled features → DoubleConv.
- martian_terrain_segmentation.models.create_teacher_unet(in_channels: int = 1, num_classes: int = 4, base_channels: int = 64, bilinear: bool = True) AttentionUNet#
Factory for the deeper teacher model.
- martian_terrain_segmentation.models.create_unet(in_channels: int = 1, num_classes: int = 4, base_channels: int = 32, bilinear: bool = True) UNet#
Factory for the lightweight student U-Net.
Knowledge distillation utilities for semantic segmentation.
This module currently provides:
SegmentationKDLoss: a pixel-wise distillation loss combining standard cross-entropy with a KL divergence term between teacher and student logits, following the classic KD formulation:\[L = \alpha \cdot \mathrm{CE}(y, s) + (1 - \alpha) T^2 \cdot \mathrm{KL} ( \mathrm{softmax}(t/T) \;\Vert\; \mathrm{softmax}(s/T) ),\]where
sandtdenote student and teacher logits respectively, and the loss is averaged over all non-ignored pixels.
- class martian_terrain_segmentation.distillation.SegmentationKDLoss(*args: Any, **kwargs: Any)#
Bases:
ModulePixel-wise knowledge distillation loss for semantic segmentation.
This loss combines a standard supervised cross-entropy term with a KL-divergence term that encourages the student to match the teacher’s softened class probabilities (as in Hinton et al., 2015).
The loss is defined as:
\[L = \alpha \cdot \mathcal{L}_{\text{CE}}(y, s) + (1 - \alpha) T^2 \cdot \mathcal{L}_{\text{KD}}(t, s),\]where
\(\mathcal{L}_{\text{CE}}\) is pixel-wise cross-entropy w.r.t. the ground truth labels \(y\).
\(\mathcal{L}_{\text{KD}}\) is the pixel-wise KL divergence between teacher and student predictive distributions with temperature \(T\).
The KL term is computed per pixel, masked by
ignore_index, and averaged over all valid pixels.- Parameters:
ignore_index (int) – Label value in the target mask that should be ignored when computing both the CE and KD terms (e.g. unlabeled/void pixels).
alpha (float, default=0.5) – Trade-off parameter between the supervised CE loss and the distillation (KL) loss.
alpha=1.0means pure CE,alpha=0.0means pure distillation.T (float, default=2.0) – Temperature used to soften the teacher and student logits in the KD term. Larger values produce softer probability distributions and typically richer distillation signals.
Examples
kd_loss_fn = SegmentationKDLoss(ignore_index=255, alpha=0.5, T=2.0) # student_logits, teacher_logits: [B, C, H, W] # targets: [B, H, W] with values in {0..C-1} or ignore_index loss = kd_loss_fn(student_logits, teacher_logits, targets)
- forward(student_logits: torch.Tensor, teacher_logits: torch.Tensor, targets: torch.Tensor) torch.Tensor#
Compute the combined CE + KD loss for segmentation.
- Parameters:
student_logits (torch.Tensor) – Logits from the student network of shape
[B, C, H, W].teacher_logits (torch.Tensor) – Logits from the teacher network of shape
[B, C, H, W]. These are treated as fixed targets (no gradient should flow into the teacher).targets (torch.Tensor) – Ground-truth segmentation mask of shape
[B, H, W]with integer class indices in[0, C-1]orignore_index.
- Returns:
Scalar tensor containing the total distillation loss.
- Return type:
torch.Tensor
Notes
The returned loss is:
\[L = \alpha \cdot \mathrm{CE}(y, s) + (1 - \alpha) T^2 \cdot \mathrm{KL} ( \mathrm{softmax}(t/T) \;\Vert\; \mathrm{softmax}(s/T) ),\]where:
The cross-entropy term is averaged over all pixels with
targets != ignore_index.The KL term is also averaged over valid pixels and scaled by \(T^2\) following the original KD formulation.
Explainability utilities for Martian terrain segmentation models.
This module provides unified interfaces for: - Grad-CAM heatmaps over decoder features - Integrated Gradients saliency maps - Neural PCA: class-wise PCA in feature space, following the lecture slides
The tools are written to work with the U-Net architecture defined in
models.py but are generic enough to be used with any segmentation
model that exposes a CAM target layer and a final 1Ă—1 classifier.
- martian_terrain_segmentation.explainability.compute_class_neural_pca_features(model: torch.nn.Module, dataset, device: torch.device, class_ids: Sequence[int], max_samples_per_class: int = 200, n_components: int = 5, min_per_class: int = 10) Dict[int, Dict[str, object]]#
Compute Neural PCA for each class, following the lecture slides.
We compute:
Feature vectors \(\phi(x)\)
Classifier weights \(w_k\)
Class-specific embedding:
\[\psi_k(x) = w_k \odot \phi(x)\]
Then perform PCA over the set \(\{\psi_k(x_i)\}\) for each class.
- Parameters:
model (nn.Module) – Trained segmentation network.
dataset – Dataset providing
(image, mask)samples.device (torch.device) – Device for model inference.
class_ids (sequence of int) – Class indices to process.
max_samples_per_class (int) – Maximum number of samples to use per class.
n_components (int) – Number of PCA components to retain.
min_per_class (int) – Minimum samples required to run PCA.
- Returns:
Mapping
class_id -> PCA resultswith keys: -mean_psi:[D]-eigvecs:[L,D]-eigvals:[L]-alphas:[N,L]projection scores -indices: dataset indices used -psi: raw psi vectors[N,D]- Return type:
dict
- martian_terrain_segmentation.explainability.explain_per_class_examples(model, dataset, device, num_examples_per_class: int = 2, ig_steps: int = 32)#
Generate combined explainability visualizations for each class:
Input image
Grad-CAM
Integrated Gradients
- Parameters:
model (nn.Module) – Trained segmentation model.
dataset – Dataset providing
(img, mask).device (torch.device) – Compute device.
num_examples_per_class (int) – Number of examples per class.
ig_steps (int) – IG integration steps.
- martian_terrain_segmentation.explainability.grad_cam(model: torch.nn.Module, input_tensor: torch.Tensor, target_class: int, target_layer: torch.nn.Module) torch.Tensor#
Compute a Grad-CAM heatmap for semantic segmentation.
Grad-CAM produces a spatial importance map based on gradients flowing into a target convolutional layer:
\[\mathrm{CAM}(x) = \mathrm{ReLU} \Big( \sum_f \alpha_f \cdot A_f(x) \Big),\]where:
\(A_f\) are the feature maps
\(\alpha_f = \frac{1}{HW} \sum_{i,j} \frac{\partial y_k}{\partial A_f(i,j)}\) are channel-wise weights
\(y_k\) is the class score (here averaged over spatial pixels)
- Parameters:
model (nn.Module) – Segmentation model outputting logits of shape
[B,C,H,W].input_tensor (torch.Tensor) – Single input image of shape
[1,C,H,W]. Gradients must be enabled.target_class (int) – Class index
0..num_classes-1for which to compute Grad-CAM.target_layer (nn.Module) – Layer to hook for feature maps and gradients (e.g.
model.get_cam_layer()).
- Returns:
Heatmap of shape
[1,1,H,W]normalized to[0,1].- Return type:
torch.Tensor
- martian_terrain_segmentation.explainability.integrated_gradients(model: nn.Module, input_tensor: torch.Tensor, target_class: int, baseline: torch.Tensor | None = None, steps: int = 50) torch.Tensor#
Compute Integrated Gradients (IG) for segmentation.
Integrated Gradients approximates the path integral:
\[\mathrm{IG}(x) = (x - x_0) \times \int_{0}^{1} \frac{\partial f(x_0 + \alpha (x - x_0))} {\partial x} \, d\alpha,\]where \(x_0\) is a baseline (typically zeros).
- Parameters:
model (nn.Module) – Segmentation model.
input_tensor (torch.Tensor) – Input image
[1,C,H,W].target_class (int) – Class index whose score to differentiate.
baseline (torch.Tensor, optional) – Baseline image
[1,C,H,W]. Default: zeros.steps (int) – Number of steps for Riemann sum approximation.
- Returns:
Attribution map of shape
[1,C,H,W].- Return type:
torch.Tensor
- martian_terrain_segmentation.explainability.normalize_map(t: torch.Tensor) torch.Tensor#
Normalize tensor linearly to the
[0,1]range.- Parameters:
t (torch.Tensor) – Input tensor.
- Returns:
Normalized tensor.
- Return type:
torch.Tensor
- martian_terrain_segmentation.explainability.show_top_neural_pca_images_for_class(neural_pca_results, dataset, class_id: int, component_idx: int = 0, top_k: int = 6)#
Visualize the top-k images that maximally activate a neural PCA component for a given class.
For PCA component \(v_\ell\), the ranking uses:
\[\alpha_\ell^{(k)}(x_i)\]There is also a small easter egg:
If
class_id == -1, a special “astronaut” NPCA image is shown instead of using the neural PCA results.
- Parameters:
neural_pca_results (dict) – Output of
compute_class_neural_pca_features().dataset – Dataset that returns
(image, mask).class_id (int) – Class index to visualize. Use
-1for the astronaut easter egg.component_idx (int) – PCA component index (0-based).
top_k (int) – Number of top activating images to display.
Optimizer helpers for model training.
This module provides:
create_optimizer Smart optimizer selection with priority: - Muon (if installed and explicitly enabled), - NAdam (PyTorch’s NAdamW-like implementation), - AdamW as a safe fallback.
create_cosine_scheduler_with_warmup A cosine–annealing learning rate schedule with linear warmup, mathematically equivalent to the Hugging Face transformers scheduler.
Both utilities are framework-agnostic and work with any PyTorch model.
- martian_terrain_segmentation.optimizers.create_cosine_scheduler_with_warmup(optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5) torch.optim.lr_scheduler.LambdaLR#
Create a cosine-annealing LR scheduler with linear warmup.
This scheduler combines a linear warmup phase with a cosine decay phase.
Learning rate schedule
Given current step \(t\), warmup \(W\), and total steps \(T\), the schedule is:
Warmup (linear)
\[\text{lr}(t) = \frac{t}{W}, \quad 0 \le t < W\]Cosine decay
\[ \begin{align}\begin{aligned}\text{progress} = \frac{t - W}{T - W}\\\text{lr}(t) = \tfrac{1}{2}\left(1 + \cos\big( 2\pi \cdot C \cdot \text{progress} \big)\right)\end{aligned}\end{align} \]where:
\(C\) =
num_cyclescontrols the number of cosine waves (0.5= standard: decay → 0 once)
- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer whose learning rate will be scheduled.
num_warmup_steps (int) – Number of linear warmup steps, typically 5–10% of total training steps.
num_training_steps (int) – Total number of steps (
epochs * steps_per_epoch).num_cycles (float, optional) – Number of cosine cycles. Default
0.5= half-cycle (decay to 0 exactly once).
- Returns:
Scheduler that updates the LR dynamically during training.
- Return type:
torch.optim.lr_scheduler.LambdaLR
Notes
This implementation is mathematically similar to
transformers.get_cosine_schedule_with_warmup.The value returned by the lambda is multiplied with the optimizer’s base LR.
- martian_terrain_segmentation.optimizers.create_optimizer(model: torch.nn.Module, lr: float = 0.0003, weight_decay: float = 0.01, use_muon: bool = True) torch.optim.Optimizer#
Create an optimizer for a given model with prioritized fallback logic.
The optimizers are tried in the following priority:
Muon (if installed and
use_muon=True) Muon is a second-order optimizer approximating natural gradient steps.NAdam PyTorch’s NAdam implementation (NadamW-style), supporting weight decay.
AdamW Stable, widely used, standard fallback.
- Parameters:
model (torch.nn.Module) – Model whose trainable parameters will be optimized.
lr (float, optional) – Learning rate (default:
3e-4).weight_decay (float, optional) – Weight decay coefficient (default:
1e-2).use_muon (bool, optional) – Whether the user prefers to use Muon if available.
- Returns:
Constructed optimizer instance.
- Return type:
torch.optim.Optimizer
Notes
Only parameters with
requires_grad=Trueare passed to the optimizer.If Muon is requested but not installed, AdamW is used and a warning printed.