|
|
|
@@ -0,0 +1,431 @@ |
|
|
|
# Copyright (c) 2021-2022, InterDigital Communications, Inc
|
|
|
|
# All rights reserved.
|
|
|
|
|
|
|
|
# Redistribution and use in source and binary forms, with or without
|
|
|
|
# modification, are permitted (subject to the limitations in the disclaimer
|
|
|
|
# below) provided that the following conditions are met:
|
|
|
|
|
|
|
|
# * Redistributions of source code must retain the above copyright notice,
|
|
|
|
# this list of conditions and the following disclaimer.
|
|
|
|
# * Redistributions in binary form must reproduce the above copyright notice,
|
|
|
|
# this list of conditions and the following disclaimer in the documentation
|
|
|
|
# and/or other materials provided with the distribution.
|
|
|
|
# * Neither the name of InterDigital Communications, Inc nor the names of its
|
|
|
|
# contributors may be used to endorse or promote products derived from this
|
|
|
|
# software without specific prior written permission.
|
|
|
|
|
|
|
|
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
|
|
|
|
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
|
|
|
|
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
|
|
|
|
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
|
|
|
|
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
|
|
|
|
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
|
|
|
|
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
|
|
|
|
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
|
|
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
|
|
|
|
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
|
|
|
|
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
|
|
|
|
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
|
|
import math
|
|
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
from compressai.entropy_models import GaussianConditional, EntropyBottleneck
|
|
|
|
|
|
|
|
from compressai.models.google import CompressionModel, get_scale_table
|
|
|
|
from compressai.models.utils import (
|
|
|
|
conv,
|
|
|
|
deconv,
|
|
|
|
quantize_ste,
|
|
|
|
update_registered_buffers,
|
|
|
|
)
|
|
|
|
from ckbd import *
|
|
|
|
|
|
|
|
def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
|
|
|
|
"""3x3 convolution with padding."""
|
|
|
|
return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
|
|
|
|
|
|
|
|
|
|
|
|
def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
|
|
|
|
"""1x1 convolution."""
|
|
|
|
return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
|
|
|
|
|
|
|
|
def subpel_conv1x1(in_ch, out_ch, r=1):
|
|
|
|
"""1x1 sub-pixel convolution for up-sampling."""
|
|
|
|
return nn.Sequential(
|
|
|
|
nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=1, padding=0), nn.PixelShuffle(r)
|
|
|
|
)
|
|
|
|
|
|
|
|
class ResidualBlockWithStride(nn.Module):
|
|
|
|
"""Residual block with a stride on the first convolution.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
in_ch (int): number of input channels
|
|
|
|
out_ch (int): number of output channels
|
|
|
|
stride (int): stride value (default: 2)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, in_ch, out_ch, stride=2):
|
|
|
|
super().__init__()
|
|
|
|
self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
|
|
|
|
self.leaky_relu = nn.LeakyReLU()
|
|
|
|
self.conv2 = conv3x3(out_ch, out_ch)
|
|
|
|
self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1)
|
|
|
|
if stride != 1:
|
|
|
|
self.downsample = conv1x1(in_ch, out_ch, stride=stride)
|
|
|
|
else:
|
|
|
|
self.downsample = None
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
identity = x
|
|
|
|
out = self.conv1(x)
|
|
|
|
out = self.leaky_relu(out)
|
|
|
|
out = self.conv2(out)
|
|
|
|
out = self.leaky_relu2(out)
|
|
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
identity = self.downsample(x)
|
|
|
|
|
|
|
|
out += identity
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualBlockUpsample(nn.Module):
|
|
|
|
"""Residual block with sub-pixel upsampling on the last convolution.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
in_ch (int): number of input channels
|
|
|
|
out_ch (int): number of output channels
|
|
|
|
upsample (int): upsampling factor (default: 2)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, in_ch, out_ch, upsample=2):
|
|
|
|
super().__init__()
|
|
|
|
self.subpel_conv = subpel_conv1x1(in_ch, out_ch, upsample)
|
|
|
|
self.leaky_relu = nn.LeakyReLU()
|
|
|
|
self.conv = conv3x3(out_ch, out_ch)
|
|
|
|
self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1)
|
|
|
|
self.upsample = subpel_conv1x1(in_ch, out_ch, upsample)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
identity = x
|
|
|
|
out = self.subpel_conv(x)
|
|
|
|
out = self.leaky_relu(out)
|
|
|
|
out = self.conv(out)
|
|
|
|
out = self.leaky_relu2(out)
|
|
|
|
identity = self.upsample(x)
|
|
|
|
out += identity
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
"""Simple residual block with two 3x3 convolutions.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
in_ch (int): number of input channels
|
|
|
|
out_ch (int): number of output channels
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, in_ch, out_ch, leaky_relu_slope=0.01):
|
|
|
|
super().__init__()
|
|
|
|
self.conv1 = conv3x3(in_ch, out_ch)
|
|
|
|
self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_relu_slope)
|
|
|
|
self.conv2 = conv3x3(out_ch, out_ch)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
identity = x
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
out = self.leaky_relu(out)
|
|
|
|
out = self.conv2(out)
|
|
|
|
out = self.leaky_relu(out)
|
|
|
|
|
|
|
|
out = out + identity
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class LVC_exp_spy_res(nn.Module):
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
class Img_Encoder(nn.Sequential):
|
|
|
|
def __init__(
|
|
|
|
self, in_planes: int, mid_planes: int = 192, out_planes: int = 192
|
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
ResidualBlockWithStride(in_planes, mid_planes, stride=2),
|
|
|
|
ResidualBlock(mid_planes, mid_planes),
|
|
|
|
ResidualBlockWithStride(mid_planes, mid_planes, stride=2),
|
|
|
|
ResidualBlock(mid_planes, mid_planes),
|
|
|
|
ResidualBlockWithStride(mid_planes, mid_planes, stride=2),
|
|
|
|
ResidualBlock(mid_planes, mid_planes),
|
|
|
|
conv3x3(mid_planes, out_planes, stride=2),
|
|
|
|
)
|
|
|
|
|
|
|
|
class Img_Decoder(nn.Sequential):
|
|
|
|
def __init__(
|
|
|
|
self, out_planes: int, in_planes: int = 192, mid_planes: int = 192
|
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
ResidualBlock(in_planes, mid_planes),
|
|
|
|
ResidualBlockUpsample(mid_planes, mid_planes, 2),
|
|
|
|
ResidualBlock(mid_planes, mid_planes),
|
|
|
|
ResidualBlockUpsample(mid_planes, mid_planes, 2),
|
|
|
|
ResidualBlock(mid_planes, mid_planes),
|
|
|
|
ResidualBlockUpsample(mid_planes, mid_planes, 2),
|
|
|
|
ResidualBlock(mid_planes, mid_planes),
|
|
|
|
subpel_conv1x1(mid_planes, out_planes, 2),
|
|
|
|
)
|
|
|
|
|
|
|
|
class Img_HyperEncoder(nn.Sequential):
|
|
|
|
def __init__(
|
|
|
|
self, in_planes: int = 192, mid_planes: int = 192, out_planes: int = 192
|
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
conv3x3(in_planes, mid_planes),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(mid_planes, mid_planes),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(mid_planes, mid_planes, stride=2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(mid_planes, mid_planes),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(mid_planes, out_planes, stride=2),
|
|
|
|
)
|
|
|
|
|
|
|
|
class Img_HyperDecoder(nn.Sequential):
|
|
|
|
def __init__(
|
|
|
|
self, in_planes: int = 192, mid_planes: int = 192, out_planes: int = 192
|
|
|
|
):
|
|
|
|
super().__init__(
|
|
|
|
conv3x3(in_planes, mid_planes),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
subpel_conv1x1(mid_planes, mid_planes, 2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(mid_planes, mid_planes * 3 // 2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
subpel_conv1x1(mid_planes * 3 // 2, mid_planes * 3 // 2, 2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(mid_planes * 3 // 2, out_planes * 2),
|
|
|
|
)
|
|
|
|
|
|
|
|
class Joint_Hy_Dual(CompressionModel):
|
|
|
|
def __init__(self, planes: int = 192, mid_planes: int = 192):
|
|
|
|
super().__init__(entropy_bottleneck_channels=mid_planes)
|
|
|
|
self.planes = planes
|
|
|
|
self.hyper_encoder = Img_HyperEncoder(planes, mid_planes, planes)
|
|
|
|
self.hyper_decoder = Img_HyperDecoder(planes, mid_planes, planes)
|
|
|
|
#self.dual_context = nn.Conv2d(in_channels=planes, out_channels=planes * 2, kernel_size=5, stride=1, padding=2)
|
|
|
|
self.dual_context = nn.Sequential(
|
|
|
|
conv3x3(planes, planes),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(planes, planes * 3 // 2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv3x3(planes * 3 // 2, planes * 2),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.entropy_parameters = nn.Sequential(
|
|
|
|
conv1x1(planes * 4, planes * 3),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv1x1(planes * 3, planes * 2),
|
|
|
|
nn.LeakyReLU(),
|
|
|
|
conv1x1(planes * 2, planes * 2),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.gaussian_conditional = GaussianConditional(None)
|
|
|
|
|
|
|
|
def forward(self, y):
|
|
|
|
z = self.hyper_encoder(y)
|
|
|
|
z_hat, z_likelihoods = self.entropy_bottleneck(z)
|
|
|
|
hyperprior = self.hyper_decoder(z_hat)
|
|
|
|
|
|
|
|
y_anchor, y_nonanchor = dual_split(y)
|
|
|
|
params_anchor = self.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
|
|
|
|
scales_anchor, means_anchor = params_anchor.chunk(2, 1)
|
|
|
|
scales_anchor = dual_anchor(scales_anchor)
|
|
|
|
means_anchor = dual_anchor(means_anchor)
|
|
|
|
y_anchor = quantize_ste(y_anchor - means_anchor) + means_anchor
|
|
|
|
y_anchor = dual_anchor(y_anchor)
|
|
|
|
|
|
|
|
spatialprior = self.dual_context(y_anchor)
|
|
|
|
params_nonanchor = self.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
|
|
|
|
scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
|
|
|
|
scales_nonanchor = dual_nonanchor(scales_nonanchor)
|
|
|
|
means_nonanchor = dual_nonanchor(means_nonanchor)
|
|
|
|
scales = ckbd_merge(scales_anchor, scales_nonanchor)
|
|
|
|
means = ckbd_merge(means_anchor, means_nonanchor)
|
|
|
|
_, y_likelihoods = self.gaussian_conditional(y, scales, means)
|
|
|
|
y_nonanchor = quantize_ste(y_nonanchor - means_nonanchor) + means_nonanchor
|
|
|
|
y_nonanchor = dual_nonanchor(y_nonanchor)
|
|
|
|
y_hat = y_anchor + y_nonanchor
|
|
|
|
|
|
|
|
return y_hat, {"y": y_likelihoods, "z": z_likelihoods}
|
|
|
|
|
|
|
|
def compress(self, y):
|
|
|
|
z = self.hyper_encoder(y)
|
|
|
|
z_string = self.entropy_bottleneck.compress(z)
|
|
|
|
z_hat = self.entropy_bottleneck.decompress(z_string, z.size()[-2:])
|
|
|
|
hyperprior = self.hyper_decoder(z_hat)
|
|
|
|
|
|
|
|
y_anchor, y_nonanchor = dual_split(y)
|
|
|
|
params_anchor = self.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
|
|
|
|
scales_anchor, means_anchor = params_anchor.chunk(2, 1)
|
|
|
|
anchor_squeeze = dual_anchor_sequeeze(y_anchor)
|
|
|
|
scales_anchor_squeeze = dual_anchor_sequeeze(scales_anchor)
|
|
|
|
means_anchor_squeeze = dual_anchor_sequeeze(means_anchor)
|
|
|
|
indexes_anchor = self.gaussian_conditional.build_indexes(scales_anchor_squeeze)
|
|
|
|
y_string_anchor = self.gaussian_conditional.compress(anchor_squeeze, indexes_anchor, means_anchor_squeeze)
|
|
|
|
y_hat_anchor_squeeze = self.gaussian_conditional.quantize(anchor_squeeze, "dequantize", means_anchor_squeeze)
|
|
|
|
y_hat_anchor = dual_anchor_unsequeeze(y_hat_anchor_squeeze)
|
|
|
|
|
|
|
|
spatialprior = self.dual_context(y_hat_anchor)
|
|
|
|
params_nonanchor = self.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
|
|
|
|
scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
|
|
|
|
nonanchor_squeeze = dual_nonanchor_sequeeze(y_nonanchor)
|
|
|
|
scales_nonanchor_squeeze = dual_nonanchor_sequeeze(scales_nonanchor)
|
|
|
|
means_nonanchor_squeeze = dual_nonanchor_sequeeze(means_nonanchor)
|
|
|
|
indexes_nonanchor = self.gaussian_conditional.build_indexes(scales_nonanchor_squeeze)
|
|
|
|
y_string_nonanchor = self.gaussian_conditional.compress(nonanchor_squeeze, indexes_nonanchor, means_nonanchor_squeeze)
|
|
|
|
y_hat_nonanchor_squeeze = self.gaussian_conditional.quantize(nonanchor_squeeze, "dequantize", means_nonanchor_squeeze)
|
|
|
|
y_hat_nonanchor = dual_nonanchor_unsequeeze(y_hat_nonanchor_squeeze)
|
|
|
|
y_hat = y_hat_anchor + y_hat_nonanchor
|
|
|
|
|
|
|
|
return y_hat, {"strings": [y_string_anchor, y_string_nonanchor, z_string], "shape": z.size()[-2:]}
|
|
|
|
|
|
|
|
def decompress(self, strings, shape):
|
|
|
|
assert isinstance(strings, list) and len(strings) == 3
|
|
|
|
z_hat = self.entropy_bottleneck.decompress(strings[2], shape)
|
|
|
|
hyperprior = self.hyper_decoder(z_hat)
|
|
|
|
|
|
|
|
params_anchor = self.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
|
|
|
|
scales_anchor, means_anchor = params_anchor.chunk(2, 1)
|
|
|
|
scales_anchor_squeeze = dual_anchor_sequeeze(scales_anchor)
|
|
|
|
means_anchor_squeeze = dual_anchor_sequeeze(means_anchor)
|
|
|
|
indexes_anchor = self.gaussian_conditional.build_indexes(scales_anchor_squeeze)
|
|
|
|
y_hat_anchor_squeeze = self.gaussian_conditional.decompress(strings[0], indexes_anchor, z_hat.dtype, means_anchor_squeeze)
|
|
|
|
y_hat_anchor = dual_anchor_unsequeeze(y_hat_anchor_squeeze)
|
|
|
|
|
|
|
|
spatialprior = self.dual_context(y_hat_anchor)
|
|
|
|
params_nonanchor = self.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
|
|
|
|
scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
|
|
|
|
scales_nonanchor_squeeze = dual_nonanchor_sequeeze(scales_nonanchor)
|
|
|
|
means_nonanchor_squeeze = dual_nonanchor_sequeeze(means_nonanchor)
|
|
|
|
indexes_nonanchor = self.gaussian_conditional.build_indexes(scales_nonanchor_squeeze)
|
|
|
|
y_hat_nonanchor_squeeze = self.gaussian_conditional.decompress(strings[1], indexes_nonanchor, z_hat.dtype, means_nonanchor_squeeze)
|
|
|
|
y_hat_nonanchor = dual_nonanchor_unsequeeze(y_hat_nonanchor_squeeze)
|
|
|
|
y_hat = y_hat_anchor + y_hat_nonanchor
|
|
|
|
|
|
|
|
return y_hat
|
|
|
|
|
|
|
|
self.img_encoder = Img_Encoder(in_planes=3, mid_planes=192, out_planes=192)
|
|
|
|
self.img_decoder = Img_Decoder(in_planes=192, mid_planes=192, out_planes=3)
|
|
|
|
self.img_entropymodel = Joint_Hy_Dual()
|
|
|
|
|
|
|
|
self.img_var_factor = nn.Parameter(torch.ones([4, 1, 192, 1, 1]))
|
|
|
|
self.img_var_bias = nn.Parameter(torch.ones(1))
|
|
|
|
|
|
|
|
# def forward(self, frames, factor):
|
|
|
|
# if not isinstance(frames, List):
|
|
|
|
# raise RuntimeError(f"Invalid number of frames: {len(frames)}.")
|
|
|
|
|
|
|
|
# reconstructions = []
|
|
|
|
# frames_likelihoods = []
|
|
|
|
|
|
|
|
# x_hat, likelihoods = self.forward_keyframe(frames[0], factor)
|
|
|
|
# reconstructions.append(x_hat)
|
|
|
|
# frames_likelihoods.append(likelihoods)
|
|
|
|
|
|
|
|
# return {
|
|
|
|
# "x_hat": reconstructions,
|
|
|
|
# "likelihoods": frames_likelihoods,
|
|
|
|
# }
|
|
|
|
|
|
|
|
# def forward_keyframe(self, x, factor):
|
|
|
|
# y = self.img_encoder(x)
|
|
|
|
# y = self.img_var_factor[factor] * y
|
|
|
|
# y_hat, likelihoods = self.img_entropymodel(y)
|
|
|
|
# y_hat = self.img_var_bias * y_hat / self.img_var_factor[factor]
|
|
|
|
# x_hat = self.img_decoder(y_hat)
|
|
|
|
# return x_hat, {"keyframe": likelihoods}
|
|
|
|
|
|
|
|
# def encode_keyframe(self, x, factor):
|
|
|
|
# y = self.img_encoder(x)
|
|
|
|
# y = self.img_var_factor[factor] * y
|
|
|
|
# y_hat, out_keyframe = self.img_entropymodel.compress(y)
|
|
|
|
# y_hat = self.img_var_bias * y_hat / self.img_var_factor[factor]
|
|
|
|
# x_hat = self.img_decoder(y_hat)
|
|
|
|
|
|
|
|
# return x_hat, out_keyframe
|
|
|
|
|
|
|
|
# def decode_keyframe(self, strings, shape, factor):
|
|
|
|
# y_hat = self.img_entropymodel.decompress(strings, shape)
|
|
|
|
# y_hat = self.img_var_bias * y_hat / self.img_var_factor[factor]
|
|
|
|
# x_hat = self.img_decoder(y_hat)
|
|
|
|
|
|
|
|
# return x_hat
|
|
|
|
|
|
|
|
def forward(self, image):
|
|
|
|
z_hat = torch.rand((image.size(0), self.img_entropymodel.planes, image.size(2) // 64, image.size(3) // 64)).cuda()
|
|
|
|
hyperprior = self.img_entropymodel.hyper_decoder(z_hat)
|
|
|
|
|
|
|
|
params_anchor = self.img_entropymodel.entropy_parameters(torch.cat([hyperprior, torch.zeros_like(hyperprior)], dim=1))
|
|
|
|
y_hat_anchor = torch.rand((image.size(0), self.img_entropymodel.planes, image.size(2) // 16, image.size(3) // 16)).cuda()
|
|
|
|
|
|
|
|
spatialprior = self.img_entropymodel.dual_context(y_hat_anchor)
|
|
|
|
params_nonanchor = self.img_entropymodel.entropy_parameters(torch.cat([hyperprior, spatialprior], dim=1))
|
|
|
|
|
|
|
|
y_hat = torch.rand((image.size(0), self.img_entropymodel.planes, image.size(2) // 16, image.size(3) // 16)).cuda()
|
|
|
|
x_hat = self.img_decoder(y_hat)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict(self, state_dict, strict=True, update_buffer=True):
|
|
|
|
if update_buffer:
|
|
|
|
# Dynamically update the entropy bottleneck buffers related to the CDFs
|
|
|
|
update_registered_buffers(
|
|
|
|
self.img_entropymodel.gaussian_conditional,
|
|
|
|
"img_entropymodel.gaussian_conditional",
|
|
|
|
["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
|
|
|
|
state_dict,
|
|
|
|
)
|
|
|
|
update_registered_buffers(
|
|
|
|
self.img_entropymodel.entropy_bottleneck,
|
|
|
|
"img_entropymodel.entropy_bottleneck",
|
|
|
|
["_quantized_cdf", "_offset", "_cdf_length"],
|
|
|
|
state_dict,
|
|
|
|
)
|
|
|
|
|
|
|
|
super().load_state_dict(state_dict, strict=strict)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def from_state_dict(cls, state_dict):
|
|
|
|
"""Return a new model instance from `state_dict`."""
|
|
|
|
net = cls()
|
|
|
|
net.load_state_dict(state_dict)
|
|
|
|
return net
|
|
|
|
|
|
|
|
def update(self, scale_table=None, force=False):
|
|
|
|
if scale_table is None:
|
|
|
|
scale_table = get_scale_table()
|
|
|
|
|
|
|
|
updated = self.img_entropymodel.gaussian_conditional.update_scale_table(
|
|
|
|
scale_table, force=force
|
|
|
|
)
|
|
|
|
# updated |= super().update(force=force)
|
|
|
|
|
|
|
|
for m in self.modules():
|
|
|
|
if not isinstance(m, EntropyBottleneck):
|
|
|
|
continue
|
|
|
|
rv = m.update(force=force)
|
|
|
|
updated |= rv
|
|
|
|
|
|
|
|
return updated
|