3 Commits

Author SHA1 Message Date
  yxlllc 7fb8139f14 optimize 2 months ago
  yxlllc fe1bbca134 fix 2 months ago
  yxlllc 16492d90f8 fix 2 months ago
1 changed files with 3 additions and 4 deletions
Split View
  1. +3
    -4
      modules/fastspeech/tts_modules.py

+ 3
- 4
modules/fastspeech/tts_modules.py View File

@@ -347,12 +347,11 @@ class StretchRegulator(torch.nn.Module):
"""
if dur is None:
dur = mel2ph_to_dur(mel2ph, mel2ph.max())
dur = F.pad(dur, [1, 0], value=1) # Avoid dividing by zero
dur = torch.cat([torch.ones_like(dur[:, :1]), dur], dim=1) # Avoid dividing by zero
mel2dur = torch.gather(dur, 1, mel2ph)
bound_mask = torch.gt(mel2ph[:, 1:], mel2ph[:, :-1])
bound_mask = F.pad(bound_mask, [0, 1], mode='constant', value=True)
stretch_delta = 1 - bound_mask * mel2dur
stretch_delta = F.pad(stretch_delta, [1, -1], mode='constant', value=0)
stretch_delta = 1 - bound_mask * mel2dur[:, :-1]
stretch_delta = F.pad(stretch_delta, [1, 0])
stretch_denorm = torch.cumsum(stretch_delta, dim=1)
stretch = stretch_denorm.float() / mel2dur
return stretch * (mel2ph > 0)


Loading…
Cancel
Save
Baidu
map