Training#
The script train_rae_progressive.py at the repository root is the frozen
version used to produce the checkpoint reported in the paper. It imports the
model, ops, and quantizer definitions from src/compressors/frappe/, so it
must be run from the repository root.
Example command#
The example below trains a small 9-channel, 3-scale model with 1 epoch per channel for each of the single-channel and merged-decoder phases. The remaining hyperparameters use the script defaults (decoder_dim=768, decoder_arch=CCCCCC, encoder_arch=SC8, LSDIR train / Kodak validation, crop size 480, batch size 1).
python train_rae_progressive.py \
--device cuda:0 \
--ps 32 32 32 16 16 16 8 8 8 \
--epochs_single 1 \
--epochs_merged 1 \
--save_checkpoint_name checkpoint_example.pth
This is much lighter than the full paper configuration
(ps = [32×4, 16×3, 8×3, 4×3], epochs_single=2, epochs_merged=4,
plus a scaled-down decoder_dim and per-channel lam / sc_max_lr
schedule), and is intended for sanity-checking the pipeline rather than
matching reported rate-distortion numbers.
What runs#
For each of the 9 channels the script performs two phases:
Single-channel residual training. A new
AutoencoderSingleChannel(one encoder filter at the current patch size + a full decoder) is trained to reconstruct the residual left by the currently merged model. Rate-distortion loss islog10(MSE) + lam · target_power^rpe · log2(std(z)).Merged-decoder retraining. The new channel’s encoder filter is merged into the
MergedAutoencoder, all encoders are frozen and their latents quantized (round()to int8), and only the unified decoder is fine-tuned.
The decoder operates at decoder_ps = max(ps) = 32 and finer latents
(ps=16, ps=8) are patchified up to that resolution before concatenation,
so the decoder input channel count grows from 1 (ch0) to 1+1+1+4+4+4+16+16+16
(ch8) without changing the decoder architecture.
After every channel, the script appends a snapshot to
checkpoint_example.pth and prints the validation PSNR and compression
ratio on Kodak. When all channels are done it computes BD-rate vs the
bundled AVIF reference points and writes a results JSON next to the
checkpoint.
Resuming#
Channel-level resumption is supported via --resume_checkpoint and
--resume_channels N (skips the first N channels and continues from
there). The new run’s flags must match the saved config for the resumed
channels — ps, decoder_ps, encoder_arch, decoder_arch,
decoder_dim, decoder_kernel_size, decoder_mlp_ratio,
decoder_layerscale, and input_channels are all asserted at startup.
Using the resulting checkpoint#
The checkpoint stores per-channel-count snapshots in
merged_decoder_weights[n]. To rebuild a MergedAutoencoder for the
first n+1 channels, instantiate it with the saved config and load
the encoder/decoder state dicts from the snapshot — the same pattern
used by compressors.frappe.model.load_progressive_model for
safetensors weights from the Hugging Face hub (see the inference
example).