Supervised Fine-tuning#
Wrapper for downstream classification and regression tasks using pretrained encoders as backbones.
SupervisedModule#
Core module combining a pretrained backbone with a classification/regression head. Supports linear probing, full fine-tuning, gradual unfreezing, and supervised-from-scratch training.
Supervised-training wrapper and reusable head for downstream tasks.
Provides a single SupervisedModule (LightningModule) that owns
the train/val loop, optimizer, logging, and freeze logic. Model-specific
concerns are injected via a BatchAdapter, a representation
function, and an nn.Module head.
- class chronocratic.models.supervised.supervised.BatchAdapter(*args, **kwargs)#
Bases:
ProtocolStrategy: decode a model-specific batch tuple into encoder inputs + targets.
Each model has a different batch format from its DataLoader. The adapter normalizes it into
((encoder_inputs, ...), targets)soSupervisedModulenever sees model-specific tuple shapes.
- class chronocratic.models.supervised.supervised.FlattenLinearHead(in_features: int, num_outputs: int)#
Bases:
ModuleFlatten a representation across all non-batch dims, then a single linear layer.
Reused by every model whose representation is a tensor of shape
(B, ...). Series2Vec reps are already(B, 2*rep)so the flatten is a no-op there.- Parameters:
in_features – Flattened representation size (
backbone.representation_dim).num_outputs – Number of classes (classification) or targets (regression).
- forward(reps: Tensor) Tensor#
Compute logits from a representation tensor.
- Parameters:
reps – Representation of shape
(B, ...).- Returns:
Logits of shape
(B, num_outputs).
- class chronocratic.models.supervised.supervised.RepresentationBackbone(*args, **kwargs)#
Bases:
ProtocolA backbone that can report the flattened feature size of its representation.
- Implementations:
- property representation_dim: int#
Flattened feature size handed to a downstream head.
- class chronocratic.models.supervised.supervised.SupervisedModule(backbone: nn.Module, head: nn.Module, representation_fn: Callable[..., torch.Tensor], batch_adapter: BatchAdapter, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], *, learning_rate: float = 0.001, weight_decay: float = 0.0, freeze_backbone: bool = True, sync_dist: bool = False)#
Bases:
LightningModuleGeneric supervised-training wrapper for labeled downstream tasks.
Trains a
backbone+headon labels for classification or regression. Owns the train/val loop, optimizer, logging, and (static) freeze. Everything model-specific is injected.The four supervised modes are configuration, not subclasses:
Mode backbone state freeze_backbone Trainer callback Linear probe pretrained True none Full fine-tune pretrained False none Gradual unfreeze pretrained False BackboneUnfreeze Supervised (scratch) fresh / random False none
“Fine-tune” vs “supervised from scratch” is solely whether the injected backbone was pretrained — same class, same call. The from-scratch path (a freshly constructed, un-pretrained backbone with
freeze_backbone=False) replaces the old TS-TCCSUPERVISEDtraining mode.- Parameters:
backbone – A (possibly pretrained) model exposing the representation fn used below.
head – Maps a representation tensor to
(B, num_outputs)(e.g.FlattenLinearHead).representation_fn –
(backbone, *encoder_inputs) -> Tensor. Differentiable. MUST NOT route throughencode()(that path is inference-mode / offline only).batch_adapter – Decodes the batch tuple into
((encoder_inputs, ...), targets).loss_fn –
(predictions, targets) -> scalar.learning_rate – Adam LR.
weight_decay – Adam weight decay.
freeze_backbone – Freeze backbone params (linear probe). Set
Falsefor full fine-tuning, for supervised-from-scratch (fresh backbone), or when a gradual-unfreeze callback owns freezing (seeBackboneUnfreeze). Never have both the bool and a callback fliprequires_grad.sync_dist – Sync logged metrics across processes.
- property backbone: Module#
The wrapped backbone (read-only access for finetuning callbacks).
- forward(*encoder_inputs: Tensor) Tensor#
Run representations through the head.
- Parameters:
encoder_inputs – Model-specific tensors passed to
representation_fn.- Returns:
Predictions of shape
(B, num_outputs).
- training_step(batch: tuple, _batch_idx: int) Tensor#
Compute and log the training loss for one batch.
- Parameters:
batch – Raw batch tuple from the DataLoader.
_batch_idx – Index of this batch within the epoch (unused).
- Returns:
Scalar training loss.
- validation_step(batch: tuple, _batch_idx: int) Tensor#
Compute and log the validation loss for one batch.
- Parameters:
batch – Raw batch tuple from the DataLoader.
_batch_idx – Index of this batch within the epoch (unused).
- Returns:
Scalar validation loss.
- configure_optimizers() Optimizer#
Return Adam over the trainable parameters of this module.
Uses a generator expression so frozen backbone params are excluded automatically. Compatible with gradual-unfreeze callbacks that add param groups later via
BaseFinetuning.unfreeze_and_add_param_group().
Factory Functions#
Convenience constructors for common backbone-head combinations.
Factory constructors for SupervisedModule.
Each factory wires the correct FlattenLinearHead, batch adapter,
representation function, and loss function for a given backbone — so
callers don’t hand-assemble four collaborators.
- chronocratic.models.supervised.factory.make_series2vec_supervised(backbone: Series2Vec, *, num_outputs: int, task: str = 'classification', freeze_backbone: bool = True, learning_rate: float = 0.001, weight_decay: float = 0.0, sync_dist: bool = False) SupervisedModule#
Build a
SupervisedModulefor a Series2Vec backbone.- Parameters:
backbone – A
Series2Vecinstance withrepresentation_dim.num_outputs – Number of classes (classification) or targets (regression).
task –
'classification'or'regression'.freeze_backbone – Freeze backbone params at construction (linear probe).
learning_rate – Adam LR.
weight_decay – Adam weight decay.
sync_dist – Sync logged metrics across processes.
- Returns:
Configured
SupervisedModuleready for training.
Example
Regression from scratch on a fresh backbone:
module = make_series2vec_supervised( Series2Vec(...), num_outputs=1, task='regression', freeze_backbone=False )
- chronocratic.models.supervised.factory.make_tst_supervised(backbone: TST, *, num_outputs: int, task: str = 'classification', freeze_backbone: bool = True, learning_rate: float = 0.001, weight_decay: float = 0.0, sync_dist: bool = False) SupervisedModule#
Build a
SupervisedModulefor a TST backbone.- Parameters:
backbone – A
TSTinstance withrepresentation_dim.num_outputs – Number of classes (classification) or targets (regression).
task –
'classification'or'regression'.freeze_backbone – Freeze backbone params at construction (linear probe).
learning_rate – Adam LR.
weight_decay – Adam weight decay.
sync_dist – Sync logged metrics across processes.
- Returns:
Configured
SupervisedModuleready for training.
Example
Linear probe on a pretrained backbone (default):
module = make_tst_supervised(pretrained_tst, num_outputs=5)
Supervised from scratch — fresh backbone, train end-to-end:
module = make_tst_supervised( TST(...), num_outputs=5, freeze_backbone=False )
- chronocratic.models.supervised.factory.make_tstcc_supervised(backbone: TSTCC, *, num_outputs: int, task: str = 'classification', freeze_backbone: bool = True, learning_rate: float = 0.001, weight_decay: float = 0.0, sync_dist: bool = False) SupervisedModule#
Build a
SupervisedModulefor a TS-TCC backbone.With
freeze_backbone=Falseon a fresh (un-pretrained)TSTCC, this is the explicit replacement for the removedTSTCCTrainingMode.SUPERVISED.- Parameters:
backbone – A
TSTCCinstance withrepresentation_dim.num_outputs – Number of classes (classification) or targets (regression).
task –
'classification'or'regression'.freeze_backbone – Freeze backbone params at construction (linear probe).
learning_rate – Adam LR.
weight_decay – Adam weight decay.
sync_dist – Sync logged metrics across processes.
- Returns:
Configured
SupervisedModuleready for training.
Example
Supervised from scratch (old
SUPERVISEDmode):module = make_tstcc_supervised( TSTCC(...), num_outputs=6, freeze_backbone=False )