|
|
|
@@ -321,6 +321,7 @@ class VAE: |
|
|
|
self.latent_channels = 4 |
|
|
|
self.latent_dim = 2 |
|
|
|
self.output_channels = 3 |
|
|
|
self.pad_channel_value = None |
|
|
|
self.process_input = lambda image: image * 2.0 - 1.0 |
|
|
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
|
self.working_dtypes = [torch.bfloat16, torch.float32] |
|
|
|
@@ -435,6 +436,7 @@ class VAE: |
|
|
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) |
|
|
|
self.latent_channels = 64 |
|
|
|
self.output_channels = 2 |
|
|
|
self.pad_channel_value = "replicate" |
|
|
|
self.upscale_ratio = 2048 |
|
|
|
self.downscale_ratio = 2048 |
|
|
|
self.latent_dim = 1 |
|
|
|
@@ -546,7 +548,9 @@ class VAE: |
|
|
|
self.downscale_index_formula = (4, 8, 8) |
|
|
|
self.latent_dim = 3 |
|
|
|
self.latent_channels = 16 |
|
|
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} |
|
|
|
self.output_channels = sd["encoder.conv1.weight"].shape[1] |
|
|
|
self.pad_channel_value = 1.0 |
|
|
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} |
|
|
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) |
|
|
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] |
|
|
|
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) |
|
|
|
@@ -582,6 +586,7 @@ class VAE: |
|
|
|
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) |
|
|
|
self.latent_channels = 8 |
|
|
|
self.output_channels = 2 |
|
|
|
self.pad_channel_value = "replicate" |
|
|
|
self.upscale_ratio = 4096 |
|
|
|
self.downscale_ratio = 4096 |
|
|
|
self.latent_dim = 2 |
|
|
|
@@ -690,17 +695,28 @@ class VAE: |
|
|
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") |
|
|
|
|
|
|
|
def vae_encode_crop_pixels(self, pixels): |
|
|
|
if not self.crop_input: |
|
|
|
return pixels |
|
|
|
|
|
|
|
downscale_ratio = self.spacial_compression_encode() |
|
|
|
if self.crop_input: |
|
|
|
downscale_ratio = self.spacial_compression_encode() |
|
|
|
|
|
|
|
dims = pixels.shape[1:-1] |
|
|
|
for d in range(len(dims)): |
|
|
|
x = (dims[d] // downscale_ratio) * downscale_ratio |
|
|
|
x_offset = (dims[d] % downscale_ratio) // 2 |
|
|
|
if x != dims[d]: |
|
|
|
pixels = pixels.narrow(d + 1, x_offset, x) |
|
|
|
|
|
|
|
if pixels.shape[-1] > self.output_channels: |
|
|
|
pixels = pixels[..., :self.output_channels] |
|
|
|
elif pixels.shape[-1] < self.output_channels: |
|
|
|
if self.pad_channel_value is not None: |
|
|
|
if isinstance(self.pad_channel_value, str): |
|
|
|
mode = self.pad_channel_value |
|
|
|
value = None |
|
|
|
else: |
|
|
|
mode = "constant" |
|
|
|
value = self.pad_channel_value |
|
|
|
|
|
|
|
dims = pixels.shape[1:-1] |
|
|
|
for d in range(len(dims)): |
|
|
|
x = (dims[d] // downscale_ratio) * downscale_ratio |
|
|
|
x_offset = (dims[d] % downscale_ratio) // 2 |
|
|
|
if x != dims[d]: |
|
|
|
pixels = pixels.narrow(d + 1, x_offset, x) |
|
|
|
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) |
|
|
|
return pixels |
|
|
|
|
|
|
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): |
|
|
|
|