RGB Training Example

RGB Training Example#

import os; os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "garbage_collection_threshold:0.5"
import io, time, torch, datasets, PIL.Image, numpy as np, matplotlib.pyplot as plt, fastprogress, IPython.display
from einops.layers.torch import Rearrange
from types import SimpleNamespace
from typing import OrderedDict
from torchvision.transforms.v2.functional import pil_to_tensor, to_pil_image
from torchvision.transforms import Compose, Resize, RandomCrop, CenterCrop, ColorJitter
from livecodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from timm.optim import Adan
device = 'cuda:1'
dataset = datasets.DatasetDict({
    'train': datasets.load_dataset("danjacobellis/LSDIR",split='train'),
    'validation': datasets.load_dataset("danjacobellis/kodak",split='validation')
})
config = SimpleNamespace()
config.F = 8
config.latent_dim = 48
config.input_channels = 3
config.encoder_depth = 4
config.encoder_kernel_size = 1
config.decoder_depth = 8
config.lightweight_encode = True
config.lightweight_decode = False
config.freeze_encoder_after = 0.7
config.λ = 3e-2
config.epochs = 150
config.batch_size = 4
config.total_steps = config.epochs * (dataset['train'].num_rows // config.batch_size)
config.min_aspect = 0.5
config.max_aspect = 2
config.min_size = 240
config.max_size = 640
config.min_scale = 1.0
config.max_scale = 2.0
config.max_lr = 3e-4
config.min_lr = 1e-8
config.lr_pow = 2
config.num_workers = 12
config.save_checkpoint_name = 'f8c48.pth'
model = AutoCodecND(
    dim=2,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device)

optimizer = Adan(model.parameters(), lr=1.0)
def rc_sched(i_step, config):
    t = i_step / config.total_steps
    return (config.max_lr - config.min_lr) * (1 - ((np.cos(np.pi*t))**(2*config.lr_pow))) + config.min_lr
schedule = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda i_step: rc_sched(i_step, config)
)
print(sum(p.numel() for p in model.parameters())/1e6)
4.931184
def train_collate_fn(batch):
    aspect = np.random.uniform(config.min_aspect,config.max_aspect)
    h = np.random.uniform(config.min_size,config.max_size); w = h*aspect;
    h = int((config.F*(np.rint(h/config.F))).clip(config.min_size,config.max_size))
    w = int((config.F*(np.rint(w/config.F))).clip(config.min_size,config.max_size))
    x = []
    for i_sample, sample in enumerate(batch):
        transform = Compose([
            Resize(
                int(max(h,w)*np.random.uniform(config.min_scale,config.max_scale)),
                interpolation=PIL.Image.Resampling.BICUBIC,
                max_size=1+int(config.max_scale*config.max_size)
            ),
            RandomCrop((h,w),pad_if_needed=True,padding_mode='symmetric'),
            ColorJitter(0.4,0.0,0.4,0.0),
        ])
        xi = pil_to_tensor(transform(sample['image'].convert("RGB"))).unsqueeze(0)
        x.append(xi)
    x = torch.cat(x)
    return x.to(torch.float)/127.5 - 1.0

def enc_dec(sample):
    img = sample['image']
    buff = io.BytesIO()
    x = pil_to_tensor(img).to(torch.float).to(device).unsqueeze(0)/127.5 - 1.0
    t0 = time.time()
    with torch.inference_mode():
        z = model.encode(x)
        z = model.quantize.compand(z).round()
    latent_to_pil(z.cpu().detach(),n_bits=8,C=3)[0].save(buff, format='WEBP', lossless=True)
    enc_time = time.time() - t0
    size_bytes = len(buff.getbuffer())
    t0 = time.time()
    z = pil_to_latent([PIL.Image.open(buff)],N=config.latent_dim,n_bits=8,C=3).to(device)
    with torch.inference_mode():
        xhat = model.decode(z).clamp(-1,1)
    dec_time = time.time() - t0
    x_01 = x/2 + 0.5
    xhat_01 = xhat/2 + 0.5
    psnr = -10*torch.nn.functional.mse_loss(x_01, xhat_01).log10().item()
    del sample['path']
    del sample['image']
    return{
        'psnr': psnr,
        'cr': x.numel()/size_bytes,
        'enc': (x.numel()/3e6)/enc_time,
        'dec': (x.numel()/3e6)/dec_time
    }
learning_rates = [optimizer.param_groups[0]['lr']]
mb = fastprogress.master_bar(range(config.epochs))
log_mse_losses = []
rate_losses = []
valid_metrics = []
global_step = 0
for i_epoch in mb:
    model.train()
    dataloader_train = torch.utils.data.DataLoader(
            dataset['train'],
            batch_size=config.batch_size,
            num_workers=config.num_workers,
            drop_last=True,
            shuffle=True,
            collate_fn=train_collate_fn
        )
    pb = fastprogress.progress_bar(dataloader_train, parent=mb)
    for i_batch, x in enumerate(pb):
        x = x.to(device)
        if i_epoch/config.epochs <= config.freeze_encoder_after :
            # Soft quantization with additive noise
            x_hat, rate = model(x)
        else:
            # Hard quantization
            with torch.no_grad():
                z = model.quantize.compand(model.encode(x))
                z_quant = z.round()
            rate = z.std().log2()              
            x_hat = model.decode(z_quant)

        # Loss and optimizer
        mse_loss = torch.nn.functional.mse_loss(x, x_hat).log10()
        log_mse_losses.append(mse_loss.item())
        rate_losses.append(rate.item())
        total_loss = mse_loss + config.λ * rate
        total_loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        schedule.step()
        learning_rates.append(optimizer.param_groups[0]['lr'])
        pb.comment = (f"PSNR: {-10*log_mse_losses[-1]+6.02:.3g}, R: {rate:.2g}, LR: {learning_rates[-1]:.2g}")
        global_step += 1

    valid = dataset['validation'].map(enc_dec)
    valid_metrics.append({key: torch.tensor(valid[key]).mean().item() for key in valid.features.keys()})
    mb.main_bar.comment = ";  ".join([f'{key}: {valid_metrics[-1][key]:0.4g}' for key in valid_metrics[-1]])
    
    torch.save({
        'i_epoch': i_epoch,
        'log_mse_losses': log_mse_losses,
        'learning_rates': learning_rates,
        'valid_metrics': {k: [d[k] for d in valid_metrics] for k in valid_metrics[0].keys()},
        'config': config,
        'state_dict': model.state_dict()
    }, config.save_checkpoint_name)
display(IPython.display.HTML(mb.main_bar.progress))
display(IPython.display.HTML(pb.progress))
N = 501; plt.plot(np.convolve(log_mse_losses, N*[1/N], mode='valid'))
plt.ylim([-4,-2])
100.00% [150/150 59:28:38<00:00 psnr: 39.9; cr: 6.535; enc: 7.878; dec: 40.39]
100.00% [21247/21247 19:16<00:00 PSNR: 32.5, R: 3.8, LR: 1e-08]
(-4.0, -2.0)
../_images/642b7335852a054ebca901f6f9c785e496898c8bac18fb1602978723da60bf61.png
torch.tensor(([[v[key] for key in v.keys()] for v in valid_metrics]))[-1]
tensor([39.8998,  6.5354,  7.8779, 40.3941])
torch.save({
    'valid_metrics': {k: [d[k] for d in valid_metrics] for k in valid_metrics[0].keys()},
    'config': config,
    'state_dict': model.state_dict()
}, '../../../hf/liveaction/lsdir_f8c48.pth')