Plachta commited on
Commit
095aba1
·
1 Parent(s): 558a3d7

Update ONNXVITS_infer.py

Browse files
Files changed (1) hide show
  1. ONNXVITS_infer.py +127 -8
ONNXVITS_infer.py CHANGED
@@ -1,6 +1,102 @@
1
  import torch
2
  import commons
3
  import models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  class SynthesizerTrn(models.SynthesizerTrn):
5
  """
6
  Synthesizer for Training
@@ -26,6 +122,7 @@ class SynthesizerTrn(models.SynthesizerTrn):
26
  n_speakers=0,
27
  gin_channels=0,
28
  use_sdp=True,
 
29
  **kwargs):
30
 
31
  super().__init__(
@@ -50,16 +147,21 @@ class SynthesizerTrn(models.SynthesizerTrn):
50
  use_sdp=use_sdp,
51
  **kwargs
52
  )
 
 
 
 
 
 
 
 
 
 
53
 
54
- def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None):
55
  from ONNXVITS_utils import runonnx
56
 
57
- #x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths)
58
- x, m_p, logs_p, x_mask = runonnx("ONNX_net/enc_p.onnx", x=x.numpy(), x_lengths=x_lengths.numpy())
59
- x = torch.from_numpy(x)
60
- m_p = torch.from_numpy(m_p)
61
- logs_p = torch.from_numpy(logs_p)
62
- x_mask = torch.from_numpy(x_mask)
63
 
64
  if self.n_speakers > 0:
65
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
@@ -151,4 +253,21 @@ class SynthesizerTrn(models.SynthesizerTrn):
151
  o = runonnx("ONNX_net/dec.onnx", z_in=(z * y_mask)[:,:,:max_len].numpy(), g=g.numpy())
152
  o = torch.from_numpy(o[0])
153
 
154
- return o, attn, y_mask, (z, z_p, m_p, logs_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import commons
3
  import models
4
+
5
+ import math
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ import modules
10
+ import attentions
11
+ import monotonic_align
12
+
13
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
14
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
15
+ from commons import init_weights, get_padding
16
+
17
+ class TextEncoder(nn.Module):
18
+ def __init__(self,
19
+ n_vocab,
20
+ out_channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ emotion_embedding):
28
+ super().__init__()
29
+ self.n_vocab = n_vocab
30
+ self.out_channels = out_channels
31
+ self.hidden_channels = hidden_channels
32
+ self.filter_channels = filter_channels
33
+ self.n_heads = n_heads
34
+ self.n_layers = n_layers
35
+ self.kernel_size = kernel_size
36
+ self.p_dropout = p_dropout
37
+ self.emotion_embedding = emotion_embedding
38
+
39
+ if self.n_vocab!=0:
40
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
41
+ if emotion_embedding:
42
+ self.emo_proj = nn.Linear(1024, hidden_channels)
43
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
44
+
45
+ self.encoder = attentions.Encoder(
46
+ hidden_channels,
47
+ filter_channels,
48
+ n_heads,
49
+ n_layers,
50
+ kernel_size,
51
+ p_dropout)
52
+ self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1)
53
+
54
+ def forward(self, x, x_lengths, emotion_embedding=None):
55
+ if self.n_vocab!=0:
56
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
57
+ if emotion_embedding is not None:
58
+ print("emotion added")
59
+ x = x + self.emo_proj(emotion_embedding.unsqueeze(1))
60
+ x = torch.transpose(x, 1, -1) # [b, h, t]
61
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
62
+
63
+ x = self.encoder(x * x_mask, x_mask)
64
+ stats = self.proj(x) * x_mask
65
+
66
+ m, logs = torch.split(stats, self.out_channels, dim=1)
67
+ return x, m, logs, x_mask
68
+
69
+ class PosteriorEncoder(nn.Module):
70
+ def __init__(self,
71
+ in_channels,
72
+ out_channels,
73
+ hidden_channels,
74
+ kernel_size,
75
+ dilation_rate,
76
+ n_layers,
77
+ gin_channels=0):
78
+ super().__init__()
79
+ self.in_channels = in_channels
80
+ self.out_channels = out_channels
81
+ self.hidden_channels = hidden_channels
82
+ self.kernel_size = kernel_size
83
+ self.dilation_rate = dilation_rate
84
+ self.n_layers = n_layers
85
+ self.gin_channels = gin_channels
86
+
87
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
88
+ self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
89
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
90
+
91
+ def forward(self, x, x_lengths, g=None):
92
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
93
+ x = self.pre(x) * x_mask
94
+ x = self.enc(x, x_mask, g=g)
95
+ stats = self.proj(x) * x_mask
96
+ m, logs = torch.split(stats, self.out_channels, dim=1)
97
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
98
+ return z, m, logs, x_mask
99
+
100
  class SynthesizerTrn(models.SynthesizerTrn):
101
  """
102
  Synthesizer for Training
 
122
  n_speakers=0,
123
  gin_channels=0,
124
  use_sdp=True,
125
+ emotion_embedding=False,
126
  **kwargs):
127
 
128
  super().__init__(
 
147
  use_sdp=use_sdp,
148
  **kwargs
149
  )
150
+ self.enc_p = TextEncoder(n_vocab,
151
+ inter_channels,
152
+ hidden_channels,
153
+ filter_channels,
154
+ n_heads,
155
+ n_layers,
156
+ kernel_size,
157
+ p_dropout,
158
+ emotion_embedding)
159
+ self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
160
 
161
+ def infer(self, x, x_lengths, sid=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None, emotion_embedding=None):
162
  from ONNXVITS_utils import runonnx
163
 
164
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emotion_embedding)
 
 
 
 
 
165
 
166
  if self.n_speakers > 0:
167
  g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
 
253
  o = runonnx("ONNX_net/dec.onnx", z_in=(z * y_mask)[:,:,:max_len].numpy(), g=g.numpy())
254
  o = torch.from_numpy(o[0])
255
 
256
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
257
+
258
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
259
+ from ONNXVITS_utils import runonnx
260
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
261
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
262
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
263
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
264
+ # z_p = self.flow(z, y_mask, g=g_src)
265
+ z_p = runonnx("ONNX_net/flow.onnx", z_p=z.numpy(), y_mask=y_mask.numpy(), g=g_src.numpy())
266
+ z_p = torch.from_numpy(z_p[0])
267
+ # z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
268
+ z_hat = runonnx("ONNX_net/flow.onnx", z_p=z_p.numpy(), y_mask=y_mask.numpy(), g=g_tgt.numpy())
269
+ z_hat = torch.from_numpy(z_hat[0])
270
+ # o_hat = self.dec(z_hat * y_mask, g=g_tgt)
271
+ o_hat = runonnx("ONNX_net/dec.onnx", z_in=(z_hat * y_mask).numpy(), g=g_tgt.numpy())
272
+ o_hat = torch.from_numpy(o_hat[0])
273
+ return o_hat, y_mask, (z, z_p, z_hat)