@@ -17,7 +17,7 @@ from modules.core import (
RectifiedFlow, PitchRectifiedFlow, MultiVarianceRectifiedFlow
)
from modules.fastspeech.acoustic_encoder import FastSpeech2Acoustic
from modules.fastspeech.bbc_mask import fast_bbc_mask
from modules.fastspeech.bbc_mask import fast_bbc_mask, fast_fast_bbc_mask
from modules.fastspeech.param_adaptor import ParameterAdaptorModule
from modules.fastspeech.tts_modules import RhythmRegulator, LengthRegulator
from modules.fastspeech.variance_encoder import FastSpeech2Variance, MelodyEncoder
@@ -181,12 +181,18 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
)
else:
raise ValueError(f"Invalid diffusion type: {self.diffusion_type}")
self.use_bbc_encoder = hparams.get('use_bbc_encoder', False)
if self.use_bbc_encoder:
self.bbc_mask_len = hparams['bbc_mask_len']
self.bbc_min_segment_length=hparams['bbc_min_segment_length']
self.bbc_mask_prob=hparams['bbc_mask_prob']
self.bbc_mask_emb=nn.Parameter(torch.randn(1, 1, hparams['hidden_size']))
self.use_me_bbc_encoder = hparams.get('use_me_bbc_encoder', False)
if self.use_me_bbc_encoder:
self.me_bbc_mask_len = hparams['me_bbc_mask_len']
self.me_bbc_min_segment_length=hparams['me_bbc_min_segment_length']
self.me_bbc_mask_prob=hparams['me_bbc_mask_prob']
self.me_bbc_mask_emb=nn.Parameter(torch.randn(1, 1, hparams['hidden_size']))
if self.predict_variances:
self.pitch_embed = Linear(1, hparams['hidden_size'])
@@ -202,6 +208,28 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
else:
raise NotImplementedError(self.diffusion_type)
self.use_variance_scaling = hparams.get('use_variance_scaling', False)
self.custom_variance_scaling_factor = {
'energy': 1. / 96,
'breathiness': 1. / 96,
'voicing': 1. / 96,
'tension': 0.1,
'key_shift': 1. / 12,
'speed': 1.
}
self.default_variance_scaling_factor = {
'energy': 1.,
'breathiness': 1.,
'voicing': 1.,
'tension': 1.,
'key_shift': 1.,
'speed': 1.
}
if self.use_variance_scaling:
self.variance_retake_scaling = self.custom_variance_scaling_factor
else:
self.variance_retake_scaling = self.default_variance_scaling_factor
def forward(
self, txt_tokens, midi, ph2word, ph_dur=None, word_dur=None, mel2ph=None,
note_midi=None, note_rest=None, note_dur=None, note_glide=None, mel2note=None,
@@ -244,7 +272,7 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
encoder_out=torch.cat([self.bbc_mask_emb.expand(mel2ph.shape[0],1,encoder_out.shape[-1]),encoder_out],dim=1)
encoder_out = F.pad(encoder_out, [0, 0, 1, 0])
mel2ph=fast_bbc_mask(mel2ph,mask_length=self.bbc_mask_len,min_segment_length=self.bbc_min_segment_length,mask_prob=self.bbc_mask_prob)
mel2ph=fast_fast_ bbc_mask(mel2ph,mask_length=self.bbc_mask_len,min_segment_length=self.bbc_min_segment_length,mask_prob=self.bbc_mask_prob)
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
condition = torch.gather(encoder_out, 1, mel2ph_)
else:
@@ -261,9 +289,24 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
note_midi, note_rest, note_dur,
glide=note_glide
)
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0])
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']])
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_)
if self.use_me_bbc_encoder:
melody_encoder_out = torch.cat(
[self.me_bbc_mask_emb.expand(mel2note.shape[0], 1, melody_encoder_out.shape[-1]), melody_encoder_out], dim=1)
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0])
mel2note = fast_fast_bbc_mask(mel2note, mask_length=self.bbc_mask_len,
min_segment_length=self.bbc_min_segment_length,
mask_prob=self.bbc_mask_prob)
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']])
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_)
else:
melody_encoder_out = F.pad(melody_encoder_out, [0, 0, 1, 0])
mel2note_ = mel2note[..., None].repeat([1, 1, hparams['hidden_size']])
melody_condition = torch.gather(melody_encoder_out, 1, mel2note_)
pitch_cond = condition + melody_condition
else:
pitch_cond = condition.clone() # preserve the original tensor to avoid further inplace operations
@@ -290,11 +333,17 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
delta_pitch_in = torch.zeros_like(base_pitch)
else:
delta_pitch_in = (pitch - base_pitch) * ~pitch_retake
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None])
if self.use_variance_scaling:
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None] / 12)
else:
pitch_cond += self.delta_pitch_embed(delta_pitch_in[:, :, None])
else:
if not retake_unset: # retake
base_pitch = base_pitch * pitch_retake + pitch * ~pitch_retake
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if self.use_variance_scaling:
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None] / 128)
else:
pitch_cond += self.base_pitch_embed(base_pitch[:, :, None])
if infer:
pitch_pred_out = self.pitch_predictor(pitch_cond, infer=True)
@@ -308,12 +357,16 @@ class DiffSingerVariance(CategorizedModule, ParameterAdaptorModule):
if pitch is None:
pitch = base_pitch + pitch_pred_out
var_cond = condition + self.pitch_embed(pitch[:, :, None])
if self.use_variance_scaling:
var_cond = condition + self.pitch_embed(pitch[:, :, None] / 12)
else:
var_cond = condition + self.pitch_embed(pitch[:, :, None])
variance_inputs = self.collect_variance_inputs(**kwargs)
if variance_retake is not None:
variance_embeds = [
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None]
self.variance_embeds[v_name](v_input[:, :, None]) * ~variance_retake[v_name][:, :, None] * self.variance_retake_scaling[v_name]
for v_name, v_input in zip(self.variance_prediction_list, variance_inputs)
]
var_cond += torch.stack(variance_embeds, dim=-1).sum(-1)