CosyVoice commited on
Commit
9504c3f
1 Parent(s): 553244b

fix flow matching training for zero shot inference

Browse files
Files changed (1) hide show
  1. cosyvoice/flow/flow.py +6 -0
cosyvoice/flow/flow.py CHANGED
@@ -12,6 +12,7 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import logging
 
15
  from typing import Dict, Optional
16
  import torch
17
  import torch.nn as nn
@@ -77,6 +78,11 @@ class MaskedDiffWithXvec(torch.nn.Module):
77
 
78
  # get conditions
79
  conds = torch.zeros(feat.shape, device=token.device)
 
 
 
 
 
80
  conds = conds.transpose(1, 2)
81
 
82
  mask = (~make_pad_mask(feat_len)).to(h)
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import logging
15
+ import random
16
  from typing import Dict, Optional
17
  import torch
18
  import torch.nn as nn
 
78
 
79
  # get conditions
80
  conds = torch.zeros(feat.shape, device=token.device)
81
+ for i, j in enumerate(feat_len):
82
+ if random.random() < 0.5:
83
+ continue
84
+ index = random.randint(0, int(0.3 * j))
85
+ conds[i, :index] = feat[i, :index]
86
  conds = conds.transpose(1, 2)
87
 
88
  mask = (~make_pad_mask(feat_len)).to(h)