Augmentation Framework#
Polymorphic augmentation producer contract — models accept any augmentation through a unified interface, eliminating enum-based branching.
Base Types#
Protocols, typed view-sets, and abstract base classes for the augmentation system.
Abstract base classes for augmentation strategies.
This module defines the shared augmentation hierarchy extracted from the
monolithic strategies.py. It contains only abstract base classes and
typed view-set dataclasses (~350 lines). Concrete implementations live in
per-model augmentation files.
- Exported symbols:
Augmentation: Structural protocol for primitive transforms.AugmentationProducer[V]: Protocol for typed view-set production.AugmentationTrainingStrategy: Abstract training-loss interface.TrainableAugmentationProducer: Abstract trainable augmentation (nn.Module).SingleView/ViewPair/AlignedPair: Typed view-set dataclasses.
- class chronocratic.models.augmentation.base.AlignedPair(first: Tensor, second: Tensor, overlap_length: int)#
Bases:
ViewPairA pair of augmented views with a known aligned region.
Extends
ViewPairwith anoverlap_lengthfield so that consumers (e.g. TS2Vec) can slice embeddings to the aligned span without ametadatadict.AlignedPairis-aViewPair(Liskov substitution), so a producer returningAlignedPairsatisfies any slot expectingViewPair.- Parameters:
first – First augmented tensor.
second – Second augmented tensor.
overlap_length – Number of time steps over which the two views align. For non-crop augmentations this equals the full sequence length, making the alignment slice a no-op.
- class chronocratic.models.augmentation.base.Augmentation(*args, **kwargs)#
Bases:
ProtocolStructural protocol for model-agnostic augmentation primitives.
Implements this protocol to create a primitive transform that accepts a tensor and returns a transformed tensor of the same shape.
Examples
Jitter, Scaling, Permutation — shared across all models.
- class chronocratic.models.augmentation.base.AugmentationProducer(*args, **kwargs)#
Bases:
Protocol,GenericAssembles the view set a model’s loss requires from a batch.
A producer wraps one or more
Augmentationprimitives and returns a typed view set (SingleView,ViewPair, orAlignedPair). This is the object injected into a model at construction time.Vis covariant — it appears only in return position — so aAugmentationProducer[AlignedPair]can be used whereverAugmentationProducer[ViewPair]is expected.This is a structural Protocol; concrete classes satisfy it by having the correct
producesignature, not by inheriting from it.- produce(x: torch.Tensor) V#
Produce the model’s view set from a batch.
- Parameters:
x – Input tensor of shape
(batch, time, channels).- Returns:
A typed view set (
SingleView,ViewPair, orAlignedPair).
- class chronocratic.models.augmentation.base.AugmentationTrainingStrategy(training_ratio_step: int = 1)#
Bases:
ABCDefines how a trainable augmentation network is optimized.
Subclass to create a new training strategy. Implement
compute_loss()to define the loss function. Overrideshould_train()for epoch-gated schedules.- abstractmethod compute_loss(x_embeddings: Tensor, aug_x_embeddings: Tensor, augmentation_factor: Tensor) Tensor#
Compute the augmentation network loss.
- Parameters:
x_embeddings – Encodings of the original data.
aug_x_embeddings – Encodings of the augmented data.
augmentation_factor – Learned augmentation weights/factors.
- Returns:
Scalar loss tensor requiring gradients.
- should_train(epoch: int, batch_idx: int) bool#
Determine if aug-network training should run this step.
Default: train when
epoch % training_ratio_step == 0.- Parameters:
epoch – Current training epoch.
batch_idx – Current batch index within the epoch.
- Returns:
Trueif the augmentation network should be trained this step.
- class chronocratic.models.augmentation.base.Reseedable(*args, **kwargs)#
Bases:
ProtocolProtocol for producers that accept an external RNG for determinism.
Implement
reseed()to allowSeededto inject a freshnp.random.Generatorbefore eachproduce()call.- reseed(rng: np.random.Generator) None#
Replace the internal RNG with
rng.- Parameters:
rng – A seeded
np.random.Generatorinstance.
- class chronocratic.models.augmentation.base.SingleView(view: Tensor)#
Bases:
objectA single augmented view returned by a producer.
Used by models that need only one transformed copy of the input (e.g. AutoTCL with a neural-network augmentation).
- Parameters:
view – Augmented tensor of shape
(batch, time, channels).
- class chronocratic.models.augmentation.base.TrainableAugmentationProducer(training_strategy: AugmentationTrainingStrategy)#
Bases:
Module,ABCA trainable augmentation producer with learnable parameters.
Combines the
nn.Modulelifecycle (parameters, state_dict) with a training strategy for the augmentation network. This is a nominal ABC (not a Protocol) because it must be runtime-checkable viaisinstance()to gate the trainable path.TrainableAugmentationProducerstructurally satisfiesAugmentationProducer[SingleView](it hasproduce(x) -> SingleView), so it type-checks in anySingleViewslot.- Parameters:
training_strategy – Strategy for computing the augmentation loss and determining training frequency.
- abstractmethod produce(x: Tensor) SingleView#
Return an augmented view produced by the encoder model.
- Parameters:
x – Input time-series tensor of shape
(batch, time, channels).- Returns:
A single augmented view wrapped in
SingleView.
- abstractmethod train_step(x: Tensor, encoder: Module, batch_idx: int) Tensor | None#
Run one augmentation-network training step.
Subclasses define their own training loop. The base provides
configure_optimizer()andshould_train_augmentation(); the composed_training_strategyprovidescompute_loss().- Parameters:
x – Original input data.
encoder – The main encoder module to compute embeddings.
batch_idx – Current batch index within the epoch.
- Returns:
Loss tensor if training should run this step, otherwise None.
- configure_optimizer(lr: float) AdamW#
Return optimizer over this module’s parameters.
- Parameters:
lr – Learning rate for the augmentation network optimizer.
- Returns:
AdamW optimizer for this module’s parameters.
- should_train_augmentation(epoch: int, batch_idx: int) bool#
Check whether the aug-network should train this step.
Delegates to the composed training strategy to avoid exposing the private
_training_strategyattribute.- Parameters:
epoch – Current training epoch.
batch_idx – Current batch index within the epoch.
- Returns:
Trueif the augmentation network should be trained this step.
- class chronocratic.models.augmentation.base.ViewPair(first: Tensor, second: Tensor)#
Bases:
objectTwo augmented views returned by a producer.
Used by models that need a pair of views for contrastive or consistency losses (e.g. CoST query/key, TS-TCC weak/strong).
- Parameters:
first – First augmented tensor.
second – Second augmented tensor.
Producers#
Concrete augmentation producers that generate single views, view pairs, and aligned pairs.
Shared augmentation producer combinators.
Each class wraps one or more Augmentation primitives and assembles
a typed ViewSet result. These are generic combinators — they import
nothing model-specific.
- Exported symbols:
SingleViewProducer: wraps one Augmentation, returns SingleView.IndependentPairProducer: applies one Augmentation twice, returns ViewPair.RolePairProducer: applies two Augmentations, returns ViewPair.FullOverlapProducer: applies one Augmentation twice, returns AlignedPair with overlap_length == input time dimension.
- class chronocratic.models.augmentation.producers.FullOverlapProducer(*, aug: Augmentation, time_dim: int = 1)#
Bases:
objectApply one
Augmentationtwice and return anAlignedPair.Sets
overlap_lengthto the full time dimension of the input, indicating that the two views are completely aligned (no cropping offset).Satisfies
AugmentationProducer[AlignedPair]structurally. BecauseAlignedPairis-aViewPair, this also satisfiesAugmentationProducer[ViewPair]via covariance.- Parameters:
aug – The augmentation primitive to apply (called twice independently).
time_dim – The time dimension index in the input tensor. Defaults to 1 for (batch, time, channels). Use -1 for (batch, channels, time).
- produce(x: torch.Tensor) AlignedPair#
Produce two aligned augmented views with full overlap.
- Parameters:
x – Input tensor.
- Returns:
AlignedPair with overlap_length equal to the time dimension of x.
- class chronocratic.models.augmentation.producers.IndependentPairProducer(*, aug: Augmentation)#
Bases:
objectApply one
Augmentationtwice and return aViewPair.Each call to
augproduces an independent (stochastic) draw, so the two views differ even though they use the same primitive.Satisfies
AugmentationProducer[ViewPair]structurally.- Parameters:
aug – The augmentation primitive to apply (called twice independently).
- class chronocratic.models.augmentation.producers.RolePairProducer(*, first: Augmentation, second: Augmentation)#
Bases:
objectApply two different
Augmentation`s and return a :class:`ViewPair.Useful when each view has a distinct role (e.g. weak/strong in TS-TCC).
Satisfies
AugmentationProducer[ViewPair]structurally.- Parameters:
first – Augmentation for the first view.
second – Augmentation for the second view.
- class chronocratic.models.augmentation.producers.SingleViewProducer(*, aug: Augmentation)#
Bases:
objectWrap one
Augmentationand return aSingleView.Satisfies
AugmentationProducer[SingleView]structurally.- Parameters:
aug – The augmentation primitive to apply.
- produce(x: torch.Tensor) SingleView#
Produce a single augmented view from
x.- Parameters:
x – Input tensor of shape
(batch, time, channels).- Returns:
SingleView containing the augmented tensor.
Primitives#
Individual augmentation operations (Jitter, Scaling, Permutation) and composition.
Model-agnostic augmentation primitives.
Extracted from tstcc/augmentations.py and reshaped to satisfy the
:model-agnostic Augmentation
Protocol. Each primitive accepts a tensor and returns a transformed tensor
of the same shape.
Shared across all models. Imports nothing model-specific.
- Exported symbols:
Jitter,JitterParameters: Additive Gaussian noise.Scaling,ScalingParameters: Per-channel multiplicative scaling.Permutation,PermutationParameters: Time-segment permutation.ComposeAugmentation: Chain primitives sequentially.
- class chronocratic.models.augmentation.primitives.ComposeAugmentation(augmentations: list[Augmentation])#
Bases:
objectApply a sequence of augmentations one after another.
Analogous to
torchvision.transforms.Compose. Each augmentation’s output is fed as input to the next.Satisfies the
AugmentationProtocol via__call__.
- class chronocratic.models.augmentation.primitives.Jitter(params: JitterParameters | None = None)#
Bases:
objectAdd elementwise Gaussian noise with std
sigma.Satisfies the
AugmentationProtocol via__call__.
- class chronocratic.models.augmentation.primitives.JitterParameters(sigma: float = 0.1, p: float = 1.0)#
Bases:
objectParameters for
Jitter.- Parameters:
sigma – Std of the additive Gaussian noise.
p – Probability of applying the transform.
1.0means always.
- class chronocratic.models.augmentation.primitives.Permutation(params: PermutationParameters | None = None)#
Bases:
objectSplit each sample’s time axis into segments and permute them.
Satisfies the
AugmentationProtocol via__call__.
- class chronocratic.models.augmentation.primitives.PermutationParameters(max_segments: int = 5, time_dim: int = -1)#
Bases:
objectParameters for
Permutation.- Parameters:
max_segments – Upper bound on the number of segments to split each sample into. The actual number is drawn uniformly from
[1, max_segments)per sample.time_dim – Dimension index of the time axis. Defaults to
-1for the(B, C, T)convention used by TS-TCC.
- class chronocratic.models.augmentation.primitives.Scaling(params: ScalingParameters | None = None)#
Bases:
objectMultiply data by a per-channel Gaussian scale factor.
Satisfies the
AugmentationProtocol via__call__.
- class chronocratic.models.augmentation.primitives.ScalingParameters(sigma: float = 0.1, mean: float = 1.0, p: float = 1.0, per_sample: bool = False, channel_dim: int = 1)#
Bases:
objectParameters for
Scaling.- Parameters:
sigma – Std of the per-channel Gaussian scale factor.
mean – Mean of the per-channel scale factor.
p – Probability of applying the transform.
per_sample – If
True, draw an independent factor for each sample in the batch. IfFalse, the factor is shared across the batch.channel_dim – Dimension index of the channel axis. Defaults to
1for the(B, C, T)convention used by TS-TCC.
Decorators#
Wrappers that add cross-cutting behavior (e.g., deterministic seeding) to augmentations.
Augmentation producer decorators.
Provides decorators that wrap AugmentationProducer instances to add
cross-cutting capabilities (e.g., deterministic seeding) without modifying
the producer’s own code.
- Exported symbols:
Seeded: Deterministic wrapper for stateless producers.
- class chronocratic.models.augmentation.decorators.Seeded(*, inner: AugmentationProducer[V], seed: int)#
Bases:
GenericDeterministic decorator wrapping a stateless AugmentationProducer.
Uses
torch.random.fork_rng()andtorch.manual_seed()so that inner randomness is isolated from the outer process random state. Repeated calls with the same seed produce identical output tensors.Constraint: must NOT wrap
TrainableAugmentationProducer. Trainable producers have their own parameterised state; seeding at the producer level is not meaningful and may hide bugs.- Parameters:
inner – A stateless
AugmentationProducerto wrap.seed – Fixed integer seed applied before every
produce()call.
- produce(x: torch.Tensor) V#
Produce a view set with deterministic randomness.
Isolates the inner producer’s random state from the outer process context using
torch.random.fork_rng().- Parameters:
x – Input tensor of shape
(batch, time, channels).- Returns:
The typed view set produced by the inner producer, generated under a reproducible random seed.
Trainable Support#
Utilities for augmentations with learnable parameters.
Centralized helpers for trainable augmentation producers.
These functions are the only place in the codebase that branch on
isinstance(..., TrainableAugmentationProducer). Models should
call these helpers instead of checking the type themselves, keeping the
model code branchless on the augmentation type.
- Exported symbols:
maybe_train_augmentationmaybe_configure_augmentation_optimizer
- chronocratic.models.augmentation.trainable_support.maybe_configure_augmentation_optimizer(augmentation: AugmentationProducer[Any], *, lr: float) Optimizer | None#
Configure an optimizer for the augmentation network if trainable.
For stateless producers this function returns
Noneimmediately. For trainable producers it delegates toconfigure_optimizer().This is the sole code path in the codebase that uses
isinstance(..., TrainableAugmentationProducer)for optimizer configuration.- Parameters:
augmentation – The augmentation producer to optionally configure.
lr – Learning rate for the augmentation network optimizer.
- Returns:
An optimizer for the augmentation network parameters, or
Noneif the producer is not trainable.
- chronocratic.models.augmentation.trainable_support.maybe_train_augmentation(augmentation: AugmentationProducer[Any], *, x: torch.Tensor, encoder: nn.Module, epoch: int, batch_idx: int) torch.Tensor | None#
Run one augmentation-network training step if the producer is trainable.
For stateless producers this function returns
Noneimmediately. For trainable producers it checksshould_train_augmentation()and delegates totrain_step()when the strategy permits.Mode management: sets
encoderto eval andaugmentationto train during the forward pass, then restores bothencoderto train andaugmentationto eval in thefinallyblock.This is the sole code path in the codebase that uses
isinstance(..., TrainableAugmentationProducer)for the training loop.- Parameters:
augmentation – The augmentation producer to optionally train.
x – Original input data.
encoder – The main encoder module to compute embeddings.
epoch – Current training epoch.
batch_idx – Current batch index within the epoch.
- Returns:
Loss tensor if the producer is trainable and should train this step, otherwise
None.