Supervised Fine-tuning#

Wrapper for downstream classification and regression tasks using pretrained encoders as backbones.

SupervisedModule#

Core module combining a pretrained backbone with a classification/regression head. Supports linear probing, full fine-tuning, gradual unfreezing, and supervised-from-scratch training.

Supervised-training wrapper and reusable head for downstream tasks.

Provides a single SupervisedModule (LightningModule) that owns the train/val loop, optimizer, logging, and freeze logic. Model-specific concerns are injected via a BatchAdapter, a representation function, and an nn.Module head.

class chronocratic.models.supervised.supervised.BatchAdapter(*args, **kwargs)#

Bases: Protocol

Strategy: decode a model-specific batch tuple into encoder inputs + targets.

Each model has a different batch format from its DataLoader. The adapter normalizes it into ((encoder_inputs, ...), targets) so SupervisedModule never sees model-specific tuple shapes.

class chronocratic.models.supervised.supervised.FlattenLinearHead(in_features: int, num_outputs: int)#

Bases: Module

Flatten a representation across all non-batch dims, then a single linear layer.

Reused by every model whose representation is a tensor of shape (B, ...). Series2Vec reps are already (B, 2*rep) so the flatten is a no-op there.

Parameters:
  • in_features – Flattened representation size (backbone.representation_dim).

  • num_outputs – Number of classes (classification) or targets (regression).

forward(reps: Tensor) Tensor#

Compute logits from a representation tensor.

Parameters:

reps – Representation of shape (B, ...).

Returns:

Logits of shape (B, num_outputs).

class chronocratic.models.supervised.supervised.RepresentationBackbone(*args, **kwargs)#

Bases: Protocol

A backbone that can report the flattened feature size of its representation.

Implementations:
property representation_dim: int#

Flattened feature size handed to a downstream head.

class chronocratic.models.supervised.supervised.SupervisedModule(backbone: nn.Module, head: nn.Module, representation_fn: Callable[..., torch.Tensor], batch_adapter: BatchAdapter, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], *, learning_rate: float = 0.001, weight_decay: float = 0.0, freeze_backbone: bool = True, sync_dist: bool = False)#

Bases: LightningModule

Generic supervised-training wrapper for labeled downstream tasks.

Trains a backbone + head on labels for classification or regression. Owns the train/val loop, optimizer, logging, and (static) freeze. Everything model-specific is injected.

The four supervised modes are configuration, not subclasses:

Mode                  backbone state   freeze_backbone   Trainer callback
Linear probe          pretrained       True              none
Full fine-tune        pretrained       False             none
Gradual unfreeze      pretrained       False             BackboneUnfreeze
Supervised (scratch)  fresh / random   False             none

“Fine-tune” vs “supervised from scratch” is solely whether the injected backbone was pretrained — same class, same call. The from-scratch path (a freshly constructed, un-pretrained backbone with freeze_backbone=False) replaces the old TS-TCC SUPERVISED training mode.

Parameters:
  • backbone – A (possibly pretrained) model exposing the representation fn used below.

  • head – Maps a representation tensor to (B, num_outputs) (e.g. FlattenLinearHead).

  • representation_fn(backbone, *encoder_inputs) -> Tensor. Differentiable. MUST NOT route through encode() (that path is inference-mode / offline only).

  • batch_adapter – Decodes the batch tuple into ((encoder_inputs, ...), targets).

  • loss_fn(predictions, targets) -> scalar.

  • learning_rate – Adam LR.

  • weight_decay – Adam weight decay.

  • freeze_backbone – Freeze backbone params (linear probe). Set False for full fine-tuning, for supervised-from-scratch (fresh backbone), or when a gradual-unfreeze callback owns freezing (see BackboneUnfreeze). Never have both the bool and a callback flip requires_grad.

  • sync_dist – Sync logged metrics across processes.

property backbone: Module#

The wrapped backbone (read-only access for finetuning callbacks).

forward(*encoder_inputs: Tensor) Tensor#

Run representations through the head.

Parameters:

encoder_inputs – Model-specific tensors passed to representation_fn.

Returns:

Predictions of shape (B, num_outputs).

training_step(batch: tuple, _batch_idx: int) Tensor#

Compute and log the training loss for one batch.

Parameters:
  • batch – Raw batch tuple from the DataLoader.

  • _batch_idx – Index of this batch within the epoch (unused).

Returns:

Scalar training loss.

validation_step(batch: tuple, _batch_idx: int) Tensor#

Compute and log the validation loss for one batch.

Parameters:
  • batch – Raw batch tuple from the DataLoader.

  • _batch_idx – Index of this batch within the epoch (unused).

Returns:

Scalar validation loss.

configure_optimizers() Optimizer#

Return Adam over the trainable parameters of this module.

Uses a generator expression so frozen backbone params are excluded automatically. Compatible with gradual-unfreeze callbacks that add param groups later via BaseFinetuning.unfreeze_and_add_param_group().

Factory Functions#

Convenience constructors for common backbone-head combinations.

Factory constructors for SupervisedModule.

Each factory wires the correct FlattenLinearHead, batch adapter, representation function, and loss function for a given backbone — so callers don’t hand-assemble four collaborators.

chronocratic.models.supervised.factory.make_series2vec_supervised(backbone: Series2Vec, *, num_outputs: int, task: str = 'classification', freeze_backbone: bool = True, learning_rate: float = 0.001, weight_decay: float = 0.0, sync_dist: bool = False) SupervisedModule#

Build a SupervisedModule for a Series2Vec backbone.

Parameters:
  • backbone – A Series2Vec instance with representation_dim.

  • num_outputs – Number of classes (classification) or targets (regression).

  • task'classification' or 'regression'.

  • freeze_backbone – Freeze backbone params at construction (linear probe).

  • learning_rate – Adam LR.

  • weight_decay – Adam weight decay.

  • sync_dist – Sync logged metrics across processes.

Returns:

Configured SupervisedModule ready for training.

Example

Regression from scratch on a fresh backbone:

module = make_series2vec_supervised(
    Series2Vec(...), num_outputs=1, task='regression', freeze_backbone=False
)
chronocratic.models.supervised.factory.make_tst_supervised(backbone: TST, *, num_outputs: int, task: str = 'classification', freeze_backbone: bool = True, learning_rate: float = 0.001, weight_decay: float = 0.0, sync_dist: bool = False) SupervisedModule#

Build a SupervisedModule for a TST backbone.

Parameters:
  • backbone – A TST instance with representation_dim.

  • num_outputs – Number of classes (classification) or targets (regression).

  • task'classification' or 'regression'.

  • freeze_backbone – Freeze backbone params at construction (linear probe).

  • learning_rate – Adam LR.

  • weight_decay – Adam weight decay.

  • sync_dist – Sync logged metrics across processes.

Returns:

Configured SupervisedModule ready for training.

Example

Linear probe on a pretrained backbone (default):

module = make_tst_supervised(pretrained_tst, num_outputs=5)

Supervised from scratch — fresh backbone, train end-to-end:

module = make_tst_supervised(
    TST(...), num_outputs=5, freeze_backbone=False
)
chronocratic.models.supervised.factory.make_tstcc_supervised(backbone: TSTCC, *, num_outputs: int, task: str = 'classification', freeze_backbone: bool = True, learning_rate: float = 0.001, weight_decay: float = 0.0, sync_dist: bool = False) SupervisedModule#

Build a SupervisedModule for a TS-TCC backbone.

With freeze_backbone=False on a fresh (un-pretrained) TSTCC, this is the explicit replacement for the removed TSTCCTrainingMode.SUPERVISED.

Parameters:
  • backbone – A TSTCC instance with representation_dim.

  • num_outputs – Number of classes (classification) or targets (regression).

  • task'classification' or 'regression'.

  • freeze_backbone – Freeze backbone params at construction (linear probe).

  • learning_rate – Adam LR.

  • weight_decay – Adam weight decay.

  • sync_dist – Sync logged metrics across processes.

Returns:

Configured SupervisedModule ready for training.

Example

Supervised from scratch (old SUPERVISED mode):

module = make_tstcc_supervised(
    TSTCC(...), num_outputs=6, freeze_backbone=False
)