RGB Image Quality Evaluation on Kodak#


Model

bpp

PSNR

LPIPS (dB)

DISTS (dB)

SSIM

LiVeAction f16c12 half res. (LSDIR)

0.0432

24.2336

2.7625

6.2970

0.7194

JPEG 2000 (CR 400:1)

0.0597

23.4048

2.1388

5.1450

0.6378

Cosmos di16×16 (Proprietary)

0.0625

21.7743

5.3784

10.6189

0.6449

JPEG 2000 (CR 250:1)

0.0957

24.3492

2.4231

5.7073

0.6842

JPEG 2000 (CR 160:1)

0.1495

25.3227

2.7599

6.3334

0.7324

LiVeAction f16c12 (LSDIR)

0.1550

27.0384

4.2153

8.8654

0.8600

LiVeAction f16c12 (LSDIR) + FLUX

0.1507

23.2454

4.5325

9.4840

0.6429

JPEG 2000 (CR 125:1)

0.1916

25.8947

2.9724

6.7353

0.7586

Balle 2018 Hyperprior (MSE loss)

0.2110

27.2377

3.9050

7.9973

0.8258

JPEG 2000 (CR 100:1)

0.2393

26.4546

3.1854

7.1317

0.7829

Cosmos di8×8 (Proprietary)

0.2500

25.9193

7.7112

13.2647

0.8558

JPEG 2000 (CR 70:1)

0.3423

27.4229

3.5425

7.7950

0.8182

JPEG 2000 (CR 40:1)

0.5985

29.1910

4.2655

9.1768

0.8712

WaLLoC f8c12 (LSDIR)

0.6171

30.5576

6.5138

13.2437

0.9501

LiVeAction f16c48 (LSDIR)

0.6615

31.1669

6.5692

13.6699

0.9571

LiVeAction f16c48 v2 (LSDIR)

0.6921

31.2534

6.5843

13.7425

0.9582

LiVeAction f16c48 EM (LSDIR)

0.8334

31.1914

6.8621

14.1159

0.8434

JPEG 2000 (CR 20:1)

1.1984

32.0019

5.4387

11.3646

0.9262

Balle 2018 (Q=8)

1.6630

37.9544

9.4908

17.5278

0.9143

WaLLoC f8c48 (LSDIR)

2.5436

37.3370

11.674

18.2942

0.9873

JPEG 2000 (CR 8:1)

2.9980

37.1560

8.2494

16.1444

0.9746

LiveAction f8c48 (LSDIR)

3.7140

39.8998

14.904

24.2401

0.9959

LiveAction f16c192 (LSDIR)

3.9113

40.2877

15.290

25.5340

0.9962

PSNR vs BPP#


import bjontegaard as bd
import matplotlib.pyplot as plt
r_j2k  = [0.0597, 0.0957, 0.1495, 0.1916, 0.2393, 0.3423, 0.5985, 1.1984, 2.9980]  # bpp
d_j2k  = [23.4048, 24.3492, 25.3227, 25.8947, 26.4546, 27.4229, 29.1910, 32.0019, 37.1560]  # PSNR
r_cosmos = [0.0625, 0.2500]
d_cosmos = [21.7743, 25.9193]
r_live = [0.04322 , 0.1550, 0.6615, 3.9113]
d_live = [24.234, 27.0384, 31.1669, 40.2877]
r_walloc = [0.6171, 2.5436]
d_walloc = [30.5576, 37.3370]
plt.rcParams["font.family"] = "serif"
plt.figure(figsize=(3, 3), dpi=180)
plt.semilogx(r_j2k, d_j2k, marker='o', linestyle='-', label='JPEG 2000')
plt.plot(r_cosmos, d_cosmos, marker='o', linestyle=':', label='Cosmos')
plt.plot(r_walloc, d_walloc, marker='o', linestyle=':', label='WaLLoC')
plt.plot(r_live,   d_live,   marker='o', linestyle=':', label='LiVeAction')
# plt.plot(0.8334, 31.19, marker='*', linestyle=':', label='Conventional entropy model\ninstead of simplified rate loss')
plt.xlabel("Rate [bits per pixel]")
plt.ylabel("PSNR [dB]")
plt.title("Rate vs PSNR (Kodak)")
plt.grid(True, alpha=0.4)
plt.legend(loc='best')
plt.tight_layout()
bd_cosmos = bd.bd_rate(r_j2k, d_j2k, r_cosmos, d_cosmos,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
bd_live   = bd.bd_rate(r_j2k, d_j2k, r_live,   d_live,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
bd_walloc = bd.bd_rate(r_j2k, d_j2k, r_walloc, d_walloc,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
print(f"BD-Rate (Cosmos vs JPEG 2000):  {bd_cosmos:+.4f} %")
print(f"BD-Rate (WaLLoC vs JPEG 2000):  {bd_walloc:+.4f} %")
print(f"BD-Rate (LiVeAction vs JPEG 2000): {bd_live:+.4f} %")
plt.savefig('rate_psnr.svg')
BD-Rate (Cosmos vs JPEG 2000):  +49.6106 %
BD-Rate (WaLLoC vs JPEG 2000):  -27.6079 %
BD-Rate (LiVeAction vs JPEG 2000): -36.3974 %
../_images/8b83b5f603966202edb475855059f50785190cde33b25ba10ae9358c1cf1ba8e.png

SSIM vs bpp#


import bjontegaard as bd
import matplotlib.pyplot as plt
r_j2k  = [0.0597, 0.0957, 0.1495, 0.1916, 0.2393, 0.3423, 0.5985, 1.1984, 2.9980]
d_j2k  = [0.6378, 0.6842, 0.7324, 0.7586, 0.7829, 0.8182, 0.8712, 0.9262, 0.9746]
r_cosmos = [0.0625, 0.2500]
d_cosmos = [0.6449, 0.8558]
r_live = [0.0432, 0.1550, 0.6615, 3.9113]
d_live = [0.7194, 0.8600, 0.9571, 0.9962]
r_walloc = [0.6171, 2.5436]
d_walloc = [0.9501, 0.9873]
plt.rcParams["font.family"] = "serif"
plt.figure(figsize=(3, 3), dpi=180)
plt.semilogx(r_j2k, d_j2k, marker='o', linestyle='-', label='JPEG 2000')
plt.plot(r_cosmos, d_cosmos, marker='o', linestyle=':', label='Cosmos')
plt.plot(r_walloc, d_walloc, marker='o', linestyle=':', label='WaLLoC')
plt.plot(r_live,   d_live,   marker='o', linestyle=':', label='LiVeAction')
plt.xlabel("Rate [bits per pixel]")
plt.ylabel("SSIM")
plt.title("Rate vs SSIM (Kodak)")
plt.grid(True, alpha=0.4)
plt.tight_layout()
bd_cosmos = bd.bd_rate(r_j2k, d_j2k, r_cosmos, d_cosmos,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
bd_live   = bd.bd_rate(r_j2k, d_j2k, r_live,   d_live,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
bd_walloc = bd.bd_rate(r_j2k, d_j2k, r_walloc, d_walloc,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
print(f"BD-Rate (Cosmos vs JPEG 2000):  {bd_cosmos:+.4f} %")
print(f"BD-Rate (WaLLoC vs JPEG 2000):  {bd_walloc:+.4f} %")
print(f"BD-Rate (LiVeAction vs JPEG 2000): {bd_live:+.4f} %")
plt.savefig('rate_ssim.svg')
BD-Rate (Cosmos vs JPEG 2000):  -29.9439 %
BD-Rate (WaLLoC vs JPEG 2000):  -57.5230 %
BD-Rate (LiVeAction vs JPEG 2000): -70.0758 %
../_images/1f5d8ffde3ed54796311cca1fe2cdee4eee44a3891fd68aef4da9683a4484274.png

DISTS vs bpp#


import bjontegaard as bd
import matplotlib.pyplot as plt
r_j2k  = [0.0597, 0.0957, 0.1495, 0.1916, 0.2393, 0.3423, 0.5985, 1.1984, 2.9980]
d_j2k  = [5.1450, 5.7073, 6.3334, 6.7353, 7.1317, 7.7950, 9.1768, 11.3646, 16.1444]
r_cosmos = [0.0625, 0.2500]
d_cosmos = [10.6189, 13.2647]
r_live = [0.0432, 0.1550, 0.6615, 3.9113]
d_live = [6.2970, 8.8654, 13.6699, 25.5340]
r_walloc = [0.6171, 2.5436]
d_walloc = [13.2437, 18.2942]

plt.rcParams["font.family"] = "serif"
plt.figure(figsize=(3, 3), dpi=180)
plt.semilogx(r_j2k, d_j2k, marker='o', linestyle='-', label='JPEG 2000')
plt.plot(r_cosmos, d_cosmos, marker='o', linestyle=':', label='Cosmos')
plt.plot(r_walloc, d_walloc, marker='o', linestyle=':', label='WaLLoC')
plt.plot(r_live,   d_live,   marker='o', linestyle=':', label='LiVeAction')
plt.xlabel("Rate [bits per pixel]")
plt.ylabel("DISTS [dB]")
plt.title("Rate vs DISTS (Kodak)")
plt.grid(True, alpha=0.4)
plt.legend(loc='best')
plt.tight_layout()
bd_cosmos = bd.bd_rate(r_j2k, d_j2k, r_cosmos, d_cosmos,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
bd_live   = bd.bd_rate(r_j2k, d_j2k, r_live,   d_live,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
bd_walloc = bd.bd_rate(r_j2k, d_j2k, r_walloc, d_walloc,
                       method='pchip', require_matching_points=False, min_overlap=0.15)
print(f"BD-Rate (Cosmos vs JPEG 2000):  {bd_cosmos:+.4f} %")
print(f"BD-Rate (WaLLoC vs JPEG 2000):  {bd_walloc:+.4f} %")
print(f"BD-Rate (LiVeAction vs JPEG 2000): {bd_live:+.4f} %")
plt.show()
plt.savefig('rate_dists.svg')
BD-Rate (Cosmos vs JPEG 2000):  -90.8796 %
BD-Rate (WaLLoC vs JPEG 2000):  -61.7135 %
BD-Rate (LiVeAction vs JPEG 2000): -69.0857 %
../_images/2d9f27e7c90ec03f47e7d7711587572bde3ada6adcbe67ea4cac0c68083a9e35.png
<Figure size 640x480 with 0 Axes>

Quality evaluation#


import pillow_jpls, torch, io, datasets, PIL.Image,  numpy as np
from huggingface_hub import hf_hub_download
from types import SimpleNamespace
from piq import LPIPS, DISTS, SSIMLoss
from livecodec.codec import AutoCodecND, latent_to_pil, pil_to_latent
from torchvision.transforms.v2.functional import to_pil_image, pil_to_tensor, resize
device = "cuda:0"
dataset = datasets.load_dataset("danjacobellis/kodak")
checkpoint_file = hf_hub_download(
    repo_id="danjacobellis/liveaction",
    filename="lsdir_f16c48.pth"
)
checkpoint = torch.load(checkpoint_file, map_location="cpu",weights_only=False)
config = checkpoint['config']
codec = AutoCodecND(
    dim=2,
    input_channels=config.input_channels,
    J = int(np.log2(config.F)),
    latent_dim=config.latent_dim,
    encoder_depth = config.encoder_depth,
    encoder_kernel_size = config.encoder_kernel_size,
    decoder_depth = config.decoder_depth,
    lightweight_encode = config.lightweight_encode,
    lightweight_decode = config.lightweight_decode,
).to(device).to(torch.float)
codec.load_state_dict(checkpoint['state_dict'])
codec.eval();

lpips_loss = LPIPS().to(device)
dists_loss = DISTS().to(device)
ssim_loss = SSIMLoss().to(device)
/home/dgj335/.local/lib/python3.11/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/home/dgj335/.local/lib/python3.11/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
def evaluate_quality(sample, resample_factor=1.0):
    img = sample['image']
    x_orig = pil_to_tensor(img).to(device).unsqueeze(0).to(torch.float) / 127.5 - 1.0
    orig_size = tuple(x_orig.shape[-2:])
    new_size = tuple(int(resample_factor*s) for s in (orig_size))
    if resample_factor == 1:
        x_resize = x_orig
    else:
        x_resize = resize(x_orig, new_size, interpolation=PIL.Image.Resampling.BICUBIC)
    orig_dim = x_orig.numel() 
    with torch.no_grad():
        z = codec.encode(x_resize)
        latent = codec.quantize.compand(z).round()
    jpls = latent_to_pil(latent.cpu(), n_bits=8, C=3)
    buff = io.BytesIO()
    jpls[0].save(buff, format='JPEG-LS')
    size_bytes = len(buff.getbuffer())
    latent_decoded = pil_to_latent(jpls, N=config.latent_dim, n_bits=8, C=3).to(device).to(torch.float)
    with torch.no_grad():
        x_hat = codec.decode(latent_decoded)
        if resample_factor == 1:
            x_hat = x_hat.clamp(-1,1)
        else:
            x_hat = resize(x_hat, orig_size, interpolation=PIL.Image.Resampling.BICUBIC).clamp(-1,1)
    x_orig_01 = x_orig / 2 + 0.5
    x_hat_01 = x_hat / 2 + 0.5
    pixels = img.width * img.height
    bpp = 8 * size_bytes / pixels
    mse = torch.nn.functional.mse_loss(x_orig_01[0], x_hat_01[0])
    PSNR = -10 * mse.log10().item()
    LPIPS_dB = -10 * np.log10(lpips_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    DISTS_dB = -10 * np.log10(dists_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item())
    SSIM = 1 - ssim_loss(x_orig_01.to("cuda"), x_hat_01.to("cuda")).item()

    return {
        'pixels': pixels,
        'bpp': bpp,
        'PSNR': PSNR,
        'LPIPS_dB': LPIPS_dB,
        'DISTS_dB': DISTS_dB,
        'SSIM': SSIM,
    }
results_dataset = dataset['validation'].map(evaluate_quality)
print("mean\n---")
for metric in [
    'pixels',
    'bpp',
    'PSNR',
    'LPIPS_dB',
    'DISTS_dB',
    'SSIM',
]:
    μ = np.mean(results_dataset[metric])
    print(f"{metric}: {μ}")
mean
---
pixels: 393216.0
bpp: 0.6921293470594619
PSNR: 31.25336805979411
LPIPS_dB: 6.584292992454675
DISTS_dB: 13.742558663102976
SSIM: 0.9581804027160009