8 Commits

9 changed files with 1255 additions and 0 deletions
Split View
  1. +3
    -0
      README.md
  2. +59
    -0
      eval_dec_complexity/codec_complexity.py
  3. +431
    -0
      eval_dec_complexity/lvc_dec_complexity.py
  4. +277
    -0
      eval_dec_performance/calc_metrics.py
  5. +2
    -0
      eval_dec_performance/calc_metrics.sh
  6. +42
    -0
      eval_dec_performance/check_bpp.py
  7. +115
    -0
      eval_dec_performance/logger.py
  8. +317
    -0
      eval_dec_performance/options.py
  9. +9
    -0
      eval_dec_performance/readme.txt

+ 3
- 0
README.md View File

@@ -7,3 +7,6 @@
启动文件为start.py,该程序完成解码器拷贝及解压缩,码流文件解压缩,调用解码器对码流进行解码,以及移动解码图像到输出目录。


解码复杂度评估可使用常见模型复杂度评估工具,本示例使用deepspeed工具。
示例评估脚本详见eval_dec_complexity目录,与解码器(decoder.zip)配合使用,计算时需确保解码阶段的网络参数都参与到统计中。
本示例解码1080P图像共计777.58 GMACs,即0.4*10^6 MACs/pixel (777.58*10^9/1920/1088=372233.07 MACs/pixel)。

+ 59
- 0
eval_dec_complexity/codec_complexity.py View File

@@ -0,0 +1,59 @@
# Copyright (c) 2021-2022, InterDigital Communications, Inc
# All rights reserved.

# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:

# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.

# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


import sys
import torch
from deepspeed.profiling.flops_profiler import get_model_profile

from lvc_dec_complexity import LVC_exp_spy_res

torch.backends.cudnn.deterministic = True



def main(argv):
device = "cuda:0"
#compressai.set_entropy_coder(coder)
ckpt_dir = "network.pth.tar"
net = LVC_exp_spy_res()
state_dict = torch.load(ckpt_dir)
net.load_state_dict(state_dict)
net.to(device).eval()
#frame = torch.ones([1,3,640,640], dtype=torch.float32, device=device)
width, height = 1920, 1088
flops, macs, params = get_model_profile(net, (1,3,width,height))

print("params: ", params)
print("flops: ", flops)
print("macs: ", macs)

if __name__ == "__main__":
main(sys.argv)

+ 431
- 0
eval_dec_complexity/lvc_dec_complexity.py View File

@@ -0,0 +1,431 @@
# Copyright (c) 2021-2022, InterDigital Communications, Inc
# All rights reserved.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted (subject to the limitations in the disclaimer
# below) provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of InterDigital Communications, Inc nor the names of its
# contributors may be used to endorse or promote products derived from this
# software without specific prior written permission.
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import math
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from compressai.entropy_models import GaussianConditional, EntropyBottleneck
from compressai.models.google import CompressionModel, get_scale_table
from compressai.models.utils import (
conv,
deconv,
quantize_ste,
update_registered_buffers,
)
from ckbd import *
def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
"""3x3 convolution with padding."""
return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
"""1x1 convolution."""
return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
def subpel_conv1x1(in_ch, out_ch, r=1):
"""1x1 sub-pixel convolution for up-sampling."""
return nn.Sequential(
nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=1, padding=0), nn.PixelShuffle(r)
)
class ResidualBlockWithStride(nn.Module):
"""Residual block with a stride on the first convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
stride (int): stride value (default: 2)
"""
def __init__(self, in_ch, out_ch, stride=2):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
self.leaky_relu = nn.LeakyReLU()
self.conv2 = conv3x3(out_ch, out_ch)
self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1)
if stride != 1:
self.downsample = conv1x1(in_ch, out_ch, stride=stride)
else:
self.downsample = None
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.leaky_relu2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class ResidualBlockUpsample(nn.Module):
"""Residual block with sub-pixel upsampling on the last convolution.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
upsample (int): upsampling factor (default: 2)
"""
def __init__(self, in_ch, out_ch, upsample=2):
super().__init__()
self.subpel_conv = subpel_conv1x1(in_ch, out_ch, upsample)
self.leaky_relu = nn.LeakyReLU()
self.conv = conv3x3(out_ch, out_ch)
self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1)
self.upsample = subpel_conv1x1(in_ch, out_ch, upsample)
def forward(self, x):
identity = x
out = self.subpel_conv(x)
out = self.leaky_relu(out)
out = self.conv(out)
out = self.leaky_relu2(out)
identity = self.upsample(x)
out += identity
return out
class ResidualBlock(nn.Module):
"""Simple residual block with two 3x3 convolutions.
Args:
in_ch (int): number of input channels
out_ch (int): number of output channels
"""
def __init__(self, in_ch, out_ch, leaky_relu_slope=0.01):
super().__init__()
self.conv1 = conv3x3(in_ch, out_ch)
self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_relu_slope)
self.conv2 = conv3x3(out_ch, out_ch)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.leaky_relu(out)
out = self.conv2(out)
out = self.leaky_relu(out)
out = out + identity
return out
class LVC_exp_spy_res(nn.Module):
def __init__(
self,
):
super().__init__()
class Img_Encoder(nn.Sequential):
def __init__(
self, in_planes: int, mid_planes: int = 192, out_planes: int = 192
):
super().__init__(
ResidualBlockWithStride(in_planes, mid_planes, stride=2),
ResidualBlock(mid_planes, mid_planes),
ResidualBlockWithStride(mid_planes, mid_planes, stride=2),
ResidualBlock(mid_planes, mid_planes),
ResidualBlockWithStride(mid_planes, mid_planes, stride=2),
ResidualBlock(mid_planes, mid_planes),
conv3x3(mid_planes, out_planes, stride=2),
)
class Img_Decoder(nn.Sequential):
def __init__(
self, out_planes: int, in_planes: int = 192, mid_planes: int = 192
):
super().__init__(
ResidualBlock(in_planes, mid_planes),
ResidualBlockUpsample(mid_planes, mid_planes, 2),
ResidualBlock(mid_planes, mid_planes),
ResidualBlockUpsample(mid_planes, mid_planes, 2),
ResidualBlock(mid_planes, mid_planes),
ResidualBlockUpsample(mid_planes, mid_planes, 2),
ResidualBlock(mid_planes, mid_planes),
subpel_conv1x1(mid_planes, out_planes, 2),
)
class Img_HyperEncoder(nn.Sequential):
def __init__(
self, in_planes: int = 192, mid_planes: int = 192, out_planes: int = 192
):
super().__init__(
conv3x3(in_planes, mid_planes),
nn.LeakyReLU(),
conv3x3(mid_planes, mid_planes),
nn.LeakyReLU(),
conv3x3(mid_planes, mid_planes, stride=2),
nn.LeakyReLU(),
conv3x3(mid_planes, mid_planes),
nn.LeakyReLU(),
conv3x3(mid_planes, out_planes, stride=2),
)
class Img_HyperDecoder(nn.Sequential):
def __init__(
self, in_planes: int = 192, mid_planes: int = 192, out_planes: int = 192
):
super().__init__(
conv3x3(in_planes, mid_planes),
nn.LeakyReLU(),
subpel_conv1x1(mid_planes, mid_planes, 2),
nn.LeakyReLU(),
conv3x3(mid_planes, mid_planes * 3 // 2),
nn.LeakyReLU(),
subpel_conv1x1(mid_planes * 3 // 2, mid_planes * 3 // 2, 2),
nn.LeakyReLU(),
conv3x3(mid_planes * 3 // 2, out_planes * 2),
)
class Joint_Hy_Dual(CompressionModel):
def __init__(self, planes: int = 192, mid_planes: int = 192):
super().__init__(entropy_bottleneck_channels=mid_planes)
self.planes = planes
self.hyper_encoder = Img_HyperEncoder(planes, mid_planes, planes)
self.hyper_decoder = Img_HyperDecoder(planes, mid_planes, planes)
#self.dual_context = nn.Conv2d(in_channels=planes, out_channels=planes * 2, kernel_size=5, stride=1, padding=2)
self.dual_context = nn.Sequential(
conv3x3(planes, planes),
nn.LeakyReLU(),
conv3x3(planes, planes * 3 // 2),
nn.LeakyReLU(),
conv3x3(planes * 3 // 2, planes * 2),
)
self.entropy_parameters = nn.Sequential(
conv1x1(planes * 4, planes * 3),
nn.LeakyReLU(),
conv1x1(planes * 3, planes * 2),
nn.LeakyReLU(),
conv1x1(planes * 2, planes * 2),
)
self.gaussian_conditional = GaussianConditional(None)
def forward(self, y):
z = self.hyper_encoder(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
hyperprior = self.hyper_decoder(z_hat)
y_anchor, y_nonanchor = dual_split(y)
params_anchor = self.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
scales_anchor, means_anchor = params_anchor.chunk(2, 1)
scales_anchor = dual_anchor(scales_anchor)
means_anchor = dual_anchor(means_anchor)
y_anchor = quantize_ste(y_anchor - means_anchor) + means_anchor
y_anchor = dual_anchor(y_anchor)
spatialprior = self.dual_context(y_anchor)
params_nonanchor = self.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
scales_nonanchor = dual_nonanchor(scales_nonanchor)
means_nonanchor = dual_nonanchor(means_nonanchor)
scales = ckbd_merge(scales_anchor, scales_nonanchor)
means = ckbd_merge(means_anchor, means_nonanchor)
_, y_likelihoods = self.gaussian_conditional(y, scales, means)
y_nonanchor = quantize_ste(y_nonanchor - means_nonanchor) + means_nonanchor
y_nonanchor = dual_nonanchor(y_nonanchor)
y_hat = y_anchor + y_nonanchor
return y_hat, {"y": y_likelihoods, "z": z_likelihoods}
def compress(self, y):
z = self.hyper_encoder(y)
z_string = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_string, z.size()[-2:])
hyperprior = self.hyper_decoder(z_hat)
y_anchor, y_nonanchor = dual_split(y)
params_anchor = self.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
scales_anchor, means_anchor = params_anchor.chunk(2, 1)
anchor_squeeze = dual_anchor_sequeeze(y_anchor)
scales_anchor_squeeze = dual_anchor_sequeeze(scales_anchor)
means_anchor_squeeze = dual_anchor_sequeeze(means_anchor)
indexes_anchor = self.gaussian_conditional.build_indexes(scales_anchor_squeeze)
y_string_anchor = self.gaussian_conditional.compress(anchor_squeeze, indexes_anchor, means_anchor_squeeze)
y_hat_anchor_squeeze = self.gaussian_conditional.quantize(anchor_squeeze, "dequantize", means_anchor_squeeze)
y_hat_anchor = dual_anchor_unsequeeze(y_hat_anchor_squeeze)
spatialprior = self.dual_context(y_hat_anchor)
params_nonanchor = self.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
nonanchor_squeeze = dual_nonanchor_sequeeze(y_nonanchor)
scales_nonanchor_squeeze = dual_nonanchor_sequeeze(scales_nonanchor)
means_nonanchor_squeeze = dual_nonanchor_sequeeze(means_nonanchor)
indexes_nonanchor = self.gaussian_conditional.build_indexes(scales_nonanchor_squeeze)
y_string_nonanchor = self.gaussian_conditional.compress(nonanchor_squeeze, indexes_nonanchor, means_nonanchor_squeeze)
y_hat_nonanchor_squeeze = self.gaussian_conditional.quantize(nonanchor_squeeze, "dequantize", means_nonanchor_squeeze)
y_hat_nonanchor = dual_nonanchor_unsequeeze(y_hat_nonanchor_squeeze)
y_hat = y_hat_anchor + y_hat_nonanchor
return y_hat, {"strings": [y_string_anchor, y_string_nonanchor, z_string], "shape": z.size()[-2:]}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 3
z_hat = self.entropy_bottleneck.decompress(strings[2], shape)
hyperprior = self.hyper_decoder(z_hat)
params_anchor = self.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
scales_anchor, means_anchor = params_anchor.chunk(2, 1)
scales_anchor_squeeze = dual_anchor_sequeeze(scales_anchor)
means_anchor_squeeze = dual_anchor_sequeeze(means_anchor)
indexes_anchor = self.gaussian_conditional.build_indexes(scales_anchor_squeeze)
y_hat_anchor_squeeze = self.gaussian_conditional.decompress(strings[0], indexes_anchor, z_hat.dtype, means_anchor_squeeze)
y_hat_anchor = dual_anchor_unsequeeze(y_hat_anchor_squeeze)
spatialprior = self.dual_context(y_hat_anchor)
params_nonanchor = self.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
scales_nonanchor_squeeze = dual_nonanchor_sequeeze(scales_nonanchor)
means_nonanchor_squeeze = dual_nonanchor_sequeeze(means_nonanchor)
indexes_nonanchor = self.gaussian_conditional.build_indexes(scales_nonanchor_squeeze)
y_hat_nonanchor_squeeze = self.gaussian_conditional.decompress(strings[1], indexes_nonanchor, z_hat.dtype, means_nonanchor_squeeze)
y_hat_nonanchor = dual_nonanchor_unsequeeze(y_hat_nonanchor_squeeze)
y_hat = y_hat_anchor + y_hat_nonanchor
return y_hat
self.img_encoder = Img_Encoder(in_planes=3, mid_planes=192, out_planes=192)
self.img_decoder = Img_Decoder(in_planes=192, mid_planes=192, out_planes=3)
self.img_entropymodel = Joint_Hy_Dual()
self.img_var_factor = nn.Parameter(torch.ones([4, 1, 192, 1, 1]))
self.img_var_bias = nn.Parameter(torch.ones(1))
# def forward(self, frames, factor):
# if not isinstance(frames, List):
# raise RuntimeError(f"Invalid number of frames: {len(frames)}.")
# reconstructions = []
# frames_likelihoods = []
# x_hat, likelihoods = self.forward_keyframe(frames[0], factor)
# reconstructions.append(x_hat)
# frames_likelihoods.append(likelihoods)
# return {
# "x_hat": reconstructions,
# "likelihoods": frames_likelihoods,
# }
# def forward_keyframe(self, x, factor):
# y = self.img_encoder(x)
# y = self.img_var_factor[factor] * y
# y_hat, likelihoods = self.img_entropymodel(y)
# y_hat = self.img_var_bias * y_hat / self.img_var_factor[factor]
# x_hat = self.img_decoder(y_hat)
# return x_hat, {"keyframe": likelihoods}
# def encode_keyframe(self, x, factor):
# y = self.img_encoder(x)
# y = self.img_var_factor[factor] * y
# y_hat, out_keyframe = self.img_entropymodel.compress(y)
# y_hat = self.img_var_bias * y_hat / self.img_var_factor[factor]
# x_hat = self.img_decoder(y_hat)
# return x_hat, out_keyframe
# def decode_keyframe(self, strings, shape, factor):
# y_hat = self.img_entropymodel.decompress(strings, shape)
# y_hat = self.img_var_bias * y_hat / self.img_var_factor[factor]
# x_hat = self.img_decoder(y_hat)
# return x_hat
def forward(self, image):
z_hat = torch.rand((image.size(0), self.img_entropymodel.planes, image.size(2) // 64, image.size(3) // 64)).cuda()
hyperprior = self.img_entropymodel.hyper_decoder(z_hat)
params_anchor = self.img_entropymodel.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
y_hat_anchor = torch.rand((image.size(0), self.img_entropymodel.planes, image.size(2) // 16, image.size(3) // 16)).cuda()
spatialprior = self.img_entropymodel.dual_context(y_hat_anchor)
params_nonanchor = self.img_entropymodel.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
y_hat = torch.rand((image.size(0), self.img_entropymodel.planes, image.size(2) // 16, image.size(3) // 16)).cuda()
x_hat = self.img_decoder(y_hat)
def load_state_dict(self, state_dict, strict=True, update_buffer=True):
if update_buffer:
# Dynamically update the entropy bottleneck buffers related to the CDFs
update_registered_buffers(
self.img_entropymodel.gaussian_conditional,
"img_entropymodel.gaussian_conditional",
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
state_dict,
)
update_registered_buffers(
self.img_entropymodel.entropy_bottleneck,
"img_entropymodel.entropy_bottleneck",
["_quantized_cdf", "_offset", "_cdf_length"],
state_dict,
)
super().load_state_dict(state_dict, strict=strict)
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
net = cls()
net.load_state_dict(state_dict)
return net
def update(self, scale_table=None, force=False):
if scale_table is None:
scale_table = get_scale_table()
updated = self.img_entropymodel.gaussian_conditional.update_scale_table(
scale_table, force=force
)
# updated |= super().update(force=force)
for m in self.modules():
if not isinstance(m, EntropyBottleneck):
continue
rv = m.update(force=force)
updated |= rv
return updated

+ 277
- 0
eval_dec_performance/calc_metrics.py View File

@@ -0,0 +1,277 @@
import argparse
import json
import os
import shutil

import tempfile
import time
import traceback
from abc import ABCMeta, abstractmethod
from concurrent import futures
from typing import Dict, List, Tuple

import cv2
import lpips
import numpy as np
import skimage
import torch
import torchvision.transforms as T
from DISTS_pytorch import DISTS
from PIL import Image
from glob import glob
from tqdm import tqdm
from torch.utils.data import Dataset
from pytorch_msssim import ms_ssim

from options import TestConfig
from logger import get_root_logger, bolded_log


class CustomConfig(TestConfig):
@classmethod
def get_opt(cls) -> "CustomConfig":
arg_dict = cls.arg_parse()
return cls(arg_dict)

@staticmethod
def arg_parse() -> Dict:
parser = argparse.ArgumentParser()
parser.add_argument("--real_dir", type=str)
parser.add_argument("--fake_dir", type=str)
parser.add_argument("-d", "--device", type=str, default="cuda:0")

args = parser.parse_args()
out_dict = vars(args) # argparse.Namespace -> Dict
return out_dict


class ImageDataset(Dataset):
def __init__(
self, real_path_list: List[str], fake_path_list: List[str], scale: str
) -> None:
super().__init__()
assert scale in ["0_1", "-1_1"]
self.real_path_list = real_path_list
self.fake_path_list = fake_path_list
transform_list = [T.ToTensor()]
if scale == "-1_1":
transform_list.append(T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]))
self.transform = T.Compose(transform_list)

def __len__(self) -> int:
return len(self.real_path_list)

def read_img(self, path):
img = Image.open(path).convert("RGB")
return self.transform(img)

def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
real_path = self.real_path_list[index]
fake_path = self.fake_path_list[index]
return self.read_img(real_path), self.read_img(fake_path)


def get_dataloader(real_path_list: List[str], fake_path_list: List[str], scale: str):
dataset = ImageDataset(real_path_list, fake_path_list, scale)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=1, drop_last=False, shuffle=False, num_workers=1
)
return dataloader


class BaseMetric(metaclass=ABCMeta):
def __init__(self, opt, metric_name: str):
self.opt = opt
self.metric_name = metric_name
self.logger = get_root_logger()

@staticmethod
def get_real_fake_path_list(
real_dir: str, fake_dir: str
) -> Tuple[List[str], List[str]]:
assert os.path.exists(real_dir)
assert os.path.exists(fake_dir)
real_path_list = glob(os.path.join(real_dir, f"*.png"))
fake_path_list = glob(os.path.join(fake_dir, f"*.png"))
assert len(real_path_list) == len(fake_path_list)
real_path_list.sort()
fake_path_list.sort()
for r, f in zip(real_path_list, fake_path_list):
assert os.path.basename(r) == os.path.basename(f[:-8]+".png")
return real_path_list, fake_path_list

@abstractmethod
def calc_metric(self, real_path_list, fake_path_list) -> float:
raise NotImplementedError()

def run(self, real_dir: str, fake_dir: str) -> float:
real_path_list, fake_path_list = self.get_real_fake_path_list(
real_dir, fake_dir
)
num_img = len(real_path_list)
img_avg_value = self.calc_metric(real_path_list, fake_path_list)
self.logger.info(f"{num_img} images: {self.metric_name}: {img_avg_value:.4}")
return img_avg_value


class PSNRMetric(BaseMetric):
def __init__(self, opt):
super().__init__(opt, "PSNR")

@staticmethod
def read_img(img_path: str) -> np.ndarray:
"""return np RGB [0, 1]"""
img_np = np.asarray(Image.open(img_path).convert("RGB"), dtype=np.float32)
return img_np

def calc_metric(self, real_path_list, fake_path_list) -> float:
self.num_dims = 0
self.sqerror_values = []
self.img_psnr_values = []
max_workers = 8
future_list = []
with futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
future = executor.submit(self.monitor_progress, total=len(real_path_list))
future_list.append(future)
for real_path, fake_path in zip(real_path_list, fake_path_list):
future = executor.submit(
self._calc_one_image, real_path=real_path, fake_path=fake_path
)
future_list.append(future)
_ = futures.as_completed(fs=future_list)
# mse = np.sum(self.sqerror_values) / self.num_dims
# pix_avg_psnr = 20. * np.log10(255.) - 10. * np.log10(mse)
img_avg_psnr = np.mean(self.img_psnr_values)
return img_avg_psnr

def monitor_progress(self, total):
strt = time.time()
n = len(self.sqerror_values)
while n < total:
print(
f"\r{n:6} / {total} imgs {time.time()-strt:5.2f}s", flush=True, end=""
)
time.sleep(0.1)
n = len(self.sqerror_values)
print("\r", end="")

def _calc_one_image(self, real_path, fake_path):
assert os.path.basename(fake_path[:-8]+".png") == os.path.basename(real_path)
image0 = self.read_img(real_path)
image1 = self.read_img(fake_path)
self.num_dims += image0.size
sqerror = np.sum(np.square(image1 - image0))
self.sqerror_values.append(sqerror)
_mse = sqerror / image0.size
self.img_psnr_values.append(20.0 * np.log10(255.0) - 10.0 * np.log10(_mse))


class MSSSIMMetric(BaseMetric):
def __init__(self, opt):
super().__init__(opt, "MSSSIM")
self.device = opt.device

@torch.no_grad()
def calc_metric(self, real_path_list, fake_path_list) -> float:
msssim_list = []
dataloader = get_dataloader(real_path_list, fake_path_list, scale="0_1")

for img_real, img_fake in tqdm(
dataloader, total=len(dataloader), leave=False, ncols=80
):
_, _, H, W = img_real.size()
assert img_fake.shape == img_real.shape
m = ms_ssim(img_real, img_fake, data_range=1.0)
msssim_list.append(m.item())
return np.mean(msssim_list)


class LPIPSMetric(BaseMetric):
def __init__(self, opt):
super().__init__(opt, "LPIPS")
self.device = opt.device
self.lpips_fn = lpips.LPIPS(net="vgg").to(self.device)
self.wandb_save_type = "pix_avg"

@torch.no_grad()
def calc_metric(self, real_path_list, fake_path_list) -> float:
lpips_list = []
dataloader = get_dataloader(real_path_list, fake_path_list, scale="-1_1")

for img_real, img_fake in tqdm(
dataloader, total=len(dataloader), leave=False, ncols=80
):
_, _, H, W = img_real.size()
assert img_fake.shape == img_real.shape
dist = self.lpips_fn.forward(
img_fake.to(self.device), img_real.to(self.device)
) # calc LPIPS
lpips_list.append(dist.item())
return np.mean(lpips_list)


class DISTSMetric(BaseMetric):
def __init__(self, opt):
super().__init__(opt, "DISTS")
self.device = opt.device
self.dists_fn = DISTS().to(self.device)

@torch.no_grad()
def calc_metric(self, real_path_list, fake_path_list) -> float:
dists_list = []
dataloader = get_dataloader(real_path_list, fake_path_list, scale="0_1")

for img_real, img_fake in tqdm(
dataloader, total=len(dataloader), leave=False, ncols=80
):
_, _, H, W = img_real.size()
assert img_fake.shape == img_real.shape
# calculate DISTS between X, Y (a batch of RGB images, data range: 0~1)
dist = self.dists_fn.forward(
img_fake.to(self.device), img_real.to(self.device), require_grad=False
)
dists_list.append(dist.item())
return np.mean(dists_list)


def main():
opt = CustomConfig.get_opt()
logger = get_root_logger(log_level="INFO")
real_dir = opt.real_dir
fake_dir = opt.fake_dir

metrics_dict = {
"PSNR": PSNRMetric(opt),
"MS-SSIM": MSSSIMMetric(opt),
"LPIPS": LPIPSMetric(opt),
"DISTS": DISTSMetric(opt),
}

logger.info("Calculate " + ", ".join(list(metrics_dict.keys())))

out_dict = {
}

for metrics_name, met_obj in metrics_dict.items():
bolded_log(msg=f"Calc {metrics_name}", level="INFO", new_line=True)
try:
met_val = met_obj.run(real_dir, fake_dir)
out_dict[metrics_name] = met_val
except KeyboardInterrupt:
traceback.print_exc()
exit()
except:
logger.error(f"ERROR: skip {fake_dir}")
traceback.print_exc()

json_path = os.path.join(fake_dir, "dec_metrics.json")
with open(json_path, "w") as f:
json.dump(out_dict, f, indent=4)

logger.info(f"Results: {fake_dir}")
for k, v in out_dict.items():
logger.info(f"{k:>7}: {v:.4f}")


if __name__ == "__main__":
main()

+ 2
- 0
eval_dec_performance/calc_metrics.sh View File

@@ -0,0 +1,2 @@
python check_bpp.py check /path_to_dec/
python calc_metrics.py --real_dir '/path_to_gt/' --fake_dir '/path_to_dec/'

+ 42
- 0
eval_dec_performance/check_bpp.py View File

@@ -0,0 +1,42 @@
from pathlib import Path
import argparse
import glob
import sys
from PIL import Image
bpp_constraint = 0.1
def filesize(filepath: str) -> int:
if not Path(filepath).is_file():
raise ValueError(f'Invalid file "{filepath}".')
return Path(filepath).stat().st_size
def get_bpp(argv):
parser = argparse.ArgumentParser(description="Evaluate per-frame bpp.")
parser.add_argument("input", type=str)
args = parser.parse_args(argv)
for bits in glob.glob(args.input+"*.bits"):
print(bits)
total_bits = filesize(bits) * 8
img = Image.open(bits[:-5] + "_dec.png")
constraint_bits = img.size[0] * img.size[1] * bpp_constraint
print(total_bits <= constraint_bits)
def parse_args(argv):
parser = argparse.ArgumentParser(description="")
parser.add_argument("command", choices=["check"])
args = parser.parse_args(argv)
return args
def main(argv):
args = parse_args(argv[1:2])
argv = argv[2:]
if args.command == "check":
get_bpp(argv)
if __name__ == "__main__":
main(sys.argv)

+ 115
- 0
eval_dec_performance/logger.py View File

@@ -0,0 +1,115 @@
from collections import defaultdict
import copy
import logging
import os
import re
import pandas as pd

from python_log_indenter import IndentedLoggerAdapter

from typing import Dict, Optional, Union
import importlib
import os.path as osp
from glob import glob

import subprocess
import time
from datetime import datetime

import torch
from tqdm import tqdm

class Color:
BLACK = "\033[30m"
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
COLOR_DEFAULT = "\033[39m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"
INVISIBLE = "\033[08m"
REVERCE = "\033[07m"
BG_BLACK = "\033[40m"
BG_RED = "\033[41m"
BG_GREEN = "\033[42m"
BG_YELLOW = "\033[43m"
BG_BLUE = "\033[44m"
BG_MAGENTA = "\033[45m"
BG_CYAN = "\033[46m"
BG_WHITE = "\033[47m"
BG_DEFAULT = "\033[49m"
RESET = "\033[0m"

initialized_logger = {}


def bolded_log(msg: str, level: Union[int, str]='INFO', new_line: bool=False, prefix: str='===== ', suffix: str=' ====='):
if new_line:
print()
msg = f'{Color.BOLD}{prefix}{msg}{suffix}{Color.RESET}'
logger = get_root_logger()
if isinstance(level, str):
level = getattr(logging, level)
logger.log(level=level, msg=msg)

class ColorStreamHandler(logging.StreamHandler):
mapping = {
"TRACE": "[ TRACE ]",
"DEBUG": "[ DEBUG ]",
"INFO": "[ INFO ]",
"WARNING": f"{Color.RED}[ WARNING]{Color.RESET}",
"WARN": f"{Color.RED}[ WARN ]{Color.RESET}",
"ERROR": f"{Color.BG_RED}[ ERROR ]{Color.RESET}",
"ALERT": f"{Color.BG_RED}[ ALERT ]{Color.RESET}",
"CRITICAL": f"{Color.BG_RED}[CRITICAL]{Color.RESET}",
}
def emit(self, record):
record = copy.deepcopy(record)
record.levelname = ColorStreamHandler.mapping[record.levelname]
super().emit(record)


def get_root_logger(logger_name='basiccomp', log_level: Union[int, str]=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basiccomp'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int or str): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
if logger_name in initialized_logger:
return initialized_logger[logger_name]
if isinstance(log_level, str):
log_level = getattr(logging, log_level)

logger = IndentedLoggerAdapter(logging.getLogger(logger_name), spaces=2)
logger.logger.setLevel(logging.DEBUG)

format_str = '%(levelname)-10s %(message)s'
stream_handler = ColorStreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
stream_handler.setLevel(log_level)
logger.logger.addHandler(stream_handler)
logger.logger.propagate = False

# add file handler
if log_file is not None:
file_handler = logging.FileHandler(log_file, 'w')
file_format_str = '%(asctime)s %(levelname)-8s: %(message)s'
file_handler.setFormatter(DelColorFormatter(file_format_str))
file_handler.setLevel(logging.DEBUG)
logger.logger.addHandler(file_handler)
initialized_logger[logger_name] = logger
return logger

+ 317
- 0
eval_dec_performance/options.py View File

@@ -0,0 +1,317 @@
# from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py
import argparse
import copy
import os
import os.path as osp
from typing import Dict, List, Optional, Tuple

import yaml
from addict import Dict as Addict

import re
import os
import os.path as osp
from datetime import datetime
from glob import glob
from typing import Optional, Dict

def check_file_exist(filename: str, msg_tmpl='file "{}" does not exist'):
if not osp.isfile(filename):
raise FileNotFoundError(msg_tmpl.format(filename))


class PathHandler(object):
def __init__(self, ckpt_root: str, exp: str) -> None:
self.ckpt_root = ckpt_root
self.exp = exp
self.job_dir = os.path.join(self.ckpt_root, self.exp)

def make_job_dir(self) -> None:
job_dir = os.path.join(self.ckpt_root, self.exp)
os.makedirs(os.path.join(job_dir, 'model'), exist_ok=True)
os.makedirs(os.path.join(job_dir, 'sample'), exist_ok=True)

def get_exp_path_dict(self) -> Dict:
job_dir = os.path.join(self.ckpt_root, self.exp)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
return {
'ckpt_root': self.ckpt_root,
'job_dir': job_dir,
'model_dir': osp.join(job_dir, 'model'),
'sample_dir': osp.join(job_dir, 'sample'),
'log_loss_path': osp.join(job_dir, 'log_loss.csv'),
'log_eval_path': osp.join(job_dir, 'eval_result.csv'),
'log_msg_path': osp.join(job_dir, f'train_{timestamp}.log'),
'sample_dir': osp.join(job_dir, 'sample'),
}

def get_ckpt_path(self, label: str, itr: int):
model_dir = self.get_exp_path_dict()['model_dir']
itr_str = self.iter2str(itr)
return os.path.join(model_dir, f'{label}_iter{itr_str}.pth.tar')

@staticmethod
def iter2str(itr: int) -> str:
if itr % 1000 == 0:
return str(itr // 1000) + 'K'
return str(itr)

BASE_KEY = "_base_"
DELETE_KEY = "_delete_"
DEPRECATION_KEY = "_deprecation_"
RESERVED_KEYS = ["filename", "text"]


class ConfigDict(Addict):

def __missing__(self, name):
raise KeyError(name)

def __getattr__(self, name):
try:
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(
f"'{self.__class__.__name__}' object has no " f"attribute '{name}'"
)
except Exception as e:
ex = e
else:
return value
raise ex


# class Config:
class BaseConfig:
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
attributes.
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""

@staticmethod
def _file2dict_yaml(filename: str) -> Tuple[Dict, str, List]:
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
loaded_yamls = [filename]
fileExtname = osp.splitext(filename)[1]
if fileExtname not in [".yaml"]:
raise IOError("Only yaml type are supported now!")

with open(filename) as f:
cfg_dict: Dict = yaml.safe_load(f)

cfg_text = filename + "\n"
with open(filename, "r", encoding="utf-8") as f:
cfg_text += f.read()

if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = (
base_filename if isinstance(base_filename, list) else [base_filename]
)

cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text, _loaded_yamls = BaseConfig._file2dict_yaml(
osp.join(cfg_dir, f)
)
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
loaded_yamls.extend(_loaded_yamls)

base_cfg_dict = dict()
for c in cfg_dict_list:
duplicate_keys = base_cfg_dict.keys() & c.keys()
if len(duplicate_keys) > 0:
raise KeyError(
"Duplicate key is not allowed among bases. "
f"Duplicate keys: {duplicate_keys}"
)
base_cfg_dict.update(c)

base_cfg_dict = BaseConfig._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict

# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = "\n".join(cfg_text_list)

return cfg_dict, cfg_text, loaded_yamls

@staticmethod
def _merge_a_into_b(a: Dict, b: Dict) -> Dict:
b = b.copy()
for k, v in a.items():
# if (1) v is dict, (2) both a and b have k, (3) v does not have DELETE_KEY
if isinstance(v, dict) and (k in b) and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict):
raise TypeError(
f"{k}={v} in child config cannot inherit from base "
f"because {k} is a dict in the child config but is of "
f"type {type(b[k])} in base config. You may set "
f"`{DELETE_KEY}=True` to ignore the base config"
)
b[k] = BaseConfig._merge_a_into_b(v, b[k])
else:
b[k] = v
return b

def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError("cfg_dict must be a dict, but " f"got {type(cfg_dict)}")
for key in cfg_dict:
if key in RESERVED_KEYS:
raise KeyError(f"{key} is reserved for config file")

super(BaseConfig, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
super(BaseConfig, self).__setattr__("_filename", filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, "r") as f:
text = f.read()
else:
text = ""
super(BaseConfig, self).__setattr__("_text", text)

@property
def filename(self):
return self._filename

@property
def text(self):
return self._text

def __repr__(self):
return f"Config (path: {self.filename}): {self._cfg_dict.__repr__()}"

def __len__(self):
return len(self._cfg_dict)

def __getattr__(self, name):
return getattr(self._cfg_dict, name)

def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)

def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)

def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)

def __iter__(self):
return iter(self._cfg_dict)

def __getstate__(self):
return (self._cfg_dict, self._filename, self._text)

def __setstate__(self, state):
_cfg_dict, _filename, _text = state
super(BaseConfig, self).__setattr__("_cfg_dict", _cfg_dict)
super(BaseConfig, self).__setattr__("_filename", _filename)
super(BaseConfig, self).__setattr__("_text", _text)

def dump(self, filename: str) -> None:
cfg_dict = super(BaseConfig, self).__getattribute__("_cfg_dict").to_dict()
with open(filename, "w") as f:
yaml.dump(cfg_dict, f)


class TestConfig(BaseConfig):
@classmethod
def get_opt(cls, config_dir: str, arg_dict: Optional[Dict] = None) -> "TestConfig":
if not (arg_dict):
arg_dict = cls.arg_parse()

filename = arg_dict["config_path"]
cfg_dict, cfg_text, loaded_yamls = cls._file2dict_yaml(filename)
cfg_dict["loaded_yamls"] = loaded_yamls
arg_dict = cls._merge_a_into_b(arg_dict, cfg_dict)
arg_dict["exp"] = os.path.basename(filename).split(".")[0]
arg_dict["path"] = cls.get_path_dict(arg_dict)
arg_dict["host"] = os.uname()[1]
arg_dict["is_train"] = False
arg_dict["dataset"]["test_dataset"] = copy.deepcopy(
cls.get_test_dataset_config(arg_dict)
)
return cls(arg_dict, cfg_text=cfg_text, filename=filename)

@staticmethod
def arg_parse() -> Dict:
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str)
parser.add_argument("iter", type=int)
parser.add_argument("test_dataset_name", type=str)
parser.add_argument("-s", "--sample_size", type=int, default=1000000)
parser.add_argument("-d", "--device", type=str, default="cuda:0")

store_true_args = ["notsave", "notgtmask", "debug"]
for key in store_true_args:
parser.add_argument(f"--{key}", action="store_true")

args = parser.parse_args()
out_dict = vars(args) # argparse.Namespace -> Dict

nonetype_keys = [k for k, v in out_dict.items() if v is None]
for k in nonetype_keys:
del out_dict[k]
for key in store_true_args:
if not (out_dict[key]):
del out_dict[key]

return out_dict

@staticmethod
def get_path_dict(cfg_dict: Dict) -> Dict:
ckpt_root = cfg_dict["ckpt_root"]
exp_name = cfg_dict["exp"] # like "exp1-2_1"
load_iter = cfg_dict["iter"]
dataset_name = cfg_dict["test_dataset_name"]
path_handler = PathHandler(ckpt_root, exp_name)
path_dict = path_handler.get_exp_path_dict()
model_path = path_handler.get_ckpt_path("comp_model", itr=load_iter)
assert os.path.exists(model_path), f'model_path "{model_path}" does not exist.'
sample_dir = os.path.join(
path_dict["sample_dir"], f"{exp_name}_iter{load_iter//1000}K_{dataset_name}"
)
return {
"ckpt_root": ckpt_root,
"model_dir": path_dict["model_dir"],
"sample_dir": sample_dir,
"model_path": model_path,
}

@staticmethod
def get_test_dataset_config(cfg_dict: Dict) -> Dict:
test_dataset_config = copy.deepcopy(cfg_dict["dataset"]["eval_dataset"])
test_dataset_name = cfg_dict["test_dataset_name"]
test_dataset_config["name"] = test_dataset_name
if "notgtmask" in cfg_dict:
test_dataset_config["use_gt_mask"] = not (cfg_dict["notgtmask"])
return test_dataset_config

+ 9
- 0
eval_dec_performance/readme.txt View File

@@ -0,0 +1,9 @@
1. check_bpp.py 检查码率是否超过限制,输入路径为码流路径
2. python calc_metrics.py --real_dir './' --fake_dir './'
real_dir为原图路径,fake_dir为解码图像路径,输出4个指标结果在json文件中,如下所示:
{
"PSNR": 27.006997590951404,
"MS-SSIM": 0.9165153607726098,
"LPIPS": 0.40229594334959984,
"DISTS": 0.20327724069356917
}

Loading…
Cancel
Save
Baidu
map