19 Commits

Author SHA1 Message Date
  autumn 4f28958a79 Merge remote-tracking branch 'origin/muon_lynxnet2_bbc' into main-bbc 5 months ago
  autumn 2ade7e48a0 add bbc encoder for me and fast bbc mask 5 months ago
  yxlllc 2a10f27194 optimize smooth width 5 months ago
  yxlllc f6af252039 support one smooth_kernel 5 months ago
  yxlllc f5915341d0 support atanglu 6 months ago
  yxlllc 0db91f27d9 save memory 7 months ago
  yxlllc d099e3c952 support bf16 calculation 7 months ago
  yxlllc 3ae76d72fe avoid precision conversions 7 months ago
  Kakaru b0ae9ca8ba
[DONE]Supplement the Variance Model Scaling / Retake Scaling / Conditioner cache on LYNXNet2 (#259) 7 months ago
  yxlllc 2ea898f516 variance scaling for onnx 7 months ago
  yxlllc 954e41c890 variance scaling 7 months ago
  Kakaru 14c360938d
Fix some issue about Initialization (#250) 8 months ago
  yxlllc 300676a8c1 stabilize fp16 training 8 months ago
  yxlllc eb3b606de6 stabilize fp16 training 8 months ago
  yxlllc 5f7a1be7f6 Merge branch 'main' into muon_lynxnet2 8 months ago
  yxlllc f9fda27814 optimize 8 months ago
  yxlllc 4a4ee3defb support muon optimizer 8 months ago
  yxlllc 51d3d3d263 Merge branch 'lynxnet2' into muon_lynxnet2 8 months ago
  yxlllc 13406a2832 update lynxnet2 backbone 9 months ago
22 changed files with 810 additions and 139 deletions
Split View
  1. +1
    -1
      basics/base_task.py
  2. +19
    -12
      configs/acoustic.yaml
  3. +14
    -8
      configs/templates/config_acoustic.yaml
  4. +35
    -28
      configs/templates/config_variance.yaml
  5. +5
    -0
      configs/templates/config_variance_bbc.yaml
  6. +35
    -20
      configs/variance.yaml
  7. +8
    -2
      deployment/modules/fastspeech2.py
  8. +13
    -4
      deployment/modules/toplevel.py
  9. +3
    -1
      modules/backbones/__init__.py
  10. +3
    -18
      modules/backbones/lynxnet.py
  11. +117
    -0
      modules/backbones/lynxnet2.py
  12. +1
    -6
      modules/backbones/wavenet.py
  13. +52
    -2
      modules/commons/common_layers.py
  14. +30
    -4
      modules/fastspeech/acoustic_encoder.py
  15. +84
    -0
      modules/fastspeech/bbc_mask.py
  16. +28
    -11
      modules/fastspeech/tts_modules.py
  17. +15
    -6
      modules/fastspeech/variance_encoder.py
  18. +122
    -0
      modules/optimizer/chained_optimizer.py
  19. +152
    -0
      modules/optimizer/muon.py
  20. +63
    -10
      modules/toplevel.py
  21. +2
    -1
      utils/__init__.py
  22. +8
    -5
      utils/binarizer_utils.py

+ 1
- 1
basics/base_task.py View File

@@ -307,7 +307,7 @@ class BaseTask(pl.LightningModule):
optimizer = build_object_from_class_name(
optimizer_args['optimizer_cls'],
torch.optim.Optimizer,
model.parameters(),
model if optimizer_args['optimizer_cls'] == 'modules.optimizer.muon.Muon_AdamW' else model.parameters(),
**optimizer_args
)
return optimizer


+ 19
- 12
configs/acoustic.yaml View File

@@ -42,10 +42,10 @@ spec_max: [0]
mel_vmin: -14.
mel_vmax: 4.
mel_base: 'e'
energy_smooth_width: 0.12
breathiness_smooth_width: 0.12
voicing_smooth_width: 0.12
tension_smooth_width: 0.12
energy_smooth_width: 0.06
breathiness_smooth_width: 0.06
voicing_smooth_width: 0.06
tension_smooth_width: 0.06

use_lang_id: false
num_lang: 1
@@ -60,7 +60,7 @@ use_speed_embed: false

use_bbc_encoder: false
bbc_mask_len: 8
bbc_min_segment_length: 64
bbc_min_segment_length: 16
bbc_mask_prob: 1.

diffusion_type: reflow
@@ -69,19 +69,21 @@ timesteps: 1000
max_beta: 0.02
enc_ffn_kernel_size: 3
use_rope: true
use_variance_scaling: true
rel_pos: true
sampling_algorithm: euler
sampling_steps: 20
diff_accelerator: ddim
diff_speedup: 10
hidden_size: 256
backbone_type: 'lynxnet'
backbone_type: 'lynxnet2'
backbone_args:
num_channels: 1024
num_layers: 6
kernel_size: 31
dropout_rate: 0.0
strong_cond: true
use_conditioner_cache: true
glu_type: 'atanglu'
main_loss_type: l2
main_loss_log_norm: false
schedule_type: 'linear'
@@ -110,20 +112,25 @@ lambda_aux_mel_loss: 0.2
# train and eval
num_sanity_val_steps: 1
optimizer_args:
optimizer_cls: modules.optimizer.muon.Muon_AdamW
lr: 0.0006
muon_args:
weight_decay: 0.1
adamw_args:
weight_decay: 0.0
lr_scheduler_args:
step_size: 10000
gamma: 0.75
step_size: 5000
gamma: 0.8
max_batch_frames: 50000
max_batch_size: 64
dataset_size_key: 'lengths'
val_with_vocoder: true
val_check_interval: 2000
num_valid_plots: 10
max_updates: 160000
max_updates: 100000
num_ckpt_keep: 5
permanent_ckpt_start: 80000
permanent_ckpt_interval: 20000
permanent_ckpt_start: 60000
permanent_ckpt_interval: 10000

finetune_enabled: false
finetune_ckpt_path: null


+ 14
- 8
configs/templates/config_acoustic.yaml View File

@@ -71,18 +71,20 @@ augmentation_args:
diffusion_type: reflow
enc_ffn_kernel_size: 3
use_rope: true
use_variance_scaling: true
use_shallow_diffusion: true
T_start: 0.4
T_start_infer: 0.4
K_step: 300
K_step_infer: 300
backbone_type: 'lynxnet'
backbone_type: 'lynxnet2'
backbone_args:
num_channels: 1024
num_layers: 6
kernel_size: 31
dropout_rate: 0.0
strong_cond: true
use_conditioner_cache: true
glu_type: 'atanglu'
#backbone_type: 'wavenet'
#backbone_args:
# num_channels: 512
@@ -102,20 +104,24 @@ shallow_diffusion_args:
lambda_aux_mel_loss: 0.2

optimizer_args:
optimizer_cls: modules.optimizer.muon.Muon_AdamW
lr: 0.0006
muon_args:
weight_decay: 0.1
adamw_args:
weight_decay: 0.0
lr_scheduler_args:
scheduler_cls: torch.optim.lr_scheduler.StepLR
step_size: 10000
gamma: 0.75
step_size: 5000
gamma: 0.8
max_batch_frames: 50000
max_batch_size: 64
max_updates: 160000
max_updates: 100000

num_valid_plots: 10
val_with_vocoder: true
val_check_interval: 2000
num_ckpt_keep: 5
permanent_ckpt_start: 120000
permanent_ckpt_interval: 20000
permanent_ckpt_start: 60000
permanent_ckpt_interval: 10000
pl_trainer_devices: 'auto'
pl_trainer_precision: '16-mixed'

+ 35
- 28
configs/templates/config_variance.yaml View File

@@ -65,10 +65,11 @@ tension_logit_max: 10.0

enc_ffn_kernel_size: 3
use_rope: true
use_variance_scaling: true
hidden_size: 256
dur_prediction_args:
arch: fs2
hidden_size: 512
arch: resnet
hidden_size: 256
dropout: 0.1
num_layers: 5
kernel_size: 3
@@ -78,7 +79,7 @@ dur_prediction_args:
lambda_wdur_loss: 1.0
lambda_sdur_loss: 3.0

use_melody_encoder: false
use_melody_encoder: true
melody_encoder_args:
hidden_size: 128
enc_layers: 4
@@ -94,50 +95,56 @@ pitch_prediction_args:
pitd_clip_min: -12.0
pitd_clip_max: 12.0
repeat_bins: 64
backbone_type: 'wavenet'
backbone_args:
num_layers: 20
num_channels: 256
dilation_cycle_length: 5
# backbone_type: 'lynxnet'
# backbone_type: 'wavenet'
# backbone_args:
# num_layers: 6
# num_channels: 512
# dropout_rate: 0.0
# strong_cond: true
# num_layers: 20
# num_channels: 256
# dilation_cycle_length: 5
backbone_type: 'lynxnet2'
backbone_args:
num_layers: 6
num_channels: 512
dropout_rate: 0.0
use_conditioner_cache: true
glu_type: 'atanglu'

variances_prediction_args:
total_repeat_bins: 48
backbone_type: 'wavenet'
backbone_args:
num_layers: 10
num_channels: 192
dilation_cycle_length: 4
# backbone_type: 'lynxnet'
# backbone_type: 'wavenet'
# backbone_args:
# num_layers: 6
# num_channels: 384
# dropout_rate: 0.0
# strong_cond: true
# num_layers: 10
# num_channels: 192
# dilation_cycle_length: 4
backbone_type: 'lynxnet2'
backbone_args:
num_layers: 6
num_channels: 384
dropout_rate: 0.0
use_conditioner_cache: true
glu_type: 'atanglu'

lambda_dur_loss: 1.0
lambda_pitch_loss: 1.0
lambda_var_loss: 1.0

optimizer_args:
optimizer_cls: modules.optimizer.muon.Muon_AdamW
lr: 0.0006
muon_args:
weight_decay: 0.1
adamw_args:
weight_decay: 0.0
lr_scheduler_args:
scheduler_cls: torch.optim.lr_scheduler.StepLR
step_size: 10000
gamma: 0.75
step_size: 5000
gamma: 0.8
max_batch_frames: 80000
max_batch_size: 48
max_updates: 160000
max_updates: 100000

num_valid_plots: 10
val_check_interval: 2000
num_ckpt_keep: 5
permanent_ckpt_start: 80000
permanent_ckpt_start: 60000
permanent_ckpt_interval: 10000
pl_trainer_devices: 'auto'
pl_trainer_precision: '16-mixed'

+ 5
- 0
configs/templates/config_variance_bbc.yaml View File

@@ -67,6 +67,11 @@ use_bbc_encoder: true
bbc_mask_len: 8
bbc_min_segment_length: 16
bbc_mask_prob: 1.

use_me_bbc_encoder: true
me_bbc_mask_len: 8
me_bbc_min_segment_length: 16
me_bbc_mask_prob: 1.
enc_ffn_kernel_size: 3
use_rope: true
hidden_size: 256


+ 35
- 20
configs/variance.yaml View File

@@ -37,16 +37,22 @@ predict_tension: false

use_bbc_encoder: false
bbc_mask_len: 8
bbc_min_segment_length: 64
bbc_min_segment_length: 16
bbc_mask_prob: 1.

use_me_bbc_encoder: false
me_bbc_mask_len: 8
me_bbc_min_segment_length: 16
me_bbc_mask_prob: 1.
enc_ffn_kernel_size: 3
use_rope: true
use_variance_scaling: true
rel_pos: true
hidden_size: 256

dur_prediction_args:
arch: fs2
hidden_size: 512
arch: resnet
hidden_size: 256
dropout: 0.1
num_layers: 5
kernel_size: 3
@@ -56,7 +62,7 @@ dur_prediction_args:
lambda_wdur_loss: 1.0
lambda_sdur_loss: 3.0

use_melody_encoder: false
use_melody_encoder: true
melody_encoder_args:
hidden_size: 128
enc_layers: 4
@@ -70,34 +76,38 @@ pitch_prediction_args:
pitd_clip_min: -12.0
pitd_clip_max: 12.0
repeat_bins: 64
backbone_type: 'wavenet'
backbone_type: 'lynxnet2'
backbone_args:
num_layers: 20
num_channels: 256
dilation_cycle_length: 5
num_layers: 6
num_channels: 512
dropout_rate: 0.0
use_conditioner_cache: true
glu_type: 'atanglu'

energy_db_min: -96.0
energy_db_max: -12.0
energy_smooth_width: 0.12
energy_smooth_width: 0.06

breathiness_db_min: -96.0
breathiness_db_max: -20.0
breathiness_smooth_width: 0.12
breathiness_smooth_width: 0.06
voicing_db_min: -96.0
voicing_db_max: -12.0
voicing_smooth_width: 0.12
voicing_smooth_width: 0.06

tension_logit_min: -10.0
tension_logit_max: 10.0
tension_smooth_width: 0.12
tension_smooth_width: 0.06

variances_prediction_args:
total_repeat_bins: 48
backbone_type: 'wavenet'
backbone_type: 'lynxnet2'
backbone_args:
num_layers: 10
num_channels: 192
dilation_cycle_length: 4
num_layers: 6
num_channels: 384
dropout_rate: 0.0
use_conditioner_cache: true
glu_type: 'atanglu'

lambda_dur_loss: 1.0
lambda_pitch_loss: 1.0
@@ -119,18 +129,23 @@ diff_speedup: 10
# train and eval
num_sanity_val_steps: 1
optimizer_args:
optimizer_cls: modules.optimizer.muon.Muon_AdamW
lr: 0.0006
muon_args:
weight_decay: 0.1
adamw_args:
weight_decay: 0.0
lr_scheduler_args:
step_size: 10000
gamma: 0.75
step_size: 5000
gamma: 0.8
max_batch_frames: 80000
max_batch_size: 48
dataset_size_key: 'lengths'
val_check_interval: 2000
num_valid_plots: 10
max_updates: 160000
max_updates: 100000
num_ckpt_keep: 5
permanent_ckpt_start: 80000
permanent_ckpt_start: 60000
permanent_ckpt_interval: 10000

finetune_enabled: false


+ 8
- 2
deployment/modules/fastspeech2.py View File

@@ -75,7 +75,7 @@ class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
durations = durations * (tokens > 0)
mel2ph = self.lr(durations)
f0 = f0 * (mel2ph > 0)
if self.use_variance_scaling:
dur_embed = self.dur_embed(torch.log(1 + durations.float())[:, :, None])
else:
@@ -113,6 +113,7 @@ class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
if self.use_variance_embeds:
variance_embeds = torch.stack([
self.variance_embeds[v_name](variances[v_name][:, :, None])
* self.variance_scaling_factor[v_name]
for v_name in self.variance_embed_list
], dim=-1).sum(-1)
condition += variance_embeds
@@ -125,6 +126,7 @@ class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
gender_mask = (gender < 0.).float()
key_shift = gender * ((1. - gender_mask) * self.shift_max + gender_mask * abs(self.shift_min))
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
key_shift_embed *= self.variance_scaling_factor['key_shift']
condition += key_shift_embed

if hparams['use_speed_embed']:
@@ -133,6 +135,7 @@ class FastSpeech2AcousticONNX(FastSpeech2Acoustic):
speed_embed = self.speed_embed(velocity[:, :, None])
else:
speed_embed = self.speed_embed(torch.FloatTensor([1.]).to(condition.device)[:, None, None])
speed_embed *= self.variance_scaling_factor['speed']
condition += speed_embed

if hparams['use_spk_id']:
@@ -175,7 +178,10 @@ class FastSpeech2VarianceONNX(FastSpeech2Variance):

def forward_encoder_phoneme(self, tokens, ph_dur, languages=None):
txt_embed = self.txt_embed(tokens)
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
if self.use_variance_scaling:
ph_dur_embed = self.ph_dur_embed(torch.log(1 + ph_dur.float())[:, :, None])
else:
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
if self.use_lang_id:
lang_mask = torch.any(
tokens[..., None] == self.cross_lingual_token_idx[None, None],


+ 13
- 4
deployment/modules/toplevel.py View File

@@ -252,10 +252,16 @@ class DiffSingerVarianceONNX(DiffSingerVariance):
base_pitch = self.smooth(frame_midi_pitch)
if self.use_melody_encoder:
delta_pitch = (pitch - base_pitch) * ~retake
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None])
if self.use_variance_scaling:
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None] / 12)
else:
pitch_cond += self.delta_pitch_embed(delta_pitch[:, :, None])
else:
base_pitch = base_pitch * retake + pitch * ~retake
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if self.use_variance_scaling:
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None] / 128)
else:
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if hparams['use_spk_id'] and spk_embed is not None:
pitch_cond += spk_embed
return pitch_cond, base_pitch
@@ -275,13 +281,16 @@ class DiffSingerVarianceONNX(DiffSingerVariance):
variances: dict = None, retake=None, spk_embed=None
):
condition = self.forward_mel2x_gather(encoder_out, ph_dur, x_dim=self.hidden_size)
variance_cond = condition + self.pitch_embed(pitch[:, :, None])
if self.use_variance_scaling:
variance_cond = condition + self.pitch_embed(pitch[:, :, None] / 12)
else:
variance_cond = condition + self.pitch_embed(pitch[:, :, None])
non_retake_masks = [
v_retake.float() # [B, T, 1]
for v_retake in (~retake).split(1, dim=2)
]
variance_embeds = [
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks
self.variance_embeds[v_name](variances[v_name][:, :, None]) * v_masks * self.variance_retake_scaling[v_name]
for v_name, v_masks in zip(self.variance_prediction_list, non_retake_masks)
]
variance_cond += torch.stack(variance_embeds, dim=-1).sum(-1)


+ 3
- 1
modules/backbones/__init__.py View File

@@ -1,11 +1,13 @@
import torch.nn
from modules.backbones.wavenet import WaveNet
from modules.backbones.lynxnet import LYNXNet
from modules.backbones.lynxnet2 import LYNXNet2
from utils import filter_kwargs

BACKBONES = {
'wavenet': WaveNet,
'lynxnet': LYNXNet
'lynxnet': LYNXNet,
'lynxnet2': LYNXNet2,
}




+ 3
- 18
modules/backbones/lynxnet.py View File

@@ -6,26 +6,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F

from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU
from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, Transpose
from modules.commons.common_layers import KaimingNormalConv1d as Conv1d
from utils.hparams import hparams


class Conv1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nn.init.kaiming_normal_(self.weight)


class Transpose(nn.Module):
def __init__(self, dims):
super().__init__()
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
self.dims = dims

def forward(self, x):
return x.transpose(*self.dims)


class LYNXConvModule(nn.Module):
@staticmethod
def calc_same_padding(kernel_size):
@@ -150,7 +135,7 @@ class LYNXNet(nn.Module):
# post-norm
x = self.norm(x.transpose(1, 2)).transpose(1, 2)

# MLP and GLU
# output_projection
x = self.output_projection(x) # [B, 128, T]

if self.n_feats == 1:


+ 117
- 0
modules/backbones/lynxnet2.py View File

@@ -0,0 +1,117 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from modules.commons.common_layers import SinusoidalPosEmb, SwiGLU, ATanGLU, Transpose
from utils.hparams import hparams


class LYNXNet2Block(nn.Module):
def __init__(self, dim, expansion_factor, kernel_size=31, dropout=0., glu_type='swiglu'):
super().__init__()
inner_dim = int(dim * expansion_factor)
if glu_type == 'swiglu':
_glu = SwiGLU()
elif glu_type == 'atanglu':
_glu = ATanGLU()
else:
raise ValueError(f'{glu_type} is not a valid activation')
if float(dropout) > 0.:
_dropout = nn.Dropout(dropout)
else:
_dropout = nn.Identity()
self.net = nn.Sequential(
nn.LayerNorm(dim),
Transpose((1, 2)),
nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim),
Transpose((1, 2)),
nn.Linear(dim, inner_dim * 2),
_glu,
nn.Linear(inner_dim, inner_dim * 2),
_glu,
nn.Linear(inner_dim, dim),
_dropout
)

def forward(self, x):
return x + self.net(x)


class LYNXNet2(nn.Module):
def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=1, kernel_size=31,
dropout=0.0, use_conditioner_cache=False, glu_type='swiglu'):
"""
LYNXNet2(Linear Gated Depthwise Separable Convolution Network Version 2)
"""
super().__init__()
self.in_dims = in_dims
self.n_feats = n_feats
self.input_projection = nn.Linear(in_dims * n_feats, num_channels)
self.use_conditioner_cache = use_conditioner_cache
if self.use_conditioner_cache:
# It may need to be modified at some point to be compatible with the condition cache
self.conditioner_projection = nn.Conv1d(hparams['hidden_size'], num_channels, 1)
else:
self.conditioner_projection = nn.Linear(hparams['hidden_size'], num_channels)
self.diffusion_embedding = nn.Sequential(
SinusoidalPosEmb(num_channels),
nn.Linear(num_channels, num_channels * 4),
nn.GELU(),
nn.Linear(num_channels * 4, num_channels),
)
self.residual_layers = nn.ModuleList(
[
LYNXNet2Block(
dim=num_channels,
expansion_factor=expansion_factor,
kernel_size=kernel_size,
dropout=dropout,
glu_type=glu_type
)
for i in range(num_layers)
]
)
self.norm = nn.LayerNorm(num_channels)
self.output_projection = nn.Linear(num_channels, in_dims * n_feats)
nn.init.kaiming_normal_(self.input_projection.weight)
nn.init.kaiming_normal_(self.conditioner_projection.weight)
nn.init.zeros_(self.output_projection.weight)

def forward(self, spec, diffusion_step, cond):
"""
:param spec: [B, F, M, T]
:param diffusion_step: [B, 1]
:param cond: [B, H, T]
:return:
"""

if self.n_feats == 1:
x = spec[:, 0] # [B, M, T]
else:
x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T]

x = self.input_projection(x.transpose(1, 2)) # [B, T, F x M]
if self.use_conditioner_cache:
# It may need to be modified at some point to be compatible with the condition cache
x = x + self.conditioner_projection(cond).transpose(1, 2)
else:
x = x + self.conditioner_projection(cond.transpose(1, 2))
x = x + self.diffusion_embedding(diffusion_step).unsqueeze(1)

for layer in self.residual_layers:
x = layer(x)

# post-norm
x = self.norm(x)

# output projection
x = self.output_projection(x).transpose(1, 2) # [B, 128, T]

if self.n_feats == 1:
x = x[:, None, :, :]
else:
# This is the temporary solution since PyTorch 1.13
# does not support exporting aten::unflatten to ONNX
# x = x.unflatten(dim=1, sizes=(self.n_feats, self.in_dims))
x = x.reshape(-1, self.n_feats, self.in_dims, x.shape[2])
return x

+ 1
- 6
modules/backbones/wavenet.py View File

@@ -6,15 +6,10 @@ import torch.nn as nn
import torch.nn.functional as F

from modules.commons.common_layers import SinusoidalPosEmb
from modules.commons.common_layers import KaimingNormalConv1d as Conv1d
from utils.hparams import hparams


class Conv1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nn.init.kaiming_normal_(self.weight)


class ResidualBlock(nn.Module):
def __init__(self, encoder_hidden, residual_channels, dilation):
super().__init__()


+ 52
- 2
modules/commons/common_layers.py View File

@@ -114,9 +114,49 @@ class SwiGLU(nn.Module):
# out, gate = x.chunk(2, dim=self.dim)
# Using torch.split instead of chunk for ONNX export compatibility.
out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
return out * F.silu(gate)
gate = F.silu(gate)
if x.dtype == torch.float16:
out_min, out_max = torch.aminmax(out.detach())
gate_min, gate_max = torch.aminmax(gate.detach())
max_abs_out = torch.max(-out_min, out_max).float()
max_abs_gate = torch.max(-gate_min, gate_max).float()
max_abs_value = max_abs_out * max_abs_gate
if max_abs_value > 1000:
ratio = (1000 / max_abs_value).half()
gate *= ratio
return (out * gate).clamp(-1000 * ratio, 1000 * ratio) / ratio
return out * gate


class ATanGLU(nn.Module):
# ArcTan-Applies the gated linear unit function.
def __init__(self, dim=-1):
super().__init__()
self.dim = dim

def forward(self, x):
# out, gate = x.chunk(2, dim=self.dim)
# Using torch.split instead of chunk for ONNX export compatibility.
out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim)
return out * torch.atan(gate)
class KaimingNormalConv1d(torch.nn.Conv1d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
nn.init.kaiming_normal_(self.weight)


class Transpose(nn.Module):
def __init__(self, dims):
super().__init__()
assert len(dims) == 2, 'dims must be a tuple of two dimensions'
self.dims = dims

def forward(self, x):
return x.transpose(*self.dims)
class TransformerFFNLayer(nn.Module):
def __init__(self, hidden_size, filter_size, kernel_size=1, dropout=0., act='gelu'):
super().__init__()
@@ -133,6 +173,9 @@ class TransformerFFNLayer(nn.Module):
elif self.act == 'swiglu':
self.act_fn = SwiGLU()
filter_size_1 = filter_size * 2
elif self.act == 'atanglu':
self.act_fn = ATanGLU()
filter_size_1 = filter_size * 2
else:
raise ValueError(f'{act} is not a valid activation')
self.ffn_1 = nn.Conv1d(hidden_size, filter_size_1, kernel_size, padding=kernel_size // 2)
@@ -166,10 +209,17 @@ class MultiheadSelfAttentionWithRoPE(nn.Module):
# Dropout layer
self.dropout = nn.Dropout(dropout)
# Rotary Embeddings
self.rotary_embed = rotary_embed
# Initialization parameters
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
if bias:
nn.init.constant_(self.in_proj.bias, 0.0)
nn.init.constant_(self.out_proj.bias, 0.0)
def forward(self, x, key_padding_mask=None):
# x: (B, L, C)
# key_padding_mask: (B, L)


+ 30
- 4
modules/fastspeech/acoustic_encoder.py View File

@@ -6,7 +6,7 @@ from modules.commons.common_layers import (
NormalInitEmbedding as Embedding,
XavierUniformInitLinear as Linear,
)
from modules.fastspeech.bbc_mask import fast_bbc_mask
from modules.fastspeech.bbc_mask import fast_bbc_mask, fast_fast_bbc_mask
from modules.fastspeech.tts_modules import FastSpeech2Encoder, mel2ph_to_dur
from utils.hparams import hparams
from utils.phoneme_utils import PAD_INDEX
@@ -50,6 +50,26 @@ class FastSpeech2Acoustic(nn.Module):
for v_name in self.variance_embed_list
})

self.use_variance_scaling = hparams.get('use_variance_scaling', False)
if self.use_variance_scaling:
self.variance_scaling_factor = {
'energy': 1. / 96,
'breathiness': 1. / 96,
'voicing': 1. / 96,
'tension': 0.1,
'key_shift': 1. / 12,
'speed': 1.
}
else:
self.variance_scaling_factor = {
'energy': 1.,
'breathiness': 1.,
'voicing': 1.,
'tension': 1.,
'key_shift': 1.,
'speed': 1.
}

self.use_key_shift_embed = hparams.get('use_key_shift_embed', False)
if self.use_key_shift_embed:
self.key_shift_embed = Linear(1, hparams['hidden_size'])
@@ -72,17 +92,20 @@ class FastSpeech2Acoustic(nn.Module):
def forward_variance_embedding(self, condition, key_shift=None, speed=None, **variances):
if self.use_variance_embeds:
variance_embeds = torch.stack([
self.variance_embeds[v_name](variances[v_name][:, :, None])
self.variance_embeds[v_name](variances[v_name][:, :, None])
* self.variance_scaling_factor[v_name]
for v_name in self.variance_embed_list
], dim=-1).sum(-1)
condition += variance_embeds

if self.use_key_shift_embed:
key_shift_embed = self.key_shift_embed(key_shift[:, :, None])
key_shift_embed *= self.variance_scaling_factor['key_shift']
condition += key_shift_embed

if self.use_speed_embed:
speed_embed = self.speed_embed(speed[:, :, None])
speed_embed *= self.variance_scaling_factor['speed']
condition += speed_embed

return condition
@@ -95,7 +118,10 @@ class FastSpeech2Acoustic(nn.Module):
):
txt_embed = self.txt_embed(txt_tokens)
dur = mel2ph_to_dur(mel2ph, txt_tokens.shape[1]).float()
dur_embed = self.dur_embed(dur[:, :, None])
if self.use_variance_scaling:
dur_embed = self.dur_embed(torch.log(1 + dur[:, :, None]))
else:
dur_embed = self.dur_embed(dur[:, :, None])
if self.use_lang_id:
lang_embed = self.lang_embed(languages)
extra_embed = dur_embed + lang_embed
@@ -106,7 +132,7 @@ class FastSpeech2Acoustic(nn.Module):

encoder_out=torch.cat([self.bbc_mask_emb.expand(mel2ph.shape[0],1,encoder_out.shape[-1]),encoder_out],dim=1)
encoder_out = F.pad(encoder_out, [0, 0, 1, 0])
mel2ph=fast_bbc_mask(mel2ph,mask_length=self.bbc_mask_len,min_segment_length=self.bbc_min_segment_length,mask_prob=self.bbc_mask_prob)
mel2ph=fast_fast_bbc_mask(mel2ph,mask_length=self.bbc_mask_len,min_segment_length=self.bbc_min_segment_length,mask_prob=self.bbc_mask_prob)
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
condition = torch.gather(encoder_out, 1, mel2ph_)
else:


+ 84
- 0
modules/fastspeech/bbc_mask.py View File

@@ -40,3 +40,87 @@ def fast_bbc_mask(mel2ph, mask_length=3, min_segment_length=5, mask_prob=1.):
result[batch_idx, mask_start:end] = 1

return result


def fast_fast_bbc_mask(mel2ph, mask_length=3, min_segment_length=5, mask_prob=0.3):

batch_size, seq_len = mel2ph.shape
device = mel2ph.device

masked_mel2ph = torch.where(mel2ph > 0, mel2ph + 1, mel2ph)


padded = F.pad(masked_mel2ph, [1, 1], value=-1) # [B, L+2]


boundaries = (padded[:, 1:] != padded[:, :-1]).float() # [B, L+1]

segment_ids = torch.cumsum(boundaries, dim=1) - 1 # [B, L+1]
segment_ids = segment_ids[:, :-1]
max_segments = int(segment_ids.max().item()) + 1

segment_starts, segment_lengths, segment_values = compute_segment_info_parallel(
masked_mel2ph, segment_ids, max_segments, seq_len, device
)


valid_segments = (segment_values != 0) & (segment_lengths >= min_segment_length)

random_vals = torch.rand_like(segment_values.float())
mask_decisions = valid_segments & (random_vals < mask_prob)

result = apply_masks_parallel(
masked_mel2ph, segment_starts, segment_lengths, mask_decisions,
mask_length, seq_len, device
)

return result
def compute_segment_info_parallel(masked_mel2ph, segment_ids, max_segments, seq_len, device):

pos_idx = torch.arange(seq_len, device=device)[None, None, :] # [1, 1, L]
seg_idx = torch.arange(max_segments, device=device)[None, :, None] # [1, S, 1]


segment_mask = (segment_ids[:, None, :] == seg_idx) # [B, S, L]


pos_masked = torch.where(segment_mask, pos_idx, seq_len)
segment_starts = pos_masked.min(dim=2)[0] # [B, S]


segment_lengths = segment_mask.sum(dim=2) # [B, S]

first_pos_mask = (pos_masked == segment_starts[:, :, None])
values_masked = torch.where(first_pos_mask, masked_mel2ph[:, None, :], 0)
segment_values = values_masked.sum(dim=2) # [B, S]

return segment_starts, segment_lengths, segment_values



def apply_masks_parallel(masked_mel2ph, segment_starts, segment_lengths, mask_decisions,
mask_length, seq_len, device):

result = masked_mel2ph.clone()


pos_indices = torch.arange(seq_len, device=device)[None, None, :] # [1, 1, L]

segment_ends = segment_starts + segment_lengths # [B, S]
mask_starts = torch.clamp(segment_ends - mask_length, min=segment_starts) # [B, S]
mask_ends = segment_ends # [B, S]


mask_matrix = (
(pos_indices >= mask_starts[:, :, None]) &
(pos_indices < mask_ends[:, :, None]) &
mask_decisions[:, :, None]
) # [B, S, L]


final_mask = mask_matrix.any(dim=1) # [B, L]


result = torch.where(final_mask, torch.ones_like(result), result)

return result

+ 28
- 11
modules/fastspeech/tts_modules.py View File

@@ -62,7 +62,7 @@ class DurationPredictor(torch.nn.Module):
"""

def __init__(self, in_dims, n_layers=2, n_chans=384, kernel_size=3,
dropout_rate=0.1, offset=1.0, dur_loss_type='mse'):
dropout_rate=0.1, offset=1.0, dur_loss_type='mse', arch='resnet'):
"""Initialize duration predictor module.
Args:
in_dims (int): Input dimension.
@@ -76,16 +76,29 @@ class DurationPredictor(torch.nn.Module):
self.offset = offset
self.conv = torch.nn.ModuleList()
self.kernel_size = kernel_size
self.use_resnet = (arch == 'resnet')
for idx in range(n_layers):
in_chans = in_dims if idx == 0 else n_chans
self.conv.append(torch.nn.Sequential(
torch.nn.Identity(), # this is a placeholder for ConstantPad1d which is now merged into Conv1d
torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
torch.nn.ReLU(),
LayerNorm(n_chans, dim=1),
torch.nn.Dropout(dropout_rate)
))

if self.use_resnet:
self.conv.append(nn.Sequential(
LayerNorm(in_chans, dim=1),
nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
nn.ReLU(),
nn.Conv1d(n_chans, n_chans, 1),
nn.Dropout(dropout_rate)
))
else:
self.conv.append(nn.Sequential(
nn.Identity(), # this is a placeholder for ConstantPad1d which is now merged into Conv1d
nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=kernel_size // 2),
nn.ReLU(),
LayerNorm(n_chans, dim=1),
nn.Dropout(dropout_rate)
))
if self.use_resnet and in_dims != n_chans:
self.res_conv = nn.Conv1d(in_dims, n_chans, 1)
else:
self.res_conv = None
self.loss_type = dur_loss_type
if self.loss_type in ['mse', 'huber']:
self.out_dims = 1
@@ -121,8 +134,12 @@ class DurationPredictor(torch.nn.Module):
xs = xs.transpose(1, -1) # (B, idim, Tmax)
masks = 1 - x_masks.float()
masks_ = masks[:, None, :]
for f in self.conv:
xs = f(xs) # (B, C, Tmax)
for idx, f in enumerate(self.conv):
if self.use_resnet:
residual = self.res_conv(xs) if idx == 0 and self.res_conv is not None else xs
xs = residual + f(xs)
else:
xs = f(xs)
if x_masks is not None:
xs = xs * masks_
xs = self.linear(xs.transpose(1, -1)) # [B, T, C]


+ 15
- 6
modules/fastspeech/variance_encoder.py View File

@@ -17,7 +17,7 @@ class FastSpeech2Variance(nn.Module):
self.predict_dur = hparams['predict_dur']
self.linguistic_mode = 'word' if hparams['predict_dur'] else 'phoneme'
self.use_lang_id = hparams['use_lang_id']
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
self.txt_embed = Embedding(vocab_size, hparams['hidden_size'], PAD_INDEX)
if self.use_lang_id:
self.lang_embed = Embedding(hparams['num_lang'] + 1, hparams['hidden_size'], padding_idx=0)
@@ -46,7 +46,8 @@ class FastSpeech2Variance(nn.Module):
dropout_rate=dur_hparams['dropout'],
kernel_size=dur_hparams['kernel_size'],
offset=dur_hparams['log_offset'],
dur_loss_type=dur_hparams['loss_type']
dur_loss_type=dur_hparams['loss_type'],
arch=dur_hparams['arch']
)

def forward(
@@ -79,9 +80,11 @@ class FastSpeech2Variance(nn.Module):
word_dur = torch.gather(F.pad(word_dur, [1, 0], value=0), 1, ph2word) # [B, T_w] => [B, T_ph]
word_dur_embed = self.word_dur_embed(word_dur.float()[:, :, None])
extra_embed = onset_embed + word_dur_embed
elif self.use_variance_scaling:
extra_embed = self.ph_dur_embed(torch.log(1 + ph_dur.float())[:, :, None])
else:
ph_dur_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
extra_embed = ph_dur_embed
extra_embed = self.ph_dur_embed(ph_dur.float()[:, :, None])
if self.use_lang_id:
lang_embed = self.lang_embed(languages)
extra_embed += lang_embed
@@ -108,6 +111,7 @@ class MelodyEncoder(nn.Module):

# MIDI inputs
hidden_size = get_hparam('hidden_size')
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
self.note_midi_embed = Linear(1, hidden_size)
self.note_dur_embed = Linear(1, hidden_size)

@@ -135,8 +139,13 @@ class MelodyEncoder(nn.Module):
:param glide: int64 [B, T_n]
:return: [B, T_n, H]
"""
midi_embed = self.note_midi_embed(note_midi[:, :, None]) * ~note_rest[:, :, None]
dur_embed = self.note_dur_embed(note_dur.float()[:, :, None])
if self.use_variance_scaling:
midi_embed = self.note_midi_embed(note_midi[:, :, None] / 128)
dur_embed = self.note_dur_embed(torch.log(1 + note_dur.float())[:, :, None])
else:
midi_embed = self.note_midi_embed(note_midi[:, :, None])
dur_embed = self.note_dur_embed(note_dur.float()[:, :, None])
midi_embed *= ~note_rest[:, :, None]
ornament_embed = 0
if self.use_glide_embed:
ornament_embed += self.note_glide_embed(glide) * self.glide_embed_scale


+ 122
- 0
modules/optimizer/chained_optimizer.py View File

@@ -0,0 +1,122 @@
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.optimizer import ParamsT
from dataclasses import dataclass
from typing import Any, Dict, List, Type, Callable, Optional, Iterable


@dataclass
class OptimizerSpec:
"""Spec for creating an optimizer that is part of a `ChainedOptimizer`."""

class_type: Type[Optimizer]
init_args: Dict[str, Any]
param_filter: Optional[Callable[[Tensor], bool]]


class ChainedOptimizer(Optimizer):
"""
A wrapper around multiple optimizers that allows for chaining them together.
The optimizers are applied in the order they are passed in the constructor.
Each optimizer is responsible for updating a subset of the parameters, which
is determined by the `param_filter` function. If no optimizer is found for a
parameter group, an exception is raised.
"""

def __init__(
self,
params: ParamsT,
optimizer_specs: List[OptimizerSpec],
lr: float,
weight_decay: float = 0.0,
optimizer_selection_callback: Optional[Callable[[Tensor, int], None]] = None,
**common_kwargs,
):
self.optimizer_specs = optimizer_specs
self.optimizer_selection_callback = optimizer_selection_callback
self.optimizers: List[Optimizer] = []
defaults = dict(lr=lr, weight_decay=weight_decay)
super().__init__(params, defaults)

# Split the params for each optimzier
params_for_optimizers = [[] for _ in optimizer_specs]
for param_group in self.param_groups:
params = param_group["params"]
indices = param_group["optimizer_and_param_group_indices"] = set()
for param in params:
assert isinstance(param, Tensor), f"Expected a Tensor, got {type(param)}"
for index, spec in enumerate(optimizer_specs):
if spec.param_filter is None or spec.param_filter(param):
if self.optimizer_selection_callback is not None:
self.optimizer_selection_callback(param, index)
params_for_optimizers[index].append(param)
indices.add((index, 0))
break

# Initialize the optimizers
for spec, selected_params in zip(optimizer_specs, params_for_optimizers):
optimizer_args = {
'lr': lr,
'weight_decay': weight_decay,
}
optimizer_args.update(common_kwargs)
optimizer_args.update(spec.init_args)
optimizer = spec.class_type(selected_params, **optimizer_args)
self.optimizers.append(optimizer)

def state_dict(self) -> Dict[str, Any]:
return {
"optimizers": [opt.state_dict() for opt in self.optimizers],
**super().state_dict(),
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
optimizers = state_dict.pop("optimizers")
super().load_state_dict(state_dict)
for i in range(len(self.optimizers)):
self.optimizers[i].load_state_dict(optimizers[i])

def zero_grad(self, set_to_none: bool = True) -> None:
for opt in self.optimizers:
opt.zero_grad(set_to_none=set_to_none)

def _copy_lr_to_optimizers(self) -> None:
for param_group in self.param_groups:
indices = param_group["optimizer_and_param_group_indices"]
for optimizer_idx, param_group_idx in indices:
self.optimizers[optimizer_idx].param_groups[param_group_idx]["lr"] = param_group["lr"]

def step(self, closure=None) -> None:
self._copy_lr_to_optimizers()
for opt in self.optimizers:
opt.step(closure)

def add_param_group(self, param_group: Dict[str, Any]) -> None:
super().add_param_group(param_group)

# If optimizer has not been initialized, skip adding the param groups
if not self.optimizers:
return

# Split the params for each optimzier
params_for_optimizers = [[] for _ in self.optimizer_specs]
params = param_group["params"]
indices = param_group["optimizer_and_param_group_indices"] = set()
for param in params:
assert isinstance(param, Tensor), f"Expected a Tensor, got {type(param)}"
found_optimizer = False
for index, spec in enumerate(self.optimizer_specs):
if spec.param_filter is None or spec.param_filter(param):
if self.optimizer_selection_callback is not None:
self.optimizer_selection_callback(param, index)
params_for_optimizers[index].append(param)
indices.add((index, len(self.optimizers[index].param_groups)))
found_optimizer = True
break
if not found_optimizer:
raise ValueError("No valid optimizer found for the given parameter group")

# Add the selected param group to the optimizers
for optimizer, selected_params in zip(self.optimizers, params_for_optimizers):
if selected_params:
optimizer.add_param_group({"params": selected_params})

+ 152
- 0
modules/optimizer/muon.py View File

@@ -0,0 +1,152 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Module, Parameter, Embedding
from typing import List
from .chained_optimizer import ChainedOptimizer, OptimizerSpec


def get_bf16_support_map():
bf16_support_map = {}

if not torch.cuda.is_available():
return bf16_support_map

device_count = torch.cuda.device_count()
if device_count == 0:
return bf16_support_map

for i in range(device_count):
device = torch.device(f'cuda:{i}')
major, minor = torch.cuda.get_device_capability(device)
bf16_support_map[device] = (major >= 8)
return bf16_support_map
def zeropower_via_newtonschulz5(G: Tensor, steps: int, use_bf16: bool) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim == 3 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
if use_bf16:
X = G.bfloat16()
else:
X = G.float()
if G.size(-2) > G.size(-1):
X = X.mT

# Ensure spectral norm is at most 1
X = F.normalize(X, p=2.0, dim=(-2, -1), eps=1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = torch.baddbmm(A, A, A, beta=b, alpha=c)
X = torch.baddbmm(X, B, X, beta=a, alpha=1)
if G.size(-2) > G.size(-1):
X = X.mT
return X.to(G)


class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz

https://kellerjordan.github.io/posts/muon/

Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.

Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.

Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""

def __init__(self, params, lr=5e-4, weight_decay=0.1, momentum=0.95, nesterov=True, ns_steps=5):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
super().__init__(params, defaults)
self.bf16_support_map = get_bf16_support_map()
@torch.no_grad()
def step(self, closure=None):
for group in self.param_groups:
shape_groups = {}
for p in filter(lambda p: p.grad is not None, group["params"]):
g = p.grad
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
key = (p.shape, p.device, p.dtype)
if key not in shape_groups:
shape_groups[key] = {"params": [], "grads": [], "buffers": []}
shape_groups[key]["params"].append(p)
shape_groups[key]["grads"].append(g)
shape_groups[key]["buffers"].append(buf)
for key in shape_groups:
group_data = shape_groups[key]
g = torch.stack(group_data["grads"])
buf = torch.stack(group_data["buffers"])
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
if g.ndim >= 4: # for the case of conv filters
g = g.view(g.size(0), g.size(1), -1)
use_bf16 = self.bf16_support_map.get(g.device, False)
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"], use_bf16=use_bf16)
for i, p in enumerate(group_data["params"]):
if group["weight_decay"] > 0:
p.data.mul_(1 - group["lr"] * group["weight_decay"])
p.data.add_(g[i].view_as(p), alpha=-group["lr"] * max(g[i].size()) ** 0.5)
self.state[p]["momentum_buffer"] = buf[i].clone()


def get_params_for_muon(model) -> List[Parameter]:
"""
Filter parameters of a module into two groups: those that can be optimized by Muon,
and those that should be optimized by a standard optimizer.
Args:
module: The module to filter parameters for.
Returns:
A list of parameters that should be optimized with muon.
"""
muon_params = []
for module in model.modules():
for param in module.parameters(recurse=False):
if not param.requires_grad:
continue
if not isinstance(module, nn.Embedding) and param.ndim >= 2:
muon_params.append(param)
return muon_params


class Muon_AdamW(ChainedOptimizer):
def __init__(self, model, lr=0.0005, weight_decay=0.0, muon_args={}, adamw_args={}, verbose=False):
muon_params_id_set = set(id(p) for p in get_params_for_muon(model))
spec_muon = OptimizerSpec(Muon, muon_args, lambda param: id(param) in muon_params_id_set)
spec_adamw = OptimizerSpec(torch.optim.AdamW, adamw_args, None)
specs = [spec_muon, spec_adamw]
callback = None
if verbose:
callback = lambda p, spec_idx: print(
f"Adding param {p.shape} to optimizer{spec_idx} {str(specs[spec_idx].class_type)}"
)
super().__init__(model.parameters(), specs, lr=lr, weight_decay=weight_decay, optimizer_selection_callback=callback)

+ 63
- 10
modules/toplevel.py View File

@@ -17,7 +17,7 @@ from modules.core import (
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
)
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
from modules.fastspeech.bbc_mask import fast_bbc_mask
from modules.fastspeech.bbc_mask import fast_bbc_mask, fast_fast_bbc_mask
from modules.fastspeech.param_adaptor import ParameterAdaptorModule
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator
from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder
@@ -181,12 +181,18 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
)
else:
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")
self.use_bbc_encoder = hparams.get('use_bbc_encoder', False)
if self.use_bbc_encoder:
self.bbc_mask_len = hparams['bbc_mask_len']
self.bbc_min_segment_length=hparams['bbc_min_segment_length']
self.bbc_mask_prob=hparams['bbc_mask_prob']
self.bbc_mask_emb=nn.Parameter(torch.randn(1, 1, hparams['hidden_size']))
self.use_me_bbc_encoder = hparams.get('use_me_bbc_encoder', False)
if self.use_me_bbc_encoder:
self.me_bbc_mask_len = hparams['me_bbc_mask_len']
self.me_bbc_min_segment_length=hparams['me_bbc_min_segment_length']
self.me_bbc_mask_prob=hparams['me_bbc_mask_prob']
self.me_bbc_mask_emb=nn.Parameter(torch.randn(1, 1, hparams['hidden_size']))

if self.predict_variances:
self.pitch_embed = Linear(1, hparams['hidden_size'])
@@ -202,6 +208,28 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
else:
raise NotImplementedError(self.diffusion_type)

self.use_variance_scaling = hparams.get('use_variance_scaling', False)
self.custom_variance_scaling_factor = {
'energy': 1. / 96,
'breathiness': 1. / 96,
'voicing': 1. / 96,
'tension': 0.1,
'key_shift': 1. / 12,
'speed': 1.
}
self.default_variance_scaling_factor = {
'energy': 1.,
'breathiness': 1.,
'voicing': 1.,
'tension': 1.,
'key_shift': 1.,
'speed': 1.
}
if self.use_variance_scaling:
self.variance_retake_scaling = self.custom_variance_scaling_factor
else:
self.variance_retake_scaling = self.default_variance_scaling_factor

def forward(
self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None,
note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None,
@@ -244,7 +272,7 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):

encoder_out=torch.cat([self.bbc_mask_emb.expand(mel2ph.shape[0],1,encoder_out.shape[-1]),encoder_out],dim=1)
encoder_out = F.pad(encoder_out, [0, 0, 1, 0])
mel2ph=fast_bbc_mask(mel2ph,mask_length=self.bbc_mask_len,min_segment_length=self.bbc_min_segment_length,mask_prob=self.bbc_mask_prob)
mel2ph=fast_fast_bbc_mask(mel2ph,mask_length=self.bbc_mask_len,min_segment_length=self.bbc_min_segment_length,mask_prob=self.bbc_mask_prob)
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
condition = torch.gather(encoder_out, 1, mel2ph_)
else:
@@ -261,9 +289,24 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
note_midi, note_rest, note_dur,
glide=note_glide
)
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0])
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']])
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_)


if self.use_me_bbc_encoder:

melody_encoder_out = torch.cat(
[self.me_bbc_mask_emb.expand(mel2note.shape[0], 1, melody_encoder_out.shape[-1]), melody_encoder_out], dim=1)
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0])
mel2note = fast_fast_bbc_mask(mel2note, mask_length=self.bbc_mask_len,
min_segment_length=self.bbc_min_segment_length,
mask_prob=self.bbc_mask_prob)
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']])
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_)
else:
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0])
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']])
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_)


pitch_cond = condition + melody_condition
else:
pitch_cond = condition.clone() # preserve the original tensor to avoid further inplace operations
@@ -290,11 +333,17 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
delta_pitch_in = torch.zeros_like(base_pitch)
else:
delta_pitch_in = (pitch - base_pitch) * ~pitch_retake
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None])
if self.use_variance_scaling:
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None] / 12)
else:
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None])
else:
if not retake_unset: # retake
base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if self.use_variance_scaling:
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None] / 128)
else:
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])

if infer:
pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True)
@@ -308,12 +357,16 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):

if pitch is None:
pitch = base_pitch + pitch_pred_out
var_cond = condition + self.pitch_embed(pitch[:, :, None])
if self.use_variance_scaling:
var_cond = condition + self.pitch_embed(pitch[:, :, None] / 12)
else:
var_cond = condition + self.pitch_embed(pitch[:, :, None])

variance_inputs = self.collect_variance_inputs(**kwargs)

if variance_retake is not None:
variance_embeds = [
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None]
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name]
for v_name, v_input in zip(self.variance_prediction_list, variance_inputs)
]
var_cond += torch.stack(variance_embeds, dim=-1).sum(-1)


+ 2
- 1
utils/__init__.py View File

@@ -315,8 +315,9 @@ def build_lr_scheduler_from_config(optimizer, scheduler_args):


def simulate_lr_scheduler(optimizer_args, scheduler_args, step_count, num_param_groups=1):
optimizer_cls = optimizer_args['optimizer_cls']
optimizer = build_object_from_class_name(
optimizer_args['optimizer_cls'],
'torch.optim.AdamW' if optimizer_cls == 'modules.optimizer.muon.Muon_AdamW' else optimizer_cls,
torch.optim.Optimizer,
[{'params': torch.nn.Parameter(), 'initial_lr': optimizer_args['lr']} for _ in range(num_param_groups)],
**optimizer_args


+ 8
- 5
utils/binarizer_utils.py View File

@@ -214,13 +214,16 @@ class SinusoidalSmoothingConv1d(torch.nn.Conv1d):
super().__init__(
in_channels=1,
out_channels=1,
kernel_size=kernel_size,
kernel_size=max(kernel_size, 1),
bias=False,
padding='same',
padding_mode='replicate'
)
smooth_kernel = torch.sin(torch.from_numpy(
np.linspace(0, 1, kernel_size).astype(np.float32) * np.pi
))
smooth_kernel /= smooth_kernel.sum()
if kernel_size > 1:
smooth_kernel = torch.sin(torch.from_numpy(
np.linspace(0, 1, kernel_size).astype(np.float32) * np.pi
))
smooth_kernel /= smooth_kernel.sum()
else:
smooth_kernel = torch.tensor([1.0], dtype=torch.float32)
self.weight.data = smooth_kernel[None, None]

Loading…
Cancel
Save
Baidu
map