CaiRou-Huang commited on
Commit
adfdf4e
·
verified ·
1 Parent(s): a94413f

Upload 26 files

Browse files
Files changed (25) hide show
  1. .gitignore +1 -0
  2. LICENSE +661 -0
  3. README.md +8 -5
  4. app.py +424 -0
  5. attentions.py +462 -0
  6. bert_gen.py +84 -0
  7. commons.py +152 -0
  8. config.py +269 -0
  9. config.yml +51 -0
  10. data_utils.py +425 -0
  11. default_config.yml +81 -0
  12. infer.py +306 -0
  13. losses.py +153 -0
  14. mel_processing.py +146 -0
  15. models.py +1024 -0
  16. models_jp_extra.py +1071 -0
  17. modules.py +581 -0
  18. preprocess_text.py +146 -0
  19. re_matching.py +81 -0
  20. requirements.txt +27 -0
  21. server_fastapi.py +263 -0
  22. spec_gen.py +87 -0
  23. style_gen.py +128 -0
  24. transforms.py +209 -0
  25. utils.py +501 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
README.md CHANGED
@@ -1,10 +1,13 @@
1
  ---
2
- title: Bert Vits2
3
- emoji: 📈
4
- colorFrom: yellow
5
- colorTo: green
6
- sdk: docker
 
 
7
  pinned: false
 
8
  ---
9
 
10
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Style-Bert-VITS2 JVNV
3
+ emoji: 😡😊😱😫
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.16.0
8
+ app_file: app.py
9
  pinned: false
10
+ license: agpl-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import os
5
+ import sys
6
+ from typing import Optional
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import yaml
11
+
12
+ from common.constants import (
13
+ DEFAULT_ASSIST_TEXT_WEIGHT,
14
+ DEFAULT_LENGTH,
15
+ DEFAULT_LINE_SPLIT,
16
+ DEFAULT_NOISE,
17
+ DEFAULT_NOISEW,
18
+ DEFAULT_SDP_RATIO,
19
+ DEFAULT_SPLIT_INTERVAL,
20
+ DEFAULT_STYLE,
21
+ DEFAULT_STYLE_WEIGHT,
22
+ Languages,
23
+ )
24
+ from common.log import logger
25
+ from common.tts_model import ModelHolder
26
+ from infer import InvalidToneError
27
+ from text.japanese import g2kata_tone, kata_tone2phone_tone, text_normalize
28
+
29
+ is_hf_spaces = os.getenv("SYSTEM") == "spaces"
30
+ limit = 10000
31
+
32
+ # Get path settings
33
+ with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
34
+ path_config: dict[str, str] = yaml.safe_load(f.read())
35
+ # dataset_root = path_config["dataset_root"]
36
+ assets_root = path_config["assets_root"]
37
+
38
+ languages = [l.value for l in Languages]
39
+
40
+
41
+ def tts_fn(
42
+ model_name,
43
+ model_path,
44
+ text,
45
+ language,
46
+ reference_audio_path,
47
+ sdp_ratio,
48
+ noise_scale,
49
+ noise_scale_w,
50
+ length_scale,
51
+ line_split,
52
+ split_interval,
53
+ assist_text,
54
+ assist_text_weight,
55
+ use_assist_text,
56
+ style,
57
+ style_weight,
58
+ kata_tone_json_str,
59
+ use_tone,
60
+ speaker,
61
+ ):
62
+ if is_hf_spaces and len(text) > limit:
63
+ logger.error(f"Text is too long: {len(text)}")
64
+ return (
65
+ f"Error: 文字数が多すぎます({limit}文字以下にしてください)",
66
+ None,
67
+ kata_tone_json_str,
68
+ )
69
+ model_holder.load_model_gr(model_name, model_path)
70
+
71
+ wrong_tone_message = ""
72
+ kata_tone: Optional[list[tuple[str, int]]] = None
73
+ if use_tone and kata_tone_json_str != "":
74
+ if language != "JP":
75
+ logger.warning("Only Japanese is supported for tone generation.")
76
+ wrong_tone_message = "アクセント指定は現在日本語のみ対応しています。"
77
+ if line_split:
78
+ logger.warning("Tone generation is not supported for line split.")
79
+ wrong_tone_message = (
80
+ "アクセント指定は改行で分けて生成を使わない場合のみ対応しています。"
81
+ )
82
+ try:
83
+ kata_tone = []
84
+ json_data = json.loads(kata_tone_json_str)
85
+ # tupleを使うように変換
86
+ for kana, tone in json_data:
87
+ assert isinstance(kana, str) and tone in (0, 1), f"{kana}, {tone}"
88
+ kata_tone.append((kana, tone))
89
+ except Exception as e:
90
+ logger.warning(f"Error occurred when parsing kana_tone_json: {e}")
91
+ wrong_tone_message = f"アクセント指定が不正です: {e}"
92
+ kata_tone = None
93
+
94
+ # toneは実際に音声合成に代入される際のみnot Noneになる
95
+ tone: Optional[list[int]] = None
96
+ if kata_tone is not None:
97
+ phone_tone = kata_tone2phone_tone(kata_tone)
98
+ tone = [t for _, t in phone_tone]
99
+
100
+ speaker_id = model_holder.current_model.spk2id[speaker]
101
+
102
+ start_time = datetime.datetime.now()
103
+
104
+ try:
105
+ sr, audio = model_holder.current_model.infer(
106
+ text=text,
107
+ language=language,
108
+ reference_audio_path=reference_audio_path,
109
+ sdp_ratio=sdp_ratio,
110
+ noise=noise_scale,
111
+ noisew=noise_scale_w,
112
+ length=length_scale,
113
+ line_split=line_split,
114
+ split_interval=split_interval,
115
+ assist_text=assist_text,
116
+ assist_text_weight=assist_text_weight,
117
+ use_assist_text=use_assist_text,
118
+ style=style,
119
+ style_weight=style_weight,
120
+ given_tone=tone,
121
+ sid=speaker_id,
122
+ )
123
+ except InvalidToneError as e:
124
+ logger.error(f"Tone error: {e}")
125
+ return f"Error: アクセント指定が不正です:\n{e}", None, kata_tone_json_str
126
+ except ValueError as e:
127
+ logger.error(f"Value error: {e}")
128
+ return f"Error: {e}", None, kata_tone_json_str
129
+
130
+ end_time = datetime.datetime.now()
131
+ duration = (end_time - start_time).total_seconds()
132
+
133
+ if tone is None and language == "JP":
134
+ # アクセント指定に使えるようにアクセント情報を返す
135
+ norm_text = text_normalize(text)
136
+ kata_tone = g2kata_tone(norm_text)
137
+ kata_tone_json_str = json.dumps(kata_tone, ensure_ascii=False)
138
+ elif tone is None:
139
+ kata_tone_json_str = ""
140
+ message = f"Success, time: {duration} seconds."
141
+ if wrong_tone_message != "":
142
+ message = wrong_tone_message + "\n" + message
143
+ return message, (sr, audio), kata_tone_json_str
144
+
145
+
146
+ initial_text = "こんにちは、初めまして。あなたの名前はなんていうの?"
147
+
148
+ example_hf_spaces = [
149
+ [initial_text, "JP"],
150
+ ["えっと、私、あなたのことが好きです!もしよければ付き合ってくれませんか?", "JP"],
151
+ ["吾輩は猫である。名前はまだ無い。", "JP"],
152
+ ["桜の樹の下には屍体が埋まっている!これは信じていいことなんだよ。", "JP"],
153
+ ["やったー!テストで満点取れたよ!私とっても嬉しいな!", "JP"],
154
+ [
155
+ "どうして私の意見を無視するの?許せない!ムカつく!あんたなんか死ねばいいのに。",
156
+ "JP",
157
+ ],
158
+ ["あはははっ!この漫画めっちゃ笑える、見てよこれ、ふふふ、あはは。", "JP"],
159
+ [
160
+ "あなたがいなくなって、私は一人になっちゃって、泣いちゃいそうなほど悲しい。",
161
+ "JP",
162
+ ],
163
+ [
164
+ "深層学習の応用により、感情やアクセントを含む声質の微妙な変化も再現されている。",
165
+ "JP",
166
+ ],
167
+ ]
168
+ initial_md = """
169
+ # Style-Bert-VITS2 音声合成デモ
170
+ 入力テキストの意味に応じて感情豊かな読み上げを生成でき、さらに怒り・悲しみ・喜び等の感情スタイルを強弱付きで制御できる、[Style-Bert-VITS2](https://github.com/litagin02/Style-Bert-VITS2)のデモです。
171
+ 入力上限文字数は100文字までにしています。
172
+ このデモでは[jvnvのモデル](https://huggingface.co/litagin/style_bert_vits2_jvnv)を使っており、[JVNVコーパス(言語音声と非言語音声を持つ日本語感情音声コーパス)](https://sites.google.com/site/shinnosuketakamichi/research-topics/jvnv_corpus)で学習されたモデルです。
173
+ """
174
+
175
+ style_md = f"""
176
+ - プリセットまたは音声ファイルから読み上げの声音・感情・スタイルのようなものを制御できます。
177
+ - デフォルトの{DEFAULT_STYLE}でも、十分に読み上げる文に応じた感情で感情豊かに読み上げられます。このスタイル制御は、それを重み付きで上書きするような感じです。
178
+ - 強さを大きくしすぎると発音が変になったり声にならなかったりと崩壊することがあります。
179
+ - どのくらいに強さがいいかはモデルやスタイルによって異なるようです。
180
+ - 音声ファイルを入力する場合は、学習データと似た声音の話者(特に同じ性別)でないとよい効果が出ないかもしれません。
181
+ """
182
+
183
+
184
+ def make_interactive():
185
+ return gr.update(interactive=True, value="音声合成")
186
+
187
+
188
+ def make_non_interactive():
189
+ return gr.update(interactive=False, value="音声合成(モデルをロードしてください)")
190
+
191
+
192
+ def gr_util(item):
193
+ if item == "プリセットから選ぶ":
194
+ return (gr.update(visible=True), gr.Audio(visible=False, value=None))
195
+ else:
196
+ return (gr.update(visible=False), gr.update(visible=True))
197
+
198
+
199
+ if __name__ == "__main__":
200
+ parser = argparse.ArgumentParser()
201
+ parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
202
+ parser.add_argument(
203
+ "--dir", "-d", type=str, help="Model directory", default=assets_root
204
+ )
205
+ parser.add_argument(
206
+ "--share", action="store_true", help="Share this app publicly", default=False
207
+ )
208
+ parser.add_argument(
209
+ "--server-name",
210
+ type=str,
211
+ default=None,
212
+ help="Server name for Gradio app",
213
+ )
214
+ parser.add_argument(
215
+ "--no-autolaunch",
216
+ action="store_true",
217
+ default=False,
218
+ help="Do not launch app automatically",
219
+ )
220
+ args = parser.parse_args()
221
+ model_dir = args.dir
222
+
223
+ if args.cpu:
224
+ device = "cpu"
225
+ else:
226
+ device = "cuda" if torch.cuda.is_available() else "cpu"
227
+
228
+ model_holder = ModelHolder(model_dir, device)
229
+
230
+ model_names = model_holder.model_names
231
+ if len(model_names) == 0:
232
+ logger.error(
233
+ f"モデルが見つかりませんでした。{model_dir}にモデルを置いてください。"
234
+ )
235
+ sys.exit(1)
236
+ initial_id = 0
237
+ initial_pth_files = model_holder.model_files_dict[model_names[initial_id]]
238
+
239
+ with gr.Blocks(theme="NoCrypt/miku") as app:
240
+ gr.Markdown(initial_md)
241
+ with gr.Row():
242
+ with gr.Column():
243
+ with gr.Row():
244
+ with gr.Column(scale=3):
245
+ model_name = gr.Dropdown(
246
+ label="モデル一覧",
247
+ choices=model_names,
248
+ value=model_names[initial_id],
249
+ )
250
+ model_path = gr.Dropdown(
251
+ label="モデルファイル",
252
+ choices=initial_pth_files,
253
+ value=initial_pth_files[0],
254
+ )
255
+ refresh_button = gr.Button("更新", scale=1, visible=False)
256
+ load_button = gr.Button("ロード", scale=1, variant="primary")
257
+ text_input = gr.TextArea(label="テキスト", value=initial_text)
258
+
259
+ line_split = gr.Checkbox(
260
+ label="改��で分けて生成(分けたほうが感情が乗ります)",
261
+ value=DEFAULT_LINE_SPLIT,
262
+ )
263
+ split_interval = gr.Slider(
264
+ minimum=0.0,
265
+ maximum=2,
266
+ value=DEFAULT_SPLIT_INTERVAL,
267
+ step=0.1,
268
+ label="改行ごとに挟む無音の長さ(秒)",
269
+ )
270
+ line_split.change(
271
+ lambda x: (gr.Slider(visible=x)),
272
+ inputs=[line_split],
273
+ outputs=[split_interval],
274
+ )
275
+ tone = gr.Textbox(
276
+ label="アクセント調整(数値は 0=低 か1=高 のみ)",
277
+ info="改行で分けない場合のみ使えます。万能ではありません。",
278
+ )
279
+ use_tone = gr.Checkbox(label="アクセント調整を使う", value=False)
280
+ use_tone.change(
281
+ lambda x: (gr.Checkbox(value=False) if x else gr.Checkbox()),
282
+ inputs=[use_tone],
283
+ outputs=[line_split],
284
+ )
285
+ language = gr.Dropdown(choices=["JP"], value="JP", label="Language")
286
+ speaker = gr.Dropdown(label="話者")
287
+ with gr.Accordion(label="詳細設定", open=False):
288
+ sdp_ratio = gr.Slider(
289
+ minimum=0,
290
+ maximum=1,
291
+ value=DEFAULT_SDP_RATIO,
292
+ step=0.1,
293
+ label="SDP Ratio",
294
+ )
295
+ noise_scale = gr.Slider(
296
+ minimum=0.1,
297
+ maximum=2,
298
+ value=DEFAULT_NOISE,
299
+ step=0.1,
300
+ label="Noise",
301
+ )
302
+ noise_scale_w = gr.Slider(
303
+ minimum=0.1,
304
+ maximum=2,
305
+ value=DEFAULT_NOISEW,
306
+ step=0.1,
307
+ label="Noise_W",
308
+ )
309
+ length_scale = gr.Slider(
310
+ minimum=0.1,
311
+ maximum=2,
312
+ value=DEFAULT_LENGTH,
313
+ step=0.1,
314
+ label="Length",
315
+ )
316
+ use_assist_text = gr.Checkbox(
317
+ label="Assist textを使う", value=False
318
+ )
319
+ assist_text = gr.Textbox(
320
+ label="Assist text",
321
+ placeholder="どうして私の意見を無視するの?許せない、ムカつく!死ねばいいのに。",
322
+ info="このテキストの読み上げと似た声音・感情になりやすくなります。ただ抑揚やテンポ等が犠牲になる傾向があります。",
323
+ visible=False,
324
+ )
325
+ assist_text_weight = gr.Slider(
326
+ minimum=0,
327
+ maximum=1,
328
+ value=DEFAULT_ASSIST_TEXT_WEIGHT,
329
+ step=0.1,
330
+ label="Assist textの強さ",
331
+ visible=False,
332
+ )
333
+ use_assist_text.change(
334
+ lambda x: (gr.Textbox(visible=x), gr.Slider(visible=x)),
335
+ inputs=[use_assist_text],
336
+ outputs=[assist_text, assist_text_weight],
337
+ )
338
+ with gr.Column():
339
+ with gr.Accordion("スタイルについて詳細", open=False):
340
+ gr.Markdown(style_md)
341
+ style_mode = gr.Radio(
342
+ ["プリセットから選ぶ", "音声ファイルを入力"],
343
+ label="スタイルの指定方法",
344
+ value="プリセットから選ぶ",
345
+ )
346
+ style = gr.Dropdown(
347
+ label=f"スタイル({DEFAULT_STYLE}が平均スタイル)",
348
+ choices=["モデルをロードしてください"],
349
+ value="モデルをロードしてください",
350
+ )
351
+ style_weight = gr.Slider(
352
+ minimum=0,
353
+ maximum=50,
354
+ value=DEFAULT_STYLE_WEIGHT,
355
+ step=0.1,
356
+ label="スタイルの強さ",
357
+ )
358
+ ref_audio_path = gr.Audio(
359
+ label="参照音声", type="filepath", visible=False
360
+ )
361
+ tts_button = gr.Button(
362
+ "音声合成(モデルをロードしてください)",
363
+ variant="primary",
364
+ interactive=False,
365
+ )
366
+ text_output = gr.Textbox(label="情報")
367
+ audio_output = gr.Audio(label="結果")
368
+ with gr.Accordion("テキスト例", open=True):
369
+ gr.Examples(example_hf_spaces, inputs=[text_input, language])
370
+
371
+ tts_button.click(
372
+ tts_fn,
373
+ inputs=[
374
+ model_name,
375
+ model_path,
376
+ text_input,
377
+ language,
378
+ ref_audio_path,
379
+ sdp_ratio,
380
+ noise_scale,
381
+ noise_scale_w,
382
+ length_scale,
383
+ line_split,
384
+ split_interval,
385
+ assist_text,
386
+ assist_text_weight,
387
+ use_assist_text,
388
+ style,
389
+ style_weight,
390
+ tone,
391
+ use_tone,
392
+ speaker,
393
+ ],
394
+ outputs=[text_output, audio_output, tone],
395
+ )
396
+
397
+ model_name.change(
398
+ model_holder.update_model_files_gr,
399
+ inputs=[model_name],
400
+ outputs=[model_path],
401
+ )
402
+
403
+ model_path.change(make_non_interactive, outputs=[tts_button])
404
+
405
+ refresh_button.click(
406
+ model_holder.update_model_names_gr,
407
+ outputs=[model_name, model_path, tts_button],
408
+ )
409
+
410
+ load_button.click(
411
+ model_holder.load_model_gr,
412
+ inputs=[model_name, model_path],
413
+ outputs=[style, tts_button, speaker],
414
+ )
415
+
416
+ style_mode.change(
417
+ gr_util,
418
+ inputs=[style_mode],
419
+ outputs=[style, ref_audio_path],
420
+ )
421
+
422
+ app.launch(
423
+ inbrowser=not args.no_autolaunch, share=args.share, server_name=args.server_name
424
+ )
attentions.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ from common.log import logger as logging
8
+
9
+
10
+ class LayerNorm(nn.Module):
11
+ def __init__(self, channels, eps=1e-5):
12
+ super().__init__()
13
+ self.channels = channels
14
+ self.eps = eps
15
+
16
+ self.gamma = nn.Parameter(torch.ones(channels))
17
+ self.beta = nn.Parameter(torch.zeros(channels))
18
+
19
+ def forward(self, x):
20
+ x = x.transpose(1, -1)
21
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
22
+ return x.transpose(1, -1)
23
+
24
+
25
+ @torch.jit.script
26
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
27
+ n_channels_int = n_channels[0]
28
+ in_act = input_a + input_b
29
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
30
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
31
+ acts = t_act * s_act
32
+ return acts
33
+
34
+
35
+ class Encoder(nn.Module):
36
+ def __init__(
37
+ self,
38
+ hidden_channels,
39
+ filter_channels,
40
+ n_heads,
41
+ n_layers,
42
+ kernel_size=1,
43
+ p_dropout=0.0,
44
+ window_size=4,
45
+ isflow=True,
46
+ **kwargs
47
+ ):
48
+ super().__init__()
49
+ self.hidden_channels = hidden_channels
50
+ self.filter_channels = filter_channels
51
+ self.n_heads = n_heads
52
+ self.n_layers = n_layers
53
+ self.kernel_size = kernel_size
54
+ self.p_dropout = p_dropout
55
+ self.window_size = window_size
56
+ # if isflow:
57
+ # cond_layer = torch.nn.Conv1d(256, 2*hidden_channels*n_layers, 1)
58
+ # self.cond_pre = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, 1)
59
+ # self.cond_layer = weight_norm(cond_layer, name='weight')
60
+ # self.gin_channels = 256
61
+ self.cond_layer_idx = self.n_layers
62
+ if "gin_channels" in kwargs:
63
+ self.gin_channels = kwargs["gin_channels"]
64
+ if self.gin_channels != 0:
65
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
66
+ # vits2 says 3rd block, so idx is 2 by default
67
+ self.cond_layer_idx = (
68
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
69
+ )
70
+ # logging.debug(self.gin_channels, self.cond_layer_idx)
71
+ assert (
72
+ self.cond_layer_idx < self.n_layers
73
+ ), "cond_layer_idx should be less than n_layers"
74
+ self.drop = nn.Dropout(p_dropout)
75
+ self.attn_layers = nn.ModuleList()
76
+ self.norm_layers_1 = nn.ModuleList()
77
+ self.ffn_layers = nn.ModuleList()
78
+ self.norm_layers_2 = nn.ModuleList()
79
+ for i in range(self.n_layers):
80
+ self.attn_layers.append(
81
+ MultiHeadAttention(
82
+ hidden_channels,
83
+ hidden_channels,
84
+ n_heads,
85
+ p_dropout=p_dropout,
86
+ window_size=window_size,
87
+ )
88
+ )
89
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
90
+ self.ffn_layers.append(
91
+ FFN(
92
+ hidden_channels,
93
+ hidden_channels,
94
+ filter_channels,
95
+ kernel_size,
96
+ p_dropout=p_dropout,
97
+ )
98
+ )
99
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
100
+
101
+ def forward(self, x, x_mask, g=None):
102
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
103
+ x = x * x_mask
104
+ for i in range(self.n_layers):
105
+ if i == self.cond_layer_idx and g is not None:
106
+ g = self.spk_emb_linear(g.transpose(1, 2))
107
+ g = g.transpose(1, 2)
108
+ x = x + g
109
+ x = x * x_mask
110
+ y = self.attn_layers[i](x, x, attn_mask)
111
+ y = self.drop(y)
112
+ x = self.norm_layers_1[i](x + y)
113
+
114
+ y = self.ffn_layers[i](x, x_mask)
115
+ y = self.drop(y)
116
+ x = self.norm_layers_2[i](x + y)
117
+ x = x * x_mask
118
+ return x
119
+
120
+
121
+ class Decoder(nn.Module):
122
+ def __init__(
123
+ self,
124
+ hidden_channels,
125
+ filter_channels,
126
+ n_heads,
127
+ n_layers,
128
+ kernel_size=1,
129
+ p_dropout=0.0,
130
+ proximal_bias=False,
131
+ proximal_init=True,
132
+ **kwargs
133
+ ):
134
+ super().__init__()
135
+ self.hidden_channels = hidden_channels
136
+ self.filter_channels = filter_channels
137
+ self.n_heads = n_heads
138
+ self.n_layers = n_layers
139
+ self.kernel_size = kernel_size
140
+ self.p_dropout = p_dropout
141
+ self.proximal_bias = proximal_bias
142
+ self.proximal_init = proximal_init
143
+
144
+ self.drop = nn.Dropout(p_dropout)
145
+ self.self_attn_layers = nn.ModuleList()
146
+ self.norm_layers_0 = nn.ModuleList()
147
+ self.encdec_attn_layers = nn.ModuleList()
148
+ self.norm_layers_1 = nn.ModuleList()
149
+ self.ffn_layers = nn.ModuleList()
150
+ self.norm_layers_2 = nn.ModuleList()
151
+ for i in range(self.n_layers):
152
+ self.self_attn_layers.append(
153
+ MultiHeadAttention(
154
+ hidden_channels,
155
+ hidden_channels,
156
+ n_heads,
157
+ p_dropout=p_dropout,
158
+ proximal_bias=proximal_bias,
159
+ proximal_init=proximal_init,
160
+ )
161
+ )
162
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
163
+ self.encdec_attn_layers.append(
164
+ MultiHeadAttention(
165
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
166
+ )
167
+ )
168
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
169
+ self.ffn_layers.append(
170
+ FFN(
171
+ hidden_channels,
172
+ hidden_channels,
173
+ filter_channels,
174
+ kernel_size,
175
+ p_dropout=p_dropout,
176
+ causal=True,
177
+ )
178
+ )
179
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
180
+
181
+ def forward(self, x, x_mask, h, h_mask):
182
+ """
183
+ x: decoder input
184
+ h: encoder output
185
+ """
186
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
187
+ device=x.device, dtype=x.dtype
188
+ )
189
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
190
+ x = x * x_mask
191
+ for i in range(self.n_layers):
192
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
193
+ y = self.drop(y)
194
+ x = self.norm_layers_0[i](x + y)
195
+
196
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
197
+ y = self.drop(y)
198
+ x = self.norm_layers_1[i](x + y)
199
+
200
+ y = self.ffn_layers[i](x, x_mask)
201
+ y = self.drop(y)
202
+ x = self.norm_layers_2[i](x + y)
203
+ x = x * x_mask
204
+ return x
205
+
206
+
207
+ class MultiHeadAttention(nn.Module):
208
+ def __init__(
209
+ self,
210
+ channels,
211
+ out_channels,
212
+ n_heads,
213
+ p_dropout=0.0,
214
+ window_size=None,
215
+ heads_share=True,
216
+ block_length=None,
217
+ proximal_bias=False,
218
+ proximal_init=False,
219
+ ):
220
+ super().__init__()
221
+ assert channels % n_heads == 0
222
+
223
+ self.channels = channels
224
+ self.out_channels = out_channels
225
+ self.n_heads = n_heads
226
+ self.p_dropout = p_dropout
227
+ self.window_size = window_size
228
+ self.heads_share = heads_share
229
+ self.block_length = block_length
230
+ self.proximal_bias = proximal_bias
231
+ self.proximal_init = proximal_init
232
+ self.attn = None
233
+
234
+ self.k_channels = channels // n_heads
235
+ self.conv_q = nn.Conv1d(channels, channels, 1)
236
+ self.conv_k = nn.Conv1d(channels, channels, 1)
237
+ self.conv_v = nn.Conv1d(channels, channels, 1)
238
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
239
+ self.drop = nn.Dropout(p_dropout)
240
+
241
+ if window_size is not None:
242
+ n_heads_rel = 1 if heads_share else n_heads
243
+ rel_stddev = self.k_channels**-0.5
244
+ self.emb_rel_k = nn.Parameter(
245
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
246
+ * rel_stddev
247
+ )
248
+ self.emb_rel_v = nn.Parameter(
249
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
250
+ * rel_stddev
251
+ )
252
+
253
+ nn.init.xavier_uniform_(self.conv_q.weight)
254
+ nn.init.xavier_uniform_(self.conv_k.weight)
255
+ nn.init.xavier_uniform_(self.conv_v.weight)
256
+ if proximal_init:
257
+ with torch.no_grad():
258
+ self.conv_k.weight.copy_(self.conv_q.weight)
259
+ self.conv_k.bias.copy_(self.conv_q.bias)
260
+
261
+ def forward(self, x, c, attn_mask=None):
262
+ q = self.conv_q(x)
263
+ k = self.conv_k(c)
264
+ v = self.conv_v(c)
265
+
266
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
267
+
268
+ x = self.conv_o(x)
269
+ return x
270
+
271
+ def attention(self, query, key, value, mask=None):
272
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
273
+ b, d, t_s, t_t = (*key.size(), query.size(2))
274
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
275
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
276
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
277
+
278
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
279
+ if self.window_size is not None:
280
+ assert (
281
+ t_s == t_t
282
+ ), "Relative attention is only available for self-attention."
283
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
284
+ rel_logits = self._matmul_with_relative_keys(
285
+ query / math.sqrt(self.k_channels), key_relative_embeddings
286
+ )
287
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
288
+ scores = scores + scores_local
289
+ if self.proximal_bias:
290
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
291
+ scores = scores + self._attention_bias_proximal(t_s).to(
292
+ device=scores.device, dtype=scores.dtype
293
+ )
294
+ if mask is not None:
295
+ scores = scores.masked_fill(mask == 0, -1e4)
296
+ if self.block_length is not None:
297
+ assert (
298
+ t_s == t_t
299
+ ), "Local attention is only available for self-attention."
300
+ block_mask = (
301
+ torch.ones_like(scores)
302
+ .triu(-self.block_length)
303
+ .tril(self.block_length)
304
+ )
305
+ scores = scores.masked_fill(block_mask == 0, -1e4)
306
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
307
+ p_attn = self.drop(p_attn)
308
+ output = torch.matmul(p_attn, value)
309
+ if self.window_size is not None:
310
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
311
+ value_relative_embeddings = self._get_relative_embeddings(
312
+ self.emb_rel_v, t_s
313
+ )
314
+ output = output + self._matmul_with_relative_values(
315
+ relative_weights, value_relative_embeddings
316
+ )
317
+ output = (
318
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
319
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
320
+ return output, p_attn
321
+
322
+ def _matmul_with_relative_values(self, x, y):
323
+ """
324
+ x: [b, h, l, m]
325
+ y: [h or 1, m, d]
326
+ ret: [b, h, l, d]
327
+ """
328
+ ret = torch.matmul(x, y.unsqueeze(0))
329
+ return ret
330
+
331
+ def _matmul_with_relative_keys(self, x, y):
332
+ """
333
+ x: [b, h, l, d]
334
+ y: [h or 1, m, d]
335
+ ret: [b, h, l, m]
336
+ """
337
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
338
+ return ret
339
+
340
+ def _get_relative_embeddings(self, relative_embeddings, length):
341
+ 2 * self.window_size + 1
342
+ # Pad first before slice to avoid using cond ops.
343
+ pad_length = max(length - (self.window_size + 1), 0)
344
+ slice_start_position = max((self.window_size + 1) - length, 0)
345
+ slice_end_position = slice_start_position + 2 * length - 1
346
+ if pad_length > 0:
347
+ padded_relative_embeddings = F.pad(
348
+ relative_embeddings,
349
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
350
+ )
351
+ else:
352
+ padded_relative_embeddings = relative_embeddings
353
+ used_relative_embeddings = padded_relative_embeddings[
354
+ :, slice_start_position:slice_end_position
355
+ ]
356
+ return used_relative_embeddings
357
+
358
+ def _relative_position_to_absolute_position(self, x):
359
+ """
360
+ x: [b, h, l, 2*l-1]
361
+ ret: [b, h, l, l]
362
+ """
363
+ batch, heads, length, _ = x.size()
364
+ # Concat columns of pad to shift from relative to absolute indexing.
365
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
366
+
367
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
368
+ x_flat = x.view([batch, heads, length * 2 * length])
369
+ x_flat = F.pad(
370
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
371
+ )
372
+
373
+ # Reshape and slice out the padded elements.
374
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
375
+ :, :, :length, length - 1 :
376
+ ]
377
+ return x_final
378
+
379
+ def _absolute_position_to_relative_position(self, x):
380
+ """
381
+ x: [b, h, l, l]
382
+ ret: [b, h, l, 2*l-1]
383
+ """
384
+ batch, heads, length, _ = x.size()
385
+ # pad along column
386
+ x = F.pad(
387
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
388
+ )
389
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
390
+ # add 0's in the beginning that will skew the elements after reshape
391
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
392
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
393
+ return x_final
394
+
395
+ def _attention_bias_proximal(self, length):
396
+ """Bias for self-attention to encourage attention to close positions.
397
+ Args:
398
+ length: an integer scalar.
399
+ Returns:
400
+ a Tensor with shape [1, 1, length, length]
401
+ """
402
+ r = torch.arange(length, dtype=torch.float32)
403
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
404
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
405
+
406
+
407
+ class FFN(nn.Module):
408
+ def __init__(
409
+ self,
410
+ in_channels,
411
+ out_channels,
412
+ filter_channels,
413
+ kernel_size,
414
+ p_dropout=0.0,
415
+ activation=None,
416
+ causal=False,
417
+ ):
418
+ super().__init__()
419
+ self.in_channels = in_channels
420
+ self.out_channels = out_channels
421
+ self.filter_channels = filter_channels
422
+ self.kernel_size = kernel_size
423
+ self.p_dropout = p_dropout
424
+ self.activation = activation
425
+ self.causal = causal
426
+
427
+ if causal:
428
+ self.padding = self._causal_padding
429
+ else:
430
+ self.padding = self._same_padding
431
+
432
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
433
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
434
+ self.drop = nn.Dropout(p_dropout)
435
+
436
+ def forward(self, x, x_mask):
437
+ x = self.conv_1(self.padding(x * x_mask))
438
+ if self.activation == "gelu":
439
+ x = x * torch.sigmoid(1.702 * x)
440
+ else:
441
+ x = torch.relu(x)
442
+ x = self.drop(x)
443
+ x = self.conv_2(self.padding(x * x_mask))
444
+ return x * x_mask
445
+
446
+ def _causal_padding(self, x):
447
+ if self.kernel_size == 1:
448
+ return x
449
+ pad_l = self.kernel_size - 1
450
+ pad_r = 0
451
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
452
+ x = F.pad(x, commons.convert_pad_shape(padding))
453
+ return x
454
+
455
+ def _same_padding(self, x):
456
+ if self.kernel_size == 1:
457
+ return x
458
+ pad_l = (self.kernel_size - 1) // 2
459
+ pad_r = self.kernel_size // 2
460
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
461
+ x = F.pad(x, commons.convert_pad_shape(padding))
462
+ return x
bert_gen.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import sys
3
+ from multiprocessing import Pool
4
+
5
+ import torch
6
+ import torch.multiprocessing as mp
7
+ from tqdm import tqdm
8
+
9
+ import commons
10
+ import utils
11
+ from config import config
12
+ from text import cleaned_text_to_sequence, get_bert
13
+
14
+
15
+ def process_line(x):
16
+ line, add_blank = x
17
+ device = config.bert_gen_config.device
18
+ if config.bert_gen_config.use_multi_device:
19
+ rank = mp.current_process()._identity
20
+ rank = rank[0] if len(rank) > 0 else 0
21
+ if torch.cuda.is_available():
22
+ gpu_id = rank % torch.cuda.device_count()
23
+ device = torch.device(f"cuda:{gpu_id}")
24
+ else:
25
+ device = torch.device("cpu")
26
+ wav_path, _, language_str, text, phones, tone, word2ph = line.strip().split("|")
27
+ phone = phones.split(" ")
28
+ tone = [int(i) for i in tone.split(" ")]
29
+ word2ph = [int(i) for i in word2ph.split(" ")]
30
+ word2ph = [i for i in word2ph]
31
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
32
+
33
+ if add_blank:
34
+ phone = commons.intersperse(phone, 0)
35
+ tone = commons.intersperse(tone, 0)
36
+ language = commons.intersperse(language, 0)
37
+ for i in range(len(word2ph)):
38
+ word2ph[i] = word2ph[i] * 2
39
+ word2ph[0] += 1
40
+
41
+ bert_path = wav_path.replace(".WAV", ".wav").replace(".wav", ".bert.pt")
42
+
43
+ try:
44
+ bert = torch.load(bert_path)
45
+ assert bert.shape[-1] == len(phone)
46
+ except Exception:
47
+ bert = get_bert(text, word2ph, language_str, device)
48
+ assert bert.shape[-1] == len(phone)
49
+ torch.save(bert, bert_path)
50
+
51
+
52
+ preprocess_text_config = config.preprocess_text_config
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser()
56
+ parser.add_argument(
57
+ "-c", "--config", type=str, default=config.bert_gen_config.config_path
58
+ )
59
+ parser.add_argument(
60
+ "--num_processes", type=int, default=config.bert_gen_config.num_processes
61
+ )
62
+ args, _ = parser.parse_known_args()
63
+ config_path = args.config
64
+ hps = utils.get_hparams_from_file(config_path)
65
+ lines = []
66
+ with open(hps.data.training_files, encoding="utf-8") as f:
67
+ lines.extend(f.readlines())
68
+
69
+ with open(hps.data.validation_files, encoding="utf-8") as f:
70
+ lines.extend(f.readlines())
71
+ add_blank = [hps.data.add_blank] * len(lines)
72
+
73
+ if len(lines) != 0:
74
+ num_processes = args.num_processes
75
+ with Pool(processes=num_processes) as pool:
76
+ for _ in tqdm(
77
+ pool.imap_unordered(process_line, zip(lines, add_blank)),
78
+ total=len(lines),
79
+ file=sys.stdout,
80
+ ):
81
+ # 这里是缩进的代码块,表示循环体
82
+ pass # 使用pass语句作为占位符
83
+
84
+ print(f"bert.pt is generated! total: {len(lines)} bert.pt files.")
commons.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ gather_indices = ids_str.view(x.size(0), 1, 1).repeat(
50
+ 1, x.size(1), 1
51
+ ) + torch.arange(segment_size, device=x.device)
52
+ return torch.gather(x, 2, gather_indices)
53
+
54
+
55
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
56
+ b, d, t = x.size()
57
+ if x_lengths is None:
58
+ x_lengths = t
59
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
60
+ ids_str = (torch.rand([b], device=x.device) * ids_str_max).to(dtype=torch.long)
61
+ ret = slice_segments(x, ids_str, segment_size)
62
+ return ret, ids_str
63
+
64
+
65
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
66
+ position = torch.arange(length, dtype=torch.float)
67
+ num_timescales = channels // 2
68
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
69
+ num_timescales - 1
70
+ )
71
+ inv_timescales = min_timescale * torch.exp(
72
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
73
+ )
74
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
75
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
76
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
77
+ signal = signal.view(1, channels, length)
78
+ return signal
79
+
80
+
81
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
82
+ b, channels, length = x.size()
83
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
84
+ return x + signal.to(dtype=x.dtype, device=x.device)
85
+
86
+
87
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
88
+ b, channels, length = x.size()
89
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
90
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
91
+
92
+
93
+ def subsequent_mask(length):
94
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
95
+ return mask
96
+
97
+
98
+ @torch.jit.script
99
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
100
+ n_channels_int = n_channels[0]
101
+ in_act = input_a + input_b
102
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
103
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
104
+ acts = t_act * s_act
105
+ return acts
106
+
107
+
108
+ def shift_1d(x):
109
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
110
+ return x
111
+
112
+
113
+ def sequence_mask(length, max_length=None):
114
+ if max_length is None:
115
+ max_length = length.max()
116
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
117
+ return x.unsqueeze(0) < length.unsqueeze(1)
118
+
119
+
120
+ def generate_path(duration, mask):
121
+ """
122
+ duration: [b, 1, t_x]
123
+ mask: [b, 1, t_y, t_x]
124
+ """
125
+
126
+ b, _, t_y, t_x = mask.shape
127
+ cum_duration = torch.cumsum(duration, -1)
128
+
129
+ cum_duration_flat = cum_duration.view(b * t_x)
130
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
131
+ path = path.view(b, t_x, t_y)
132
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
133
+ path = path.unsqueeze(1).transpose(2, 3) * mask
134
+ return path
135
+
136
+
137
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
138
+ if isinstance(parameters, torch.Tensor):
139
+ parameters = [parameters]
140
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
141
+ norm_type = float(norm_type)
142
+ if clip_value is not None:
143
+ clip_value = float(clip_value)
144
+
145
+ total_norm = 0
146
+ for p in parameters:
147
+ param_norm = p.grad.data.norm(norm_type)
148
+ total_norm += param_norm.item() ** norm_type
149
+ if clip_value is not None:
150
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
151
+ total_norm = total_norm ** (1.0 / norm_type)
152
+ return total_norm
config.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Desc: 全局配置文件读取
3
+ """
4
+ import argparse
5
+ import os
6
+ import shutil
7
+ from typing import Dict, List
8
+
9
+ import yaml
10
+
11
+ from common.log import logger
12
+
13
+
14
+ class Resample_config:
15
+ """重采样配置"""
16
+
17
+ def __init__(self, in_dir: str, out_dir: str, sampling_rate: int = 44100):
18
+ self.sampling_rate: int = sampling_rate # 目标采样率
19
+ self.in_dir: str = in_dir # 待处理音频目录路径
20
+ self.out_dir: str = out_dir # 重采样输出路径
21
+
22
+ @classmethod
23
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
24
+ """从字典中生成实例"""
25
+
26
+ # 不检查路径是否有效,此逻辑在resample.py中处理
27
+ data["in_dir"] = os.path.join(dataset_path, data["in_dir"])
28
+ data["out_dir"] = os.path.join(dataset_path, data["out_dir"])
29
+
30
+ return cls(**data)
31
+
32
+
33
+ class Preprocess_text_config:
34
+ """数据预处理配置"""
35
+
36
+ def __init__(
37
+ self,
38
+ transcription_path: str,
39
+ cleaned_path: str,
40
+ train_path: str,
41
+ val_path: str,
42
+ config_path: str,
43
+ val_per_lang: int = 5,
44
+ max_val_total: int = 10000,
45
+ clean: bool = True,
46
+ ):
47
+ self.transcription_path: str = transcription_path # 原始文本文件路径,文本格式应为{wav_path}|{speaker_name}|{language}|{text}。
48
+ self.cleaned_path: str = cleaned_path # 数据清洗后文本路径,可以不填。不填则将在原始文本目录生成
49
+ self.train_path: str = train_path # 训练集路径,可以不填。不填则将在原始文本目录生成
50
+ self.val_path: str = val_path # 验证集路径,可以不填。不填则将在原始文本目录生成
51
+ self.config_path: str = config_path # 配置文件路径
52
+ self.val_per_lang: int = val_per_lang # 每个speaker的验证集条数
53
+ self.max_val_total: int = max_val_total # 验证集最大条数,多于的会被截断并放到训练集中
54
+ self.clean: bool = clean # 是否进行数据清洗
55
+
56
+ @classmethod
57
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
58
+ """从字典中生成实例"""
59
+
60
+ data["transcription_path"] = os.path.join(
61
+ dataset_path, data["transcription_path"]
62
+ )
63
+ if data["cleaned_path"] == "" or data["cleaned_path"] is None:
64
+ data["cleaned_path"] = None
65
+ else:
66
+ data["cleaned_path"] = os.path.join(dataset_path, data["cleaned_path"])
67
+ data["train_path"] = os.path.join(dataset_path, data["train_path"])
68
+ data["val_path"] = os.path.join(dataset_path, data["val_path"])
69
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
70
+
71
+ return cls(**data)
72
+
73
+
74
+ class Bert_gen_config:
75
+ """bert_gen 配置"""
76
+
77
+ def __init__(
78
+ self,
79
+ config_path: str,
80
+ num_processes: int = 2,
81
+ device: str = "cuda",
82
+ use_multi_device: bool = False,
83
+ ):
84
+ self.config_path = config_path
85
+ self.num_processes = num_processes
86
+ self.device = device
87
+ self.use_multi_device = use_multi_device
88
+
89
+ @classmethod
90
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
91
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
92
+
93
+ return cls(**data)
94
+
95
+
96
+ class Style_gen_config:
97
+ """style_gen 配置"""
98
+
99
+ def __init__(
100
+ self,
101
+ config_path: str,
102
+ num_processes: int = 4,
103
+ device: str = "cuda",
104
+ ):
105
+ self.config_path = config_path
106
+ self.num_processes = num_processes
107
+ self.device = device
108
+
109
+ @classmethod
110
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
111
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
112
+
113
+ return cls(**data)
114
+
115
+
116
+ class Train_ms_config:
117
+ """训练配置"""
118
+
119
+ def __init__(
120
+ self,
121
+ config_path: str,
122
+ env: Dict[str, any],
123
+ # base: Dict[str, any],
124
+ model_dir: str,
125
+ num_workers: int,
126
+ spec_cache: bool,
127
+ keep_ckpts: int,
128
+ ):
129
+ self.env = env # 需要加载的环境变量
130
+ # self.base = base # 底模配置
131
+ self.model_dir = model_dir # 训练模型存储目录,该路径为相对于dataset_path的路径,而非项目根目录
132
+ self.config_path = config_path # 配置文件路径
133
+ self.num_workers = num_workers # worker数量
134
+ self.spec_cache = spec_cache # 是否启用spec缓存
135
+ self.keep_ckpts = keep_ckpts # ckpt数量
136
+
137
+ @classmethod
138
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
139
+ # data["model"] = os.path.join(dataset_path, data["model"])
140
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
141
+
142
+ return cls(**data)
143
+
144
+
145
+ class Webui_config:
146
+ """webui 配置"""
147
+
148
+ def __init__(
149
+ self,
150
+ device: str,
151
+ model: str,
152
+ config_path: str,
153
+ language_identification_library: str,
154
+ port: int = 7860,
155
+ share: bool = False,
156
+ debug: bool = False,
157
+ ):
158
+ self.device: str = device
159
+ self.model: str = model # 端口号
160
+ self.config_path: str = config_path # 是否公开部署,对外网开放
161
+ self.port: int = port # 是否开启debug模式
162
+ self.share: bool = share # 模型路径
163
+ self.debug: bool = debug # 配置文件路径
164
+ self.language_identification_library: str = (
165
+ language_identification_library # 语种识别库
166
+ )
167
+
168
+ @classmethod
169
+ def from_dict(cls, dataset_path: str, data: Dict[str, any]):
170
+ data["config_path"] = os.path.join(dataset_path, data["config_path"])
171
+ data["model"] = os.path.join(dataset_path, data["model"])
172
+ return cls(**data)
173
+
174
+
175
+ class Server_config:
176
+ def __init__(
177
+ self,
178
+ port: int = 5000,
179
+ device: str = "cuda",
180
+ limit: int = 100,
181
+ language: str = "JP",
182
+ origins: List[str] = None,
183
+ ):
184
+ self.port: int = port
185
+ self.device: str = device
186
+ self.language: str = language
187
+ self.limit: int = limit
188
+ self.origins: List[str] = origins
189
+
190
+ @classmethod
191
+ def from_dict(cls, data: Dict[str, any]):
192
+ return cls(**data)
193
+
194
+
195
+ class Translate_config:
196
+ """翻译api配置"""
197
+
198
+ def __init__(self, app_key: str, secret_key: str):
199
+ self.app_key = app_key
200
+ self.secret_key = secret_key
201
+
202
+ @classmethod
203
+ def from_dict(cls, data: Dict[str, any]):
204
+ return cls(**data)
205
+
206
+
207
+ class Config:
208
+ def __init__(self, config_path: str, path_config: dict[str, str]):
209
+ if not os.path.isfile(config_path) and os.path.isfile("default_config.yml"):
210
+ shutil.copy(src="default_config.yml", dst=config_path)
211
+ logger.info(
212
+ f"A configuration file {config_path} has been generated based on the default configuration file default_config.yml."
213
+ )
214
+ logger.info(
215
+ "If you have no special needs, please do not modify default_config.yml."
216
+ )
217
+ # sys.exit(0)
218
+ with open(file=config_path, mode="r", encoding="utf-8") as file:
219
+ yaml_config: Dict[str, any] = yaml.safe_load(file.read())
220
+ model_name: str = yaml_config["model_name"]
221
+ self.model_name: str = model_name
222
+ if "dataset_path" in yaml_config:
223
+ dataset_path = yaml_config["dataset_path"]
224
+ else:
225
+ dataset_path = os.path.join(path_config["dataset_root"], model_name)
226
+ self.dataset_path: str = dataset_path
227
+ self.assets_root: str = path_config["assets_root"]
228
+ self.out_dir = os.path.join(self.assets_root, model_name)
229
+ self.resample_config: Resample_config = Resample_config.from_dict(
230
+ dataset_path, yaml_config["resample"]
231
+ )
232
+ self.preprocess_text_config: Preprocess_text_config = (
233
+ Preprocess_text_config.from_dict(
234
+ dataset_path, yaml_config["preprocess_text"]
235
+ )
236
+ )
237
+ self.bert_gen_config: Bert_gen_config = Bert_gen_config.from_dict(
238
+ dataset_path, yaml_config["bert_gen"]
239
+ )
240
+ self.style_gen_config: Style_gen_config = Style_gen_config.from_dict(
241
+ dataset_path, yaml_config["style_gen"]
242
+ )
243
+ self.train_ms_config: Train_ms_config = Train_ms_config.from_dict(
244
+ dataset_path, yaml_config["train_ms"]
245
+ )
246
+ self.webui_config: Webui_config = Webui_config.from_dict(
247
+ dataset_path, yaml_config["webui"]
248
+ )
249
+ self.server_config: Server_config = Server_config.from_dict(
250
+ yaml_config["server"]
251
+ )
252
+ # self.translate_config: Translate_config = Translate_config.from_dict(
253
+ # yaml_config["translate"]
254
+ # )
255
+
256
+
257
+ with open(os.path.join("configs", "paths.yml"), "r", encoding="utf-8") as f:
258
+ path_config: dict[str, str] = yaml.safe_load(f.read())
259
+ # Should contain the following keys:
260
+ # - dataset_root: the root directory of the dataset, default to "Data"
261
+ # - assets_root: the root directory of the assets, default to "model_assets"
262
+
263
+
264
+ try:
265
+ config = Config("config.yml", path_config)
266
+ except (TypeError, KeyError):
267
+ logger.warning("Old config.yml found. Replace it with default_config.yml.")
268
+ shutil.copy(src="default_config.yml", dst="config.yml")
269
+ config = Config("config.yml", path_config)
config.yml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bert_gen:
2
+ config_path: config.json
3
+ device: cpu
4
+ num_processes: 2
5
+ use_multi_device: false
6
+ dataset_path: Data\model_name
7
+ model_name: model_name
8
+ preprocess_text:
9
+ clean: true
10
+ cleaned_path: ''
11
+ config_path: config.json
12
+ max_val_total: 12
13
+ train_path: train.list
14
+ transcription_path: esd.list
15
+ val_path: val.list
16
+ val_per_lang: 4
17
+ resample:
18
+ in_dir: raw
19
+ out_dir: wavs
20
+ sampling_rate: 44100
21
+ server:
22
+ device: cuda
23
+ language: JP
24
+ limit: 100
25
+ origins:
26
+ - '*'
27
+ port: 5000
28
+ style_gen:
29
+ config_path: config.json
30
+ device: cpu
31
+ num_processes: 4
32
+ train_ms:
33
+ config_path: config.json
34
+ env:
35
+ LOCAL_RANK: 0
36
+ MASTER_ADDR: localhost
37
+ MASTER_PORT: 10086
38
+ RANK: 0
39
+ WORLD_SIZE: 1
40
+ keep_ckpts: 1
41
+ model_dir: models
42
+ num_workers: 16
43
+ spec_cache: true
44
+ webui:
45
+ config_path: config.json
46
+ debug: false
47
+ device: cuda
48
+ language_identification_library: langid
49
+ model: models/G_8000.pth
50
+ port: 7860
51
+ share: false
data_utils.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import torch
4
+ import torch.utils.data
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from tools.log import logger
8
+ import commons
9
+ from mel_processing import spectrogram_torch, mel_spectrogram_torch
10
+ from utils import load_wav_to_torch, load_filepaths_and_text
11
+ from text import cleaned_text_to_sequence
12
+ from config import config
13
+
14
+ """Multi speaker version"""
15
+
16
+
17
+ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
18
+ """
19
+ 1) loads audio, speaker_id, text pairs
20
+ 2) normalizes text and converts them to sequences of integers
21
+ 3) computes spectrograms from audio files.
22
+ """
23
+
24
+ def __init__(self, audiopaths_sid_text, hparams):
25
+ self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
26
+ self.max_wav_value = hparams.max_wav_value
27
+ self.sampling_rate = hparams.sampling_rate
28
+ self.filter_length = hparams.filter_length
29
+ self.hop_length = hparams.hop_length
30
+ self.win_length = hparams.win_length
31
+ self.sampling_rate = hparams.sampling_rate
32
+ self.spk_map = hparams.spk2id
33
+ self.hparams = hparams
34
+
35
+ self.use_mel_spec_posterior = getattr(
36
+ hparams, "use_mel_posterior_encoder", False
37
+ )
38
+ if self.use_mel_spec_posterior:
39
+ self.n_mel_channels = getattr(hparams, "n_mel_channels", 80)
40
+
41
+ self.cleaned_text = getattr(hparams, "cleaned_text", False)
42
+
43
+ self.add_blank = hparams.add_blank
44
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
45
+ self.max_text_len = getattr(hparams, "max_text_len", 384)
46
+
47
+ random.seed(1234)
48
+ random.shuffle(self.audiopaths_sid_text)
49
+ self._filter()
50
+
51
+ def _filter(self):
52
+ """
53
+ Filter text & store spec lengths
54
+ """
55
+ # Store spectrogram lengths for Bucketing
56
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
57
+ # spec_length = wav_length // hop_length
58
+
59
+ audiopaths_sid_text_new = []
60
+ lengths = []
61
+ skipped = 0
62
+ logger.info("Init dataset...")
63
+ for _id, spk, language, text, phones, tone, word2ph in tqdm(
64
+ self.audiopaths_sid_text
65
+ ):
66
+ audiopath = f"{_id}"
67
+ if self.min_text_len <= len(phones) and len(phones) <= self.max_text_len:
68
+ phones = phones.split(" ")
69
+ tone = [int(i) for i in tone.split(" ")]
70
+ word2ph = [int(i) for i in word2ph.split(" ")]
71
+ audiopaths_sid_text_new.append(
72
+ [audiopath, spk, language, text, phones, tone, word2ph]
73
+ )
74
+ lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length))
75
+ else:
76
+ skipped += 1
77
+ logger.info(
78
+ "skipped: "
79
+ + str(skipped)
80
+ + ", total: "
81
+ + str(len(self.audiopaths_sid_text))
82
+ )
83
+ self.audiopaths_sid_text = audiopaths_sid_text_new
84
+ self.lengths = lengths
85
+
86
+ def get_audio_text_speaker_pair(self, audiopath_sid_text):
87
+ # separate filename, speaker_id and text
88
+ audiopath, sid, language, text, phones, tone, word2ph = audiopath_sid_text
89
+
90
+ bert, ja_bert, en_bert, phones, tone, language = self.get_text(
91
+ text, word2ph, phones, tone, language, audiopath
92
+ )
93
+
94
+ spec, wav = self.get_audio(audiopath)
95
+ sid = torch.LongTensor([int(self.spk_map[sid])])
96
+ style_vec = torch.FloatTensor(np.load(f"{audiopath}.npy"))
97
+ return (
98
+ phones,
99
+ spec,
100
+ wav,
101
+ sid,
102
+ tone,
103
+ language,
104
+ bert,
105
+ ja_bert,
106
+ en_bert,
107
+ style_vec,
108
+ )
109
+
110
+ def get_audio(self, filename):
111
+ audio, sampling_rate = load_wav_to_torch(filename)
112
+ if sampling_rate != self.sampling_rate:
113
+ raise ValueError(
114
+ "{} {} SR doesn't match target {} SR".format(
115
+ filename, sampling_rate, self.sampling_rate
116
+ )
117
+ )
118
+ audio_norm = audio / self.max_wav_value
119
+ audio_norm = audio_norm.unsqueeze(0)
120
+ spec_filename = filename.replace(".wav", ".spec.pt")
121
+ if self.use_mel_spec_posterior:
122
+ spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
123
+ try:
124
+ spec = torch.load(spec_filename)
125
+ except:
126
+ if self.use_mel_spec_posterior:
127
+ spec = mel_spectrogram_torch(
128
+ audio_norm,
129
+ self.filter_length,
130
+ self.n_mel_channels,
131
+ self.sampling_rate,
132
+ self.hop_length,
133
+ self.win_length,
134
+ self.hparams.mel_fmin,
135
+ self.hparams.mel_fmax,
136
+ center=False,
137
+ )
138
+ else:
139
+ spec = spectrogram_torch(
140
+ audio_norm,
141
+ self.filter_length,
142
+ self.sampling_rate,
143
+ self.hop_length,
144
+ self.win_length,
145
+ center=False,
146
+ )
147
+ spec = torch.squeeze(spec, 0)
148
+ if config.train_ms_config.spec_cache:
149
+ torch.save(spec, spec_filename)
150
+ return spec, audio_norm
151
+
152
+ def get_text(self, text, word2ph, phone, tone, language_str, wav_path):
153
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
154
+ if self.add_blank:
155
+ phone = commons.intersperse(phone, 0)
156
+ tone = commons.intersperse(tone, 0)
157
+ language = commons.intersperse(language, 0)
158
+ for i in range(len(word2ph)):
159
+ word2ph[i] = word2ph[i] * 2
160
+ word2ph[0] += 1
161
+ bert_path = wav_path.replace(".wav", ".bert.pt")
162
+ try:
163
+ bert_ori = torch.load(bert_path)
164
+ assert bert_ori.shape[-1] == len(phone)
165
+ except Exception as e:
166
+ logger.warning("Bert load Failed")
167
+ logger.warning(e)
168
+
169
+ if language_str == "ZH":
170
+ bert = bert_ori
171
+ ja_bert = torch.zeros(1024, len(phone))
172
+ en_bert = torch.zeros(1024, len(phone))
173
+ elif language_str == "JP":
174
+ bert = torch.zeros(1024, len(phone))
175
+ ja_bert = bert_ori
176
+ en_bert = torch.zeros(1024, len(phone))
177
+ elif language_str == "EN":
178
+ bert = torch.zeros(1024, len(phone))
179
+ ja_bert = torch.zeros(1024, len(phone))
180
+ en_bert = bert_ori
181
+ phone = torch.LongTensor(phone)
182
+ tone = torch.LongTensor(tone)
183
+ language = torch.LongTensor(language)
184
+ return bert, ja_bert, en_bert, phone, tone, language
185
+
186
+ def get_sid(self, sid):
187
+ sid = torch.LongTensor([int(sid)])
188
+ return sid
189
+
190
+ def __getitem__(self, index):
191
+ return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
192
+
193
+ def __len__(self):
194
+ return len(self.audiopaths_sid_text)
195
+
196
+
197
+ class TextAudioSpeakerCollate:
198
+ """Zero-pads model inputs and targets"""
199
+
200
+ def __init__(self, return_ids=False):
201
+ self.return_ids = return_ids
202
+
203
+ def __call__(self, batch):
204
+ """Collate's training batch from normalized text, audio and speaker identities
205
+ PARAMS
206
+ ------
207
+ batch: [text_normalized, spec_normalized, wav_normalized, sid]
208
+ """
209
+ # Right zero-pad all one-hot text sequences to max input length
210
+ _, ids_sorted_decreasing = torch.sort(
211
+ torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True
212
+ )
213
+
214
+ max_text_len = max([len(x[0]) for x in batch])
215
+ max_spec_len = max([x[1].size(1) for x in batch])
216
+ max_wav_len = max([x[2].size(1) for x in batch])
217
+
218
+ text_lengths = torch.LongTensor(len(batch))
219
+ spec_lengths = torch.LongTensor(len(batch))
220
+ wav_lengths = torch.LongTensor(len(batch))
221
+ sid = torch.LongTensor(len(batch))
222
+
223
+ text_padded = torch.LongTensor(len(batch), max_text_len)
224
+ tone_padded = torch.LongTensor(len(batch), max_text_len)
225
+ language_padded = torch.LongTensor(len(batch), max_text_len)
226
+ bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
227
+ ja_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
228
+ en_bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len)
229
+ style_vec = torch.FloatTensor(len(batch), 256)
230
+
231
+ spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
232
+ wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
233
+ text_padded.zero_()
234
+ tone_padded.zero_()
235
+ language_padded.zero_()
236
+ spec_padded.zero_()
237
+ wav_padded.zero_()
238
+ bert_padded.zero_()
239
+ ja_bert_padded.zero_()
240
+ en_bert_padded.zero_()
241
+ style_vec.zero_()
242
+
243
+ for i in range(len(ids_sorted_decreasing)):
244
+ row = batch[ids_sorted_decreasing[i]]
245
+
246
+ text = row[0]
247
+ text_padded[i, : text.size(0)] = text
248
+ text_lengths[i] = text.size(0)
249
+
250
+ spec = row[1]
251
+ spec_padded[i, :, : spec.size(1)] = spec
252
+ spec_lengths[i] = spec.size(1)
253
+
254
+ wav = row[2]
255
+ wav_padded[i, :, : wav.size(1)] = wav
256
+ wav_lengths[i] = wav.size(1)
257
+
258
+ sid[i] = row[3]
259
+
260
+ tone = row[4]
261
+ tone_padded[i, : tone.size(0)] = tone
262
+
263
+ language = row[5]
264
+ language_padded[i, : language.size(0)] = language
265
+
266
+ bert = row[6]
267
+ bert_padded[i, :, : bert.size(1)] = bert
268
+
269
+ ja_bert = row[7]
270
+ ja_bert_padded[i, :, : ja_bert.size(1)] = ja_bert
271
+
272
+ en_bert = row[8]
273
+ en_bert_padded[i, :, : en_bert.size(1)] = en_bert
274
+
275
+ style_vec[i, :] = row[9]
276
+
277
+ return (
278
+ text_padded,
279
+ text_lengths,
280
+ spec_padded,
281
+ spec_lengths,
282
+ wav_padded,
283
+ wav_lengths,
284
+ sid,
285
+ tone_padded,
286
+ language_padded,
287
+ bert_padded,
288
+ ja_bert_padded,
289
+ en_bert_padded,
290
+ style_vec,
291
+ )
292
+
293
+
294
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
295
+ """
296
+ Maintain similar input lengths in a batch.
297
+ Length groups are specified by boundaries.
298
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
299
+
300
+ It removes samples which are not included in the boundaries.
301
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
302
+ """
303
+
304
+ def __init__(
305
+ self,
306
+ dataset,
307
+ batch_size,
308
+ boundaries,
309
+ num_replicas=None,
310
+ rank=None,
311
+ shuffle=True,
312
+ ):
313
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
314
+ self.lengths = dataset.lengths
315
+ self.batch_size = batch_size
316
+ self.boundaries = boundaries
317
+
318
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
319
+ logger.info(f"Bucket info: {self.num_samples_per_bucket}")
320
+ logger.info(
321
+ f"Unused samples: {len(self.lengths) - sum(self.num_samples_per_bucket)}"
322
+ )
323
+ self.total_size = sum(self.num_samples_per_bucket)
324
+ self.num_samples = self.total_size // self.num_replicas
325
+
326
+ def _create_buckets(self):
327
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
328
+ for i in range(len(self.lengths)):
329
+ length = self.lengths[i]
330
+ idx_bucket = self._bisect(length)
331
+ if idx_bucket != -1:
332
+ buckets[idx_bucket].append(i)
333
+
334
+ try:
335
+ for i in range(len(buckets) - 1, 0, -1):
336
+ if len(buckets[i]) == 0:
337
+ buckets.pop(i)
338
+ self.boundaries.pop(i + 1)
339
+ assert all(len(bucket) > 0 for bucket in buckets)
340
+ # When one bucket is not traversed
341
+ except Exception as e:
342
+ print("Bucket warning ", e)
343
+ for i in range(len(buckets) - 1, -1, -1):
344
+ if len(buckets[i]) == 0:
345
+ buckets.pop(i)
346
+ self.boundaries.pop(i + 1)
347
+
348
+ num_samples_per_bucket = []
349
+ for i in range(len(buckets)):
350
+ len_bucket = len(buckets[i])
351
+ total_batch_size = self.num_replicas * self.batch_size
352
+ rem = (
353
+ total_batch_size - (len_bucket % total_batch_size)
354
+ ) % total_batch_size
355
+ num_samples_per_bucket.append(len_bucket + rem)
356
+ return buckets, num_samples_per_bucket
357
+
358
+ def __iter__(self):
359
+ # deterministically shuffle based on epoch
360
+ g = torch.Generator()
361
+ g.manual_seed(self.epoch)
362
+
363
+ indices = []
364
+ if self.shuffle:
365
+ for bucket in self.buckets:
366
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
367
+ else:
368
+ for bucket in self.buckets:
369
+ indices.append(list(range(len(bucket))))
370
+
371
+ batches = []
372
+ for i in range(len(self.buckets)):
373
+ bucket = self.buckets[i]
374
+ len_bucket = len(bucket)
375
+ if len_bucket == 0:
376
+ continue
377
+ ids_bucket = indices[i]
378
+ num_samples_bucket = self.num_samples_per_bucket[i]
379
+
380
+ # add extra samples to make it evenly divisible
381
+ rem = num_samples_bucket - len_bucket
382
+ ids_bucket = (
383
+ ids_bucket
384
+ + ids_bucket * (rem // len_bucket)
385
+ + ids_bucket[: (rem % len_bucket)]
386
+ )
387
+
388
+ # subsample
389
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
390
+
391
+ # batching
392
+ for j in range(len(ids_bucket) // self.batch_size):
393
+ batch = [
394
+ bucket[idx]
395
+ for idx in ids_bucket[
396
+ j * self.batch_size : (j + 1) * self.batch_size
397
+ ]
398
+ ]
399
+ batches.append(batch)
400
+
401
+ if self.shuffle:
402
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
403
+ batches = [batches[i] for i in batch_ids]
404
+ self.batches = batches
405
+
406
+ assert len(self.batches) * self.batch_size == self.num_samples
407
+ return iter(self.batches)
408
+
409
+ def _bisect(self, x, lo=0, hi=None):
410
+ if hi is None:
411
+ hi = len(self.boundaries) - 1
412
+
413
+ if hi > lo:
414
+ mid = (hi + lo) // 2
415
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
416
+ return mid
417
+ elif x <= self.boundaries[mid]:
418
+ return self._bisect(x, lo, mid)
419
+ else:
420
+ return self._bisect(x, mid + 1, hi)
421
+ else:
422
+ return -1
423
+
424
+ def __len__(self):
425
+ return self.num_samples // self.batch_size
default_config.yml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Global configuration file for Bert-VITS2
2
+
3
+ model_name: "model_name"
4
+
5
+ out_dir: "model_assets"
6
+
7
+ # If you want to use a specific dataset path, uncomment the following line.
8
+ # Otherwise, the dataset path is `Data/{model_name}`.
9
+
10
+ # dataset_path: "your/dataset/path"
11
+
12
+ resample:
13
+ sampling_rate: 44100
14
+ in_dir: "audios/raw"
15
+ out_dir: "audios/wavs"
16
+
17
+ preprocess_text:
18
+ transcription_path: "filelists/esd.list"
19
+ cleaned_path: ""
20
+ train_path: "filelists/train.list"
21
+ val_path: "filelists/val.list"
22
+ config_path: "config.json"
23
+ val_per_lang: 4
24
+ max_val_total: 12
25
+ clean: true
26
+
27
+ bert_gen:
28
+ config_path: "config.json"
29
+ num_processes: 4
30
+ device: "cuda"
31
+ use_multi_device: false
32
+
33
+ style_gen:
34
+ config_path: "config.json"
35
+ num_processes: 4
36
+ device: "cuda"
37
+
38
+ train_ms:
39
+ env:
40
+ MASTER_ADDR: "localhost"
41
+ MASTER_PORT: 10086
42
+ WORLD_SIZE: 1
43
+ LOCAL_RANK: 0
44
+ RANK: 0
45
+ model: "models"
46
+ config_path: "config.json"
47
+ num_workers: 16
48
+ spec_cache: True
49
+ keep_ckpts: 1 # Set this to 0 to keep all checkpoints
50
+
51
+ webui:
52
+ # 推理设备
53
+ device: "cuda"
54
+ # 模型路径
55
+ model: "models/G_8000.pth"
56
+ # 配置文件路径
57
+ config_path: "config.json"
58
+ # 端口号
59
+ port: 7860
60
+ # 是否公开部署,对外网开放
61
+ share: false
62
+ # 是否开启debug模式
63
+ debug: false
64
+ # 语种识别库,可选langid, fastlid
65
+ language_identification_library: "langid"
66
+
67
+ # server_fastapi's config
68
+ # TODO: `server_fastapi.py` is not implemented yet for this version
69
+ server:
70
+ port: 5000
71
+ device: "cuda"
72
+ models:
73
+ - model: ""
74
+ config: ""
75
+ device: "cuda"
76
+ language: "ZH"
77
+ - model: ""
78
+ config: ""
79
+ device: "cpu"
80
+ language: "JP"
81
+ speakers: []
infer.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import commons
4
+ import utils
5
+ from models import SynthesizerTrn
6
+ from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra
7
+ from text import cleaned_text_to_sequence, get_bert
8
+ from text.cleaner import clean_text
9
+ from text.symbols import symbols
10
+ from common.log import logger
11
+
12
+
13
+ class InvalidToneError(ValueError):
14
+ pass
15
+
16
+
17
+ def get_net_g(model_path: str, version: str, device: str, hps):
18
+ if version.endswith("JP-Extra"):
19
+ logger.info("Using JP-Extra model")
20
+ net_g = SynthesizerTrnJPExtra(
21
+ len(symbols),
22
+ hps.data.filter_length // 2 + 1,
23
+ hps.train.segment_size // hps.data.hop_length,
24
+ n_speakers=hps.data.n_speakers,
25
+ **hps.model,
26
+ ).to(device)
27
+ else:
28
+ logger.info("Using normal model")
29
+ net_g = SynthesizerTrn(
30
+ len(symbols),
31
+ hps.data.filter_length // 2 + 1,
32
+ hps.train.segment_size // hps.data.hop_length,
33
+ n_speakers=hps.data.n_speakers,
34
+ **hps.model,
35
+ ).to(device)
36
+ net_g.state_dict()
37
+ _ = net_g.eval()
38
+ if model_path.endswith(".pth") or model_path.endswith(".pt"):
39
+ _ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
40
+ elif model_path.endswith(".safetensors"):
41
+ _ = utils.load_safetensors(model_path, net_g, True)
42
+ else:
43
+ raise ValueError(f"Unknown model format: {model_path}")
44
+ return net_g
45
+
46
+
47
+ def get_text(
48
+ text,
49
+ language_str,
50
+ hps,
51
+ device,
52
+ assist_text=None,
53
+ assist_text_weight=0.7,
54
+ given_tone=None,
55
+ ):
56
+ use_jp_extra = hps.version.endswith("JP-Extra")
57
+ norm_text, phone, tone, word2ph = clean_text(text, language_str, use_jp_extra)
58
+ if given_tone is not None:
59
+ if len(given_tone) != len(phone):
60
+ raise InvalidToneError(
61
+ f"Length of given_tone ({len(given_tone)}) != length of phone ({len(phone)})"
62
+ )
63
+ tone = given_tone
64
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
65
+
66
+ if hps.data.add_blank:
67
+ phone = commons.intersperse(phone, 0)
68
+ tone = commons.intersperse(tone, 0)
69
+ language = commons.intersperse(language, 0)
70
+ for i in range(len(word2ph)):
71
+ word2ph[i] = word2ph[i] * 2
72
+ word2ph[0] += 1
73
+ bert_ori = get_bert(
74
+ norm_text, word2ph, language_str, device, assist_text, assist_text_weight
75
+ )
76
+ del word2ph
77
+ assert bert_ori.shape[-1] == len(phone), phone
78
+
79
+ if language_str == "ZH":
80
+ bert = bert_ori
81
+ ja_bert = torch.zeros(1024, len(phone))
82
+ en_bert = torch.zeros(1024, len(phone))
83
+ elif language_str == "JP":
84
+ bert = torch.zeros(1024, len(phone))
85
+ ja_bert = bert_ori
86
+ en_bert = torch.zeros(1024, len(phone))
87
+ elif language_str == "EN":
88
+ bert = torch.zeros(1024, len(phone))
89
+ ja_bert = torch.zeros(1024, len(phone))
90
+ en_bert = bert_ori
91
+ else:
92
+ raise ValueError("language_str should be ZH, JP or EN")
93
+
94
+ assert bert.shape[-1] == len(
95
+ phone
96
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
97
+
98
+ phone = torch.LongTensor(phone)
99
+ tone = torch.LongTensor(tone)
100
+ language = torch.LongTensor(language)
101
+ return bert, ja_bert, en_bert, phone, tone, language
102
+
103
+
104
+ def infer(
105
+ text,
106
+ style_vec,
107
+ sdp_ratio,
108
+ noise_scale,
109
+ noise_scale_w,
110
+ length_scale,
111
+ sid: int, # In the original Bert-VITS2, its speaker_name: str, but here it's id
112
+ language,
113
+ hps,
114
+ net_g,
115
+ device,
116
+ skip_start=False,
117
+ skip_end=False,
118
+ assist_text=None,
119
+ assist_text_weight=0.7,
120
+ given_tone=None,
121
+ ):
122
+ is_jp_extra = hps.version.endswith("JP-Extra")
123
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
124
+ text,
125
+ language,
126
+ hps,
127
+ device,
128
+ assist_text=assist_text,
129
+ assist_text_weight=assist_text_weight,
130
+ given_tone=given_tone,
131
+ )
132
+ if skip_start:
133
+ phones = phones[3:]
134
+ tones = tones[3:]
135
+ lang_ids = lang_ids[3:]
136
+ bert = bert[:, 3:]
137
+ ja_bert = ja_bert[:, 3:]
138
+ en_bert = en_bert[:, 3:]
139
+ if skip_end:
140
+ phones = phones[:-2]
141
+ tones = tones[:-2]
142
+ lang_ids = lang_ids[:-2]
143
+ bert = bert[:, :-2]
144
+ ja_bert = ja_bert[:, :-2]
145
+ en_bert = en_bert[:, :-2]
146
+ with torch.no_grad():
147
+ x_tst = phones.to(device).unsqueeze(0)
148
+ tones = tones.to(device).unsqueeze(0)
149
+ lang_ids = lang_ids.to(device).unsqueeze(0)
150
+ bert = bert.to(device).unsqueeze(0)
151
+ ja_bert = ja_bert.to(device).unsqueeze(0)
152
+ en_bert = en_bert.to(device).unsqueeze(0)
153
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
154
+ style_vec = torch.from_numpy(style_vec).to(device).unsqueeze(0)
155
+ del phones
156
+ sid_tensor = torch.LongTensor([sid]).to(device)
157
+ if is_jp_extra:
158
+ output = net_g.infer(
159
+ x_tst,
160
+ x_tst_lengths,
161
+ sid_tensor,
162
+ tones,
163
+ lang_ids,
164
+ ja_bert,
165
+ style_vec=style_vec,
166
+ sdp_ratio=sdp_ratio,
167
+ noise_scale=noise_scale,
168
+ noise_scale_w=noise_scale_w,
169
+ length_scale=length_scale,
170
+ )
171
+ else:
172
+ output = net_g.infer(
173
+ x_tst,
174
+ x_tst_lengths,
175
+ sid_tensor,
176
+ tones,
177
+ lang_ids,
178
+ bert,
179
+ ja_bert,
180
+ en_bert,
181
+ style_vec=style_vec,
182
+ sdp_ratio=sdp_ratio,
183
+ noise_scale=noise_scale,
184
+ noise_scale_w=noise_scale_w,
185
+ length_scale=length_scale,
186
+ )
187
+ audio = output[0][0, 0].data.cpu().float().numpy()
188
+ del (
189
+ x_tst,
190
+ tones,
191
+ lang_ids,
192
+ bert,
193
+ x_tst_lengths,
194
+ sid_tensor,
195
+ ja_bert,
196
+ en_bert,
197
+ style_vec,
198
+ ) # , emo
199
+ if torch.cuda.is_available():
200
+ torch.cuda.empty_cache()
201
+ return audio
202
+
203
+
204
+ def infer_multilang(
205
+ text,
206
+ style_vec,
207
+ sdp_ratio,
208
+ noise_scale,
209
+ noise_scale_w,
210
+ length_scale,
211
+ sid,
212
+ language,
213
+ hps,
214
+ net_g,
215
+ device,
216
+ skip_start=False,
217
+ skip_end=False,
218
+ ):
219
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
220
+ # emo = get_emo_(reference_audio, emotion, sid)
221
+ # if isinstance(reference_audio, np.ndarray):
222
+ # emo = get_clap_audio_feature(reference_audio, device)
223
+ # else:
224
+ # emo = get_clap_text_feature(emotion, device)
225
+ # emo = torch.squeeze(emo, dim=1)
226
+ for idx, (txt, lang) in enumerate(zip(text, language)):
227
+ _skip_start = (idx != 0) or (skip_start and idx == 0)
228
+ _skip_end = (idx != len(language) - 1) or skip_end
229
+ (
230
+ temp_bert,
231
+ temp_ja_bert,
232
+ temp_en_bert,
233
+ temp_phones,
234
+ temp_tones,
235
+ temp_lang_ids,
236
+ ) = get_text(txt, lang, hps, device)
237
+ if _skip_start:
238
+ temp_bert = temp_bert[:, 3:]
239
+ temp_ja_bert = temp_ja_bert[:, 3:]
240
+ temp_en_bert = temp_en_bert[:, 3:]
241
+ temp_phones = temp_phones[3:]
242
+ temp_tones = temp_tones[3:]
243
+ temp_lang_ids = temp_lang_ids[3:]
244
+ if _skip_end:
245
+ temp_bert = temp_bert[:, :-2]
246
+ temp_ja_bert = temp_ja_bert[:, :-2]
247
+ temp_en_bert = temp_en_bert[:, :-2]
248
+ temp_phones = temp_phones[:-2]
249
+ temp_tones = temp_tones[:-2]
250
+ temp_lang_ids = temp_lang_ids[:-2]
251
+ bert.append(temp_bert)
252
+ ja_bert.append(temp_ja_bert)
253
+ en_bert.append(temp_en_bert)
254
+ phones.append(temp_phones)
255
+ tones.append(temp_tones)
256
+ lang_ids.append(temp_lang_ids)
257
+ bert = torch.concatenate(bert, dim=1)
258
+ ja_bert = torch.concatenate(ja_bert, dim=1)
259
+ en_bert = torch.concatenate(en_bert, dim=1)
260
+ phones = torch.concatenate(phones, dim=0)
261
+ tones = torch.concatenate(tones, dim=0)
262
+ lang_ids = torch.concatenate(lang_ids, dim=0)
263
+ with torch.no_grad():
264
+ x_tst = phones.to(device).unsqueeze(0)
265
+ tones = tones.to(device).unsqueeze(0)
266
+ lang_ids = lang_ids.to(device).unsqueeze(0)
267
+ bert = bert.to(device).unsqueeze(0)
268
+ ja_bert = ja_bert.to(device).unsqueeze(0)
269
+ en_bert = en_bert.to(device).unsqueeze(0)
270
+ # emo = emo.to(device).unsqueeze(0)
271
+ x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
272
+ del phones
273
+ speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
274
+ audio = (
275
+ net_g.infer(
276
+ x_tst,
277
+ x_tst_lengths,
278
+ speakers,
279
+ tones,
280
+ lang_ids,
281
+ bert,
282
+ ja_bert,
283
+ en_bert,
284
+ style_vec=style_vec,
285
+ sdp_ratio=sdp_ratio,
286
+ noise_scale=noise_scale,
287
+ noise_scale_w=noise_scale_w,
288
+ length_scale=length_scale,
289
+ )[0][0, 0]
290
+ .data.cpu()
291
+ .float()
292
+ .numpy()
293
+ )
294
+ del (
295
+ x_tst,
296
+ tones,
297
+ lang_ids,
298
+ bert,
299
+ x_tst_lengths,
300
+ speakers,
301
+ ja_bert,
302
+ en_bert,
303
+ ) # , emo
304
+ if torch.cuda.is_available():
305
+ torch.cuda.empty_cache()
306
+ return audio
losses.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from transformers import AutoModel
4
+
5
+
6
+ def feature_loss(fmap_r, fmap_g):
7
+ loss = 0
8
+ for dr, dg in zip(fmap_r, fmap_g):
9
+ for rl, gl in zip(dr, dg):
10
+ rl = rl.float().detach()
11
+ gl = gl.float()
12
+ loss += torch.mean(torch.abs(rl - gl))
13
+
14
+ return loss * 2
15
+
16
+
17
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
18
+ loss = 0
19
+ r_losses = []
20
+ g_losses = []
21
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
22
+ dr = dr.float()
23
+ dg = dg.float()
24
+ r_loss = torch.mean((1 - dr) ** 2)
25
+ g_loss = torch.mean(dg**2)
26
+ loss += r_loss + g_loss
27
+ r_losses.append(r_loss.item())
28
+ g_losses.append(g_loss.item())
29
+
30
+ return loss, r_losses, g_losses
31
+
32
+
33
+ def generator_loss(disc_outputs):
34
+ loss = 0
35
+ gen_losses = []
36
+ for dg in disc_outputs:
37
+ dg = dg.float()
38
+ l = torch.mean((1 - dg) ** 2)
39
+ gen_losses.append(l)
40
+ loss += l
41
+
42
+ return loss, gen_losses
43
+
44
+
45
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
46
+ """
47
+ z_p, logs_q: [b, h, t_t]
48
+ m_p, logs_p: [b, h, t_t]
49
+ """
50
+ z_p = z_p.float()
51
+ logs_q = logs_q.float()
52
+ m_p = m_p.float()
53
+ logs_p = logs_p.float()
54
+ z_mask = z_mask.float()
55
+
56
+ kl = logs_p - logs_q - 0.5
57
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
58
+ kl = torch.sum(kl * z_mask)
59
+ l = kl / torch.sum(z_mask)
60
+ return l
61
+
62
+
63
+ class WavLMLoss(torch.nn.Module):
64
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
65
+ super(WavLMLoss, self).__init__()
66
+ self.wavlm = AutoModel.from_pretrained(model)
67
+ self.wd = wd
68
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
69
+ self.wavlm.eval()
70
+ for param in self.wavlm.parameters():
71
+ param.requires_grad = False
72
+
73
+ def forward(self, wav, y_rec):
74
+ with torch.no_grad():
75
+ wav_16 = self.resample(wav)
76
+ wav_embeddings = self.wavlm(
77
+ input_values=wav_16, output_hidden_states=True
78
+ ).hidden_states
79
+ y_rec_16 = self.resample(y_rec)
80
+ y_rec_embeddings = self.wavlm(
81
+ input_values=y_rec_16.squeeze(), output_hidden_states=True
82
+ ).hidden_states
83
+
84
+ floss = 0
85
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
86
+ floss += torch.mean(torch.abs(er - eg))
87
+
88
+ return floss.mean()
89
+
90
+ def generator(self, y_rec):
91
+ y_rec_16 = self.resample(y_rec)
92
+ y_rec_embeddings = self.wavlm(
93
+ input_values=y_rec_16, output_hidden_states=True
94
+ ).hidden_states
95
+ y_rec_embeddings = (
96
+ torch.stack(y_rec_embeddings, dim=1)
97
+ .transpose(-1, -2)
98
+ .flatten(start_dim=1, end_dim=2)
99
+ )
100
+ y_df_hat_g = self.wd(y_rec_embeddings)
101
+ loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
102
+
103
+ return loss_gen
104
+
105
+ def discriminator(self, wav, y_rec):
106
+ with torch.no_grad():
107
+ wav_16 = self.resample(wav)
108
+ wav_embeddings = self.wavlm(
109
+ input_values=wav_16, output_hidden_states=True
110
+ ).hidden_states
111
+ y_rec_16 = self.resample(y_rec)
112
+ y_rec_embeddings = self.wavlm(
113
+ input_values=y_rec_16, output_hidden_states=True
114
+ ).hidden_states
115
+
116
+ y_embeddings = (
117
+ torch.stack(wav_embeddings, dim=1)
118
+ .transpose(-1, -2)
119
+ .flatten(start_dim=1, end_dim=2)
120
+ )
121
+ y_rec_embeddings = (
122
+ torch.stack(y_rec_embeddings, dim=1)
123
+ .transpose(-1, -2)
124
+ .flatten(start_dim=1, end_dim=2)
125
+ )
126
+
127
+ y_d_rs = self.wd(y_embeddings)
128
+ y_d_gs = self.wd(y_rec_embeddings)
129
+
130
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
131
+
132
+ r_loss = torch.mean((1 - y_df_hat_r) ** 2)
133
+ g_loss = torch.mean((y_df_hat_g) ** 2)
134
+
135
+ loss_disc_f = r_loss + g_loss
136
+
137
+ return loss_disc_f.mean()
138
+
139
+ def discriminator_forward(self, wav):
140
+ with torch.no_grad():
141
+ wav_16 = self.resample(wav)
142
+ wav_embeddings = self.wavlm(
143
+ input_values=wav_16, output_hidden_states=True
144
+ ).hidden_states
145
+ y_embeddings = (
146
+ torch.stack(wav_embeddings, dim=1)
147
+ .transpose(-1, -2)
148
+ .flatten(start_dim=1, end_dim=2)
149
+ )
150
+
151
+ y_d_rs = self.wd(y_embeddings)
152
+
153
+ return y_d_rs
mel_processing.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+ import warnings
5
+
6
+ # warnings.simplefilter(action='ignore', category=FutureWarning)
7
+ warnings.filterwarnings(action="ignore")
8
+ MAX_WAV_VALUE = 32768.0
9
+
10
+
11
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
12
+ """
13
+ PARAMS
14
+ ------
15
+ C: compression factor
16
+ """
17
+ return torch.log(torch.clamp(x, min=clip_val) * C)
18
+
19
+
20
+ def dynamic_range_decompression_torch(x, C=1):
21
+ """
22
+ PARAMS
23
+ ------
24
+ C: compression factor used to compress
25
+ """
26
+ return torch.exp(x) / C
27
+
28
+
29
+ def spectral_normalize_torch(magnitudes):
30
+ output = dynamic_range_compression_torch(magnitudes)
31
+ return output
32
+
33
+
34
+ def spectral_de_normalize_torch(magnitudes):
35
+ output = dynamic_range_decompression_torch(magnitudes)
36
+ return output
37
+
38
+
39
+ mel_basis = {}
40
+ hann_window = {}
41
+
42
+
43
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
44
+ if torch.min(y) < -1.0:
45
+ print("min value is ", torch.min(y))
46
+ if torch.max(y) > 1.0:
47
+ print("max value is ", torch.max(y))
48
+
49
+ global hann_window
50
+ dtype_device = str(y.dtype) + "_" + str(y.device)
51
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
52
+ if wnsize_dtype_device not in hann_window:
53
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
54
+ dtype=y.dtype, device=y.device
55
+ )
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1),
59
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
60
+ mode="reflect",
61
+ )
62
+ y = y.squeeze(1)
63
+
64
+ spec = torch.stft(
65
+ y,
66
+ n_fft,
67
+ hop_length=hop_size,
68
+ win_length=win_size,
69
+ window=hann_window[wnsize_dtype_device],
70
+ center=center,
71
+ pad_mode="reflect",
72
+ normalized=False,
73
+ onesided=True,
74
+ return_complex=False,
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
78
+ return spec
79
+
80
+
81
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
82
+ global mel_basis
83
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
84
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
85
+ if fmax_dtype_device not in mel_basis:
86
+ mel = librosa_mel_fn(
87
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
88
+ )
89
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
90
+ dtype=spec.dtype, device=spec.device
91
+ )
92
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
93
+ spec = spectral_normalize_torch(spec)
94
+ return spec
95
+
96
+
97
+ def mel_spectrogram_torch(
98
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
99
+ ):
100
+ if torch.min(y) < -1.0:
101
+ print("min value is ", torch.min(y))
102
+ if torch.max(y) > 1.0:
103
+ print("max value is ", torch.max(y))
104
+
105
+ global mel_basis, hann_window
106
+ dtype_device = str(y.dtype) + "_" + str(y.device)
107
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
108
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
109
+ if fmax_dtype_device not in mel_basis:
110
+ mel = librosa_mel_fn(
111
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
112
+ )
113
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
114
+ dtype=y.dtype, device=y.device
115
+ )
116
+ if wnsize_dtype_device not in hann_window:
117
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
118
+ dtype=y.dtype, device=y.device
119
+ )
120
+
121
+ y = torch.nn.functional.pad(
122
+ y.unsqueeze(1),
123
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
124
+ mode="reflect",
125
+ )
126
+ y = y.squeeze(1)
127
+
128
+ spec = torch.stft(
129
+ y,
130
+ n_fft,
131
+ hop_length=hop_size,
132
+ win_length=win_size,
133
+ window=hann_window[wnsize_dtype_device],
134
+ center=center,
135
+ pad_mode="reflect",
136
+ normalized=False,
137
+ onesided=True,
138
+ return_complex=False,
139
+ )
140
+
141
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
142
+
143
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
144
+ spec = spectral_normalize_torch(spec)
145
+
146
+ return spec
models.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
9
+
10
+ import attentions
11
+ import commons
12
+ import modules
13
+ import monotonic_align
14
+ from commons import get_padding, init_weights
15
+ from text import num_languages, num_tones, symbols
16
+
17
+
18
+ class DurationDiscriminator(nn.Module): # vits2
19
+ def __init__(
20
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
21
+ ):
22
+ super().__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.filter_channels = filter_channels
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.gin_channels = gin_channels
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.conv_1 = nn.Conv1d(
32
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
33
+ )
34
+ self.norm_1 = modules.LayerNorm(filter_channels)
35
+ self.conv_2 = nn.Conv1d(
36
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
37
+ )
38
+ self.norm_2 = modules.LayerNorm(filter_channels)
39
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
40
+
41
+ self.pre_out_conv_1 = nn.Conv1d(
42
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
43
+ )
44
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
45
+ self.pre_out_conv_2 = nn.Conv1d(
46
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
47
+ )
48
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
49
+
50
+ if gin_channels != 0:
51
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
52
+
53
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
54
+
55
+ def forward_probability(self, x, x_mask, dur, g=None):
56
+ dur = self.dur_proj(dur)
57
+ x = torch.cat([x, dur], dim=1)
58
+ x = self.pre_out_conv_1(x * x_mask)
59
+ x = torch.relu(x)
60
+ x = self.pre_out_norm_1(x)
61
+ x = self.drop(x)
62
+ x = self.pre_out_conv_2(x * x_mask)
63
+ x = torch.relu(x)
64
+ x = self.pre_out_norm_2(x)
65
+ x = self.drop(x)
66
+ x = x * x_mask
67
+ x = x.transpose(1, 2)
68
+ output_prob = self.output_layer(x)
69
+ return output_prob
70
+
71
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
72
+ x = torch.detach(x)
73
+ if g is not None:
74
+ g = torch.detach(g)
75
+ x = x + self.cond(g)
76
+ x = self.conv_1(x * x_mask)
77
+ x = torch.relu(x)
78
+ x = self.norm_1(x)
79
+ x = self.drop(x)
80
+ x = self.conv_2(x * x_mask)
81
+ x = torch.relu(x)
82
+ x = self.norm_2(x)
83
+ x = self.drop(x)
84
+
85
+ output_probs = []
86
+ for dur in [dur_r, dur_hat]:
87
+ output_prob = self.forward_probability(x, x_mask, dur, g)
88
+ output_probs.append(output_prob)
89
+
90
+ return output_probs
91
+
92
+
93
+ class TransformerCouplingBlock(nn.Module):
94
+ def __init__(
95
+ self,
96
+ channels,
97
+ hidden_channels,
98
+ filter_channels,
99
+ n_heads,
100
+ n_layers,
101
+ kernel_size,
102
+ p_dropout,
103
+ n_flows=4,
104
+ gin_channels=0,
105
+ share_parameter=False,
106
+ ):
107
+ super().__init__()
108
+ self.channels = channels
109
+ self.hidden_channels = hidden_channels
110
+ self.kernel_size = kernel_size
111
+ self.n_layers = n_layers
112
+ self.n_flows = n_flows
113
+ self.gin_channels = gin_channels
114
+
115
+ self.flows = nn.ModuleList()
116
+
117
+ self.wn = (
118
+ attentions.FFT(
119
+ hidden_channels,
120
+ filter_channels,
121
+ n_heads,
122
+ n_layers,
123
+ kernel_size,
124
+ p_dropout,
125
+ isflow=True,
126
+ gin_channels=self.gin_channels,
127
+ )
128
+ if share_parameter
129
+ else None
130
+ )
131
+
132
+ for i in range(n_flows):
133
+ self.flows.append(
134
+ modules.TransformerCouplingLayer(
135
+ channels,
136
+ hidden_channels,
137
+ kernel_size,
138
+ n_layers,
139
+ n_heads,
140
+ p_dropout,
141
+ filter_channels,
142
+ mean_only=True,
143
+ wn_sharing_parameter=self.wn,
144
+ gin_channels=self.gin_channels,
145
+ )
146
+ )
147
+ self.flows.append(modules.Flip())
148
+
149
+ def forward(self, x, x_mask, g=None, reverse=False):
150
+ if not reverse:
151
+ for flow in self.flows:
152
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
153
+ else:
154
+ for flow in reversed(self.flows):
155
+ x = flow(x, x_mask, g=g, reverse=reverse)
156
+ return x
157
+
158
+
159
+ class StochasticDurationPredictor(nn.Module):
160
+ def __init__(
161
+ self,
162
+ in_channels,
163
+ filter_channels,
164
+ kernel_size,
165
+ p_dropout,
166
+ n_flows=4,
167
+ gin_channels=0,
168
+ ):
169
+ super().__init__()
170
+ filter_channels = in_channels # it needs to be removed from future version.
171
+ self.in_channels = in_channels
172
+ self.filter_channels = filter_channels
173
+ self.kernel_size = kernel_size
174
+ self.p_dropout = p_dropout
175
+ self.n_flows = n_flows
176
+ self.gin_channels = gin_channels
177
+
178
+ self.log_flow = modules.Log()
179
+ self.flows = nn.ModuleList()
180
+ self.flows.append(modules.ElementwiseAffine(2))
181
+ for i in range(n_flows):
182
+ self.flows.append(
183
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
184
+ )
185
+ self.flows.append(modules.Flip())
186
+
187
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
188
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
189
+ self.post_convs = modules.DDSConv(
190
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
191
+ )
192
+ self.post_flows = nn.ModuleList()
193
+ self.post_flows.append(modules.ElementwiseAffine(2))
194
+ for i in range(4):
195
+ self.post_flows.append(
196
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
197
+ )
198
+ self.post_flows.append(modules.Flip())
199
+
200
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
201
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
202
+ self.convs = modules.DDSConv(
203
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
204
+ )
205
+ if gin_channels != 0:
206
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
207
+
208
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
209
+ x = torch.detach(x)
210
+ x = self.pre(x)
211
+ if g is not None:
212
+ g = torch.detach(g)
213
+ x = x + self.cond(g)
214
+ x = self.convs(x, x_mask)
215
+ x = self.proj(x) * x_mask
216
+
217
+ if not reverse:
218
+ flows = self.flows
219
+ assert w is not None
220
+
221
+ logdet_tot_q = 0
222
+ h_w = self.post_pre(w)
223
+ h_w = self.post_convs(h_w, x_mask)
224
+ h_w = self.post_proj(h_w) * x_mask
225
+ e_q = (
226
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
227
+ * x_mask
228
+ )
229
+ z_q = e_q
230
+ for flow in self.post_flows:
231
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
232
+ logdet_tot_q += logdet_q
233
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
234
+ u = torch.sigmoid(z_u) * x_mask
235
+ z0 = (w - u) * x_mask
236
+ logdet_tot_q += torch.sum(
237
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
238
+ )
239
+ logq = (
240
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
241
+ - logdet_tot_q
242
+ )
243
+
244
+ logdet_tot = 0
245
+ z0, logdet = self.log_flow(z0, x_mask)
246
+ logdet_tot += logdet
247
+ z = torch.cat([z0, z1], 1)
248
+ for flow in flows:
249
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
250
+ logdet_tot = logdet_tot + logdet
251
+ nll = (
252
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
253
+ - logdet_tot
254
+ )
255
+ return nll + logq # [b]
256
+ else:
257
+ flows = list(reversed(self.flows))
258
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
259
+ z = (
260
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
261
+ * noise_scale
262
+ )
263
+ for flow in flows:
264
+ z = flow(z, x_mask, g=x, reverse=reverse)
265
+ z0, z1 = torch.split(z, [1, 1], 1)
266
+ logw = z0
267
+ return logw
268
+
269
+
270
+ class DurationPredictor(nn.Module):
271
+ def __init__(
272
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
273
+ ):
274
+ super().__init__()
275
+
276
+ self.in_channels = in_channels
277
+ self.filter_channels = filter_channels
278
+ self.kernel_size = kernel_size
279
+ self.p_dropout = p_dropout
280
+ self.gin_channels = gin_channels
281
+
282
+ self.drop = nn.Dropout(p_dropout)
283
+ self.conv_1 = nn.Conv1d(
284
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
285
+ )
286
+ self.norm_1 = modules.LayerNorm(filter_channels)
287
+ self.conv_2 = nn.Conv1d(
288
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
289
+ )
290
+ self.norm_2 = modules.LayerNorm(filter_channels)
291
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
292
+
293
+ if gin_channels != 0:
294
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
295
+
296
+ def forward(self, x, x_mask, g=None):
297
+ x = torch.detach(x)
298
+ if g is not None:
299
+ g = torch.detach(g)
300
+ x = x + self.cond(g)
301
+ x = self.conv_1(x * x_mask)
302
+ x = torch.relu(x)
303
+ x = self.norm_1(x)
304
+ x = self.drop(x)
305
+ x = self.conv_2(x * x_mask)
306
+ x = torch.relu(x)
307
+ x = self.norm_2(x)
308
+ x = self.drop(x)
309
+ x = self.proj(x * x_mask)
310
+ return x * x_mask
311
+
312
+
313
+ class TextEncoder(nn.Module):
314
+ def __init__(
315
+ self,
316
+ n_vocab,
317
+ out_channels,
318
+ hidden_channels,
319
+ filter_channels,
320
+ n_heads,
321
+ n_layers,
322
+ kernel_size,
323
+ p_dropout,
324
+ n_speakers,
325
+ gin_channels=0,
326
+ ):
327
+ super().__init__()
328
+ self.n_vocab = n_vocab
329
+ self.out_channels = out_channels
330
+ self.hidden_channels = hidden_channels
331
+ self.filter_channels = filter_channels
332
+ self.n_heads = n_heads
333
+ self.n_layers = n_layers
334
+ self.kernel_size = kernel_size
335
+ self.p_dropout = p_dropout
336
+ self.gin_channels = gin_channels
337
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
338
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
339
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
340
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
341
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
342
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
343
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
344
+ self.ja_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
345
+ self.en_bert_proj = nn.Conv1d(1024, hidden_channels, 1)
346
+ self.style_proj = nn.Linear(256, hidden_channels)
347
+
348
+ self.encoder = attentions.Encoder(
349
+ hidden_channels,
350
+ filter_channels,
351
+ n_heads,
352
+ n_layers,
353
+ kernel_size,
354
+ p_dropout,
355
+ gin_channels=self.gin_channels,
356
+ )
357
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
358
+
359
+ def forward(
360
+ self,
361
+ x,
362
+ x_lengths,
363
+ tone,
364
+ language,
365
+ bert,
366
+ ja_bert,
367
+ en_bert,
368
+ style_vec,
369
+ sid,
370
+ g=None,
371
+ ):
372
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
373
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
374
+ en_bert_emb = self.en_bert_proj(en_bert).transpose(1, 2)
375
+ style_emb = self.style_proj(style_vec.unsqueeze(1))
376
+
377
+ x = (
378
+ self.emb(x)
379
+ + self.tone_emb(tone)
380
+ + self.language_emb(language)
381
+ + bert_emb
382
+ + ja_bert_emb
383
+ + en_bert_emb
384
+ + style_emb
385
+ ) * math.sqrt(
386
+ self.hidden_channels
387
+ ) # [b, t, h]
388
+ x = torch.transpose(x, 1, -1) # [b, h, t]
389
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
390
+ x.dtype
391
+ )
392
+
393
+ x = self.encoder(x * x_mask, x_mask, g=g)
394
+ stats = self.proj(x) * x_mask
395
+
396
+ m, logs = torch.split(stats, self.out_channels, dim=1)
397
+ return x, m, logs, x_mask
398
+
399
+
400
+ class ResidualCouplingBlock(nn.Module):
401
+ def __init__(
402
+ self,
403
+ channels,
404
+ hidden_channels,
405
+ kernel_size,
406
+ dilation_rate,
407
+ n_layers,
408
+ n_flows=4,
409
+ gin_channels=0,
410
+ ):
411
+ super().__init__()
412
+ self.channels = channels
413
+ self.hidden_channels = hidden_channels
414
+ self.kernel_size = kernel_size
415
+ self.dilation_rate = dilation_rate
416
+ self.n_layers = n_layers
417
+ self.n_flows = n_flows
418
+ self.gin_channels = gin_channels
419
+
420
+ self.flows = nn.ModuleList()
421
+ for i in range(n_flows):
422
+ self.flows.append(
423
+ modules.ResidualCouplingLayer(
424
+ channels,
425
+ hidden_channels,
426
+ kernel_size,
427
+ dilation_rate,
428
+ n_layers,
429
+ gin_channels=gin_channels,
430
+ mean_only=True,
431
+ )
432
+ )
433
+ self.flows.append(modules.Flip())
434
+
435
+ def forward(self, x, x_mask, g=None, reverse=False):
436
+ if not reverse:
437
+ for flow in self.flows:
438
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
439
+ else:
440
+ for flow in reversed(self.flows):
441
+ x = flow(x, x_mask, g=g, reverse=reverse)
442
+ return x
443
+
444
+
445
+ class PosteriorEncoder(nn.Module):
446
+ def __init__(
447
+ self,
448
+ in_channels,
449
+ out_channels,
450
+ hidden_channels,
451
+ kernel_size,
452
+ dilation_rate,
453
+ n_layers,
454
+ gin_channels=0,
455
+ ):
456
+ super().__init__()
457
+ self.in_channels = in_channels
458
+ self.out_channels = out_channels
459
+ self.hidden_channels = hidden_channels
460
+ self.kernel_size = kernel_size
461
+ self.dilation_rate = dilation_rate
462
+ self.n_layers = n_layers
463
+ self.gin_channels = gin_channels
464
+
465
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
466
+ self.enc = modules.WN(
467
+ hidden_channels,
468
+ kernel_size,
469
+ dilation_rate,
470
+ n_layers,
471
+ gin_channels=gin_channels,
472
+ )
473
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
474
+
475
+ def forward(self, x, x_lengths, g=None):
476
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
477
+ x.dtype
478
+ )
479
+ x = self.pre(x) * x_mask
480
+ x = self.enc(x, x_mask, g=g)
481
+ stats = self.proj(x) * x_mask
482
+ m, logs = torch.split(stats, self.out_channels, dim=1)
483
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
484
+ return z, m, logs, x_mask
485
+
486
+
487
+ class Generator(torch.nn.Module):
488
+ def __init__(
489
+ self,
490
+ initial_channel,
491
+ resblock,
492
+ resblock_kernel_sizes,
493
+ resblock_dilation_sizes,
494
+ upsample_rates,
495
+ upsample_initial_channel,
496
+ upsample_kernel_sizes,
497
+ gin_channels=0,
498
+ ):
499
+ super(Generator, self).__init__()
500
+ self.num_kernels = len(resblock_kernel_sizes)
501
+ self.num_upsamples = len(upsample_rates)
502
+ self.conv_pre = Conv1d(
503
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
504
+ )
505
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
506
+
507
+ self.ups = nn.ModuleList()
508
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
509
+ self.ups.append(
510
+ weight_norm(
511
+ ConvTranspose1d(
512
+ upsample_initial_channel // (2**i),
513
+ upsample_initial_channel // (2 ** (i + 1)),
514
+ k,
515
+ u,
516
+ padding=(k - u) // 2,
517
+ )
518
+ )
519
+ )
520
+
521
+ self.resblocks = nn.ModuleList()
522
+ for i in range(len(self.ups)):
523
+ ch = upsample_initial_channel // (2 ** (i + 1))
524
+ for j, (k, d) in enumerate(
525
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
526
+ ):
527
+ self.resblocks.append(resblock(ch, k, d))
528
+
529
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
530
+ self.ups.apply(init_weights)
531
+
532
+ if gin_channels != 0:
533
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
534
+
535
+ def forward(self, x, g=None):
536
+ x = self.conv_pre(x)
537
+ if g is not None:
538
+ x = x + self.cond(g)
539
+
540
+ for i in range(self.num_upsamples):
541
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
542
+ x = self.ups[i](x)
543
+ xs = None
544
+ for j in range(self.num_kernels):
545
+ if xs is None:
546
+ xs = self.resblocks[i * self.num_kernels + j](x)
547
+ else:
548
+ xs += self.resblocks[i * self.num_kernels + j](x)
549
+ x = xs / self.num_kernels
550
+ x = F.leaky_relu(x)
551
+ x = self.conv_post(x)
552
+ x = torch.tanh(x)
553
+
554
+ return x
555
+
556
+ def remove_weight_norm(self):
557
+ print("Removing weight norm...")
558
+ for layer in self.ups:
559
+ remove_weight_norm(layer)
560
+ for layer in self.resblocks:
561
+ layer.remove_weight_norm()
562
+
563
+
564
+ class DiscriminatorP(torch.nn.Module):
565
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
566
+ super(DiscriminatorP, self).__init__()
567
+ self.period = period
568
+ self.use_spectral_norm = use_spectral_norm
569
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
570
+ self.convs = nn.ModuleList(
571
+ [
572
+ norm_f(
573
+ Conv2d(
574
+ 1,
575
+ 32,
576
+ (kernel_size, 1),
577
+ (stride, 1),
578
+ padding=(get_padding(kernel_size, 1), 0),
579
+ )
580
+ ),
581
+ norm_f(
582
+ Conv2d(
583
+ 32,
584
+ 128,
585
+ (kernel_size, 1),
586
+ (stride, 1),
587
+ padding=(get_padding(kernel_size, 1), 0),
588
+ )
589
+ ),
590
+ norm_f(
591
+ Conv2d(
592
+ 128,
593
+ 512,
594
+ (kernel_size, 1),
595
+ (stride, 1),
596
+ padding=(get_padding(kernel_size, 1), 0),
597
+ )
598
+ ),
599
+ norm_f(
600
+ Conv2d(
601
+ 512,
602
+ 1024,
603
+ (kernel_size, 1),
604
+ (stride, 1),
605
+ padding=(get_padding(kernel_size, 1), 0),
606
+ )
607
+ ),
608
+ norm_f(
609
+ Conv2d(
610
+ 1024,
611
+ 1024,
612
+ (kernel_size, 1),
613
+ 1,
614
+ padding=(get_padding(kernel_size, 1), 0),
615
+ )
616
+ ),
617
+ ]
618
+ )
619
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
620
+
621
+ def forward(self, x):
622
+ fmap = []
623
+
624
+ # 1d to 2d
625
+ b, c, t = x.shape
626
+ if t % self.period != 0: # pad first
627
+ n_pad = self.period - (t % self.period)
628
+ x = F.pad(x, (0, n_pad), "reflect")
629
+ t = t + n_pad
630
+ x = x.view(b, c, t // self.period, self.period)
631
+
632
+ for layer in self.convs:
633
+ x = layer(x)
634
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
635
+ fmap.append(x)
636
+ x = self.conv_post(x)
637
+ fmap.append(x)
638
+ x = torch.flatten(x, 1, -1)
639
+
640
+ return x, fmap
641
+
642
+
643
+ class DiscriminatorS(torch.nn.Module):
644
+ def __init__(self, use_spectral_norm=False):
645
+ super(DiscriminatorS, self).__init__()
646
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
647
+ self.convs = nn.ModuleList(
648
+ [
649
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
650
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
651
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
652
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
653
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
654
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
655
+ ]
656
+ )
657
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
658
+
659
+ def forward(self, x):
660
+ fmap = []
661
+
662
+ for layer in self.convs:
663
+ x = layer(x)
664
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
665
+ fmap.append(x)
666
+ x = self.conv_post(x)
667
+ fmap.append(x)
668
+ x = torch.flatten(x, 1, -1)
669
+
670
+ return x, fmap
671
+
672
+
673
+ class MultiPeriodDiscriminator(torch.nn.Module):
674
+ def __init__(self, use_spectral_norm=False):
675
+ super(MultiPeriodDiscriminator, self).__init__()
676
+ periods = [2, 3, 5, 7, 11]
677
+
678
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
679
+ discs = discs + [
680
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
681
+ ]
682
+ self.discriminators = nn.ModuleList(discs)
683
+
684
+ def forward(self, y, y_hat):
685
+ y_d_rs = []
686
+ y_d_gs = []
687
+ fmap_rs = []
688
+ fmap_gs = []
689
+ for i, d in enumerate(self.discriminators):
690
+ y_d_r, fmap_r = d(y)
691
+ y_d_g, fmap_g = d(y_hat)
692
+ y_d_rs.append(y_d_r)
693
+ y_d_gs.append(y_d_g)
694
+ fmap_rs.append(fmap_r)
695
+ fmap_gs.append(fmap_g)
696
+
697
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
698
+
699
+
700
+ class ReferenceEncoder(nn.Module):
701
+ """
702
+ inputs --- [N, Ty/r, n_mels*r] mels
703
+ outputs --- [N, ref_enc_gru_size]
704
+ """
705
+
706
+ def __init__(self, spec_channels, gin_channels=0):
707
+ super().__init__()
708
+ self.spec_channels = spec_channels
709
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
710
+ K = len(ref_enc_filters)
711
+ filters = [1] + ref_enc_filters
712
+ convs = [
713
+ weight_norm(
714
+ nn.Conv2d(
715
+ in_channels=filters[i],
716
+ out_channels=filters[i + 1],
717
+ kernel_size=(3, 3),
718
+ stride=(2, 2),
719
+ padding=(1, 1),
720
+ )
721
+ )
722
+ for i in range(K)
723
+ ]
724
+ self.convs = nn.ModuleList(convs)
725
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
726
+
727
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
728
+ self.gru = nn.GRU(
729
+ input_size=ref_enc_filters[-1] * out_channels,
730
+ hidden_size=256 // 2,
731
+ batch_first=True,
732
+ )
733
+ self.proj = nn.Linear(128, gin_channels)
734
+
735
+ def forward(self, inputs, mask=None):
736
+ N = inputs.size(0)
737
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
738
+ for conv in self.convs:
739
+ out = conv(out)
740
+ # out = wn(out)
741
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
742
+
743
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
744
+ T = out.size(1)
745
+ N = out.size(0)
746
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
747
+
748
+ self.gru.flatten_parameters()
749
+ memory, out = self.gru(out) # out --- [1, N, 128]
750
+
751
+ return self.proj(out.squeeze(0))
752
+
753
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
754
+ for i in range(n_convs):
755
+ L = (L - kernel_size + 2 * pad) // stride + 1
756
+ return L
757
+
758
+
759
+ class SynthesizerTrn(nn.Module):
760
+ """
761
+ Synthesizer for Training
762
+ """
763
+
764
+ def __init__(
765
+ self,
766
+ n_vocab,
767
+ spec_channels,
768
+ segment_size,
769
+ inter_channels,
770
+ hidden_channels,
771
+ filter_channels,
772
+ n_heads,
773
+ n_layers,
774
+ kernel_size,
775
+ p_dropout,
776
+ resblock,
777
+ resblock_kernel_sizes,
778
+ resblock_dilation_sizes,
779
+ upsample_rates,
780
+ upsample_initial_channel,
781
+ upsample_kernel_sizes,
782
+ n_speakers=256,
783
+ gin_channels=256,
784
+ use_sdp=True,
785
+ n_flow_layer=4,
786
+ n_layers_trans_flow=4,
787
+ flow_share_parameter=False,
788
+ use_transformer_flow=True,
789
+ **kwargs,
790
+ ):
791
+ super().__init__()
792
+ self.n_vocab = n_vocab
793
+ self.spec_channels = spec_channels
794
+ self.inter_channels = inter_channels
795
+ self.hidden_channels = hidden_channels
796
+ self.filter_channels = filter_channels
797
+ self.n_heads = n_heads
798
+ self.n_layers = n_layers
799
+ self.kernel_size = kernel_size
800
+ self.p_dropout = p_dropout
801
+ self.resblock = resblock
802
+ self.resblock_kernel_sizes = resblock_kernel_sizes
803
+ self.resblock_dilation_sizes = resblock_dilation_sizes
804
+ self.upsample_rates = upsample_rates
805
+ self.upsample_initial_channel = upsample_initial_channel
806
+ self.upsample_kernel_sizes = upsample_kernel_sizes
807
+ self.segment_size = segment_size
808
+ self.n_speakers = n_speakers
809
+ self.gin_channels = gin_channels
810
+ self.n_layers_trans_flow = n_layers_trans_flow
811
+ self.use_spk_conditioned_encoder = kwargs.get(
812
+ "use_spk_conditioned_encoder", True
813
+ )
814
+ self.use_sdp = use_sdp
815
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
816
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
817
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
818
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
819
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
820
+ self.enc_gin_channels = gin_channels
821
+ self.enc_p = TextEncoder(
822
+ n_vocab,
823
+ inter_channels,
824
+ hidden_channels,
825
+ filter_channels,
826
+ n_heads,
827
+ n_layers,
828
+ kernel_size,
829
+ p_dropout,
830
+ self.n_speakers,
831
+ gin_channels=self.enc_gin_channels,
832
+ )
833
+ self.dec = Generator(
834
+ inter_channels,
835
+ resblock,
836
+ resblock_kernel_sizes,
837
+ resblock_dilation_sizes,
838
+ upsample_rates,
839
+ upsample_initial_channel,
840
+ upsample_kernel_sizes,
841
+ gin_channels=gin_channels,
842
+ )
843
+ self.enc_q = PosteriorEncoder(
844
+ spec_channels,
845
+ inter_channels,
846
+ hidden_channels,
847
+ 5,
848
+ 1,
849
+ 16,
850
+ gin_channels=gin_channels,
851
+ )
852
+ if use_transformer_flow:
853
+ self.flow = TransformerCouplingBlock(
854
+ inter_channels,
855
+ hidden_channels,
856
+ filter_channels,
857
+ n_heads,
858
+ n_layers_trans_flow,
859
+ 5,
860
+ p_dropout,
861
+ n_flow_layer,
862
+ gin_channels=gin_channels,
863
+ share_parameter=flow_share_parameter,
864
+ )
865
+ else:
866
+ self.flow = ResidualCouplingBlock(
867
+ inter_channels,
868
+ hidden_channels,
869
+ 5,
870
+ 1,
871
+ n_flow_layer,
872
+ gin_channels=gin_channels,
873
+ )
874
+ self.sdp = StochasticDurationPredictor(
875
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
876
+ )
877
+ self.dp = DurationPredictor(
878
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
879
+ )
880
+
881
+ if n_speakers >= 1:
882
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
883
+ else:
884
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
885
+
886
+ def forward(
887
+ self,
888
+ x,
889
+ x_lengths,
890
+ y,
891
+ y_lengths,
892
+ sid,
893
+ tone,
894
+ language,
895
+ bert,
896
+ ja_bert,
897
+ en_bert,
898
+ style_vec,
899
+ ):
900
+ if self.n_speakers > 0:
901
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
902
+ else:
903
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
904
+ x, m_p, logs_p, x_mask = self.enc_p(
905
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
906
+ )
907
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
908
+ z_p = self.flow(z, y_mask, g=g)
909
+
910
+ with torch.no_grad():
911
+ # negative cross-entropy
912
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
913
+ neg_cent1 = torch.sum(
914
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
915
+ ) # [b, 1, t_s]
916
+ neg_cent2 = torch.matmul(
917
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
918
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
919
+ neg_cent3 = torch.matmul(
920
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
921
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
922
+ neg_cent4 = torch.sum(
923
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
924
+ ) # [b, 1, t_s]
925
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
926
+ if self.use_noise_scaled_mas:
927
+ epsilon = (
928
+ torch.std(neg_cent)
929
+ * torch.randn_like(neg_cent)
930
+ * self.current_mas_noise_scale
931
+ )
932
+ neg_cent = neg_cent + epsilon
933
+
934
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
935
+ attn = (
936
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
937
+ .unsqueeze(1)
938
+ .detach()
939
+ )
940
+
941
+ w = attn.sum(2)
942
+
943
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
944
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
945
+
946
+ logw_ = torch.log(w + 1e-6) * x_mask
947
+ logw = self.dp(x, x_mask, g=g)
948
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
949
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
950
+ x_mask
951
+ ) # for averaging
952
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
953
+
954
+ l_length = l_length_dp + l_length_sdp
955
+
956
+ # expand prior
957
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
958
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
959
+
960
+ z_slice, ids_slice = commons.rand_slice_segments(
961
+ z, y_lengths, self.segment_size
962
+ )
963
+ o = self.dec(z_slice, g=g)
964
+ return (
965
+ o,
966
+ l_length,
967
+ attn,
968
+ ids_slice,
969
+ x_mask,
970
+ y_mask,
971
+ (z, z_p, m_p, logs_p, m_q, logs_q),
972
+ (x, logw, logw_),
973
+ )
974
+
975
+ def infer(
976
+ self,
977
+ x,
978
+ x_lengths,
979
+ sid,
980
+ tone,
981
+ language,
982
+ bert,
983
+ ja_bert,
984
+ en_bert,
985
+ style_vec,
986
+ noise_scale=0.667,
987
+ length_scale=1,
988
+ noise_scale_w=0.8,
989
+ max_len=None,
990
+ sdp_ratio=0,
991
+ y=None,
992
+ ):
993
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
994
+ # g = self.gst(y)
995
+ if self.n_speakers > 0:
996
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
997
+ else:
998
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
999
+ x, m_p, logs_p, x_mask = self.enc_p(
1000
+ x, x_lengths, tone, language, bert, ja_bert, en_bert, style_vec, sid, g=g
1001
+ )
1002
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1003
+ sdp_ratio
1004
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1005
+ w = torch.exp(logw) * x_mask * length_scale
1006
+ w_ceil = torch.ceil(w)
1007
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1008
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1009
+ x_mask.dtype
1010
+ )
1011
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1012
+ attn = commons.generate_path(w_ceil, attn_mask)
1013
+
1014
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1015
+ 1, 2
1016
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1017
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1018
+ 1, 2
1019
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1020
+
1021
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1022
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1023
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1024
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
models_jp_extra.py ADDED
@@ -0,0 +1,1071 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ import commons
7
+ import modules
8
+ import attentions
9
+ import monotonic_align
10
+
11
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
12
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
13
+
14
+ from commons import init_weights, get_padding
15
+ from text import symbols, num_tones, num_languages
16
+
17
+
18
+ class DurationDiscriminator(nn.Module): # vits2
19
+ def __init__(
20
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
21
+ ):
22
+ super().__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.filter_channels = filter_channels
26
+ self.kernel_size = kernel_size
27
+ self.p_dropout = p_dropout
28
+ self.gin_channels = gin_channels
29
+
30
+ self.drop = nn.Dropout(p_dropout)
31
+ self.conv_1 = nn.Conv1d(
32
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
33
+ )
34
+ self.norm_1 = modules.LayerNorm(filter_channels)
35
+ self.conv_2 = nn.Conv1d(
36
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
37
+ )
38
+ self.norm_2 = modules.LayerNorm(filter_channels)
39
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
40
+
41
+ self.LSTM = nn.LSTM(
42
+ 2 * filter_channels, filter_channels, batch_first=True, bidirectional=True
43
+ )
44
+
45
+ if gin_channels != 0:
46
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
47
+
48
+ self.output_layer = nn.Sequential(
49
+ nn.Linear(2 * filter_channels, 1), nn.Sigmoid()
50
+ )
51
+
52
+ def forward_probability(self, x, dur):
53
+ dur = self.dur_proj(dur)
54
+ x = torch.cat([x, dur], dim=1)
55
+ x = x.transpose(1, 2)
56
+ x, _ = self.LSTM(x)
57
+ output_prob = self.output_layer(x)
58
+ return output_prob
59
+
60
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
61
+ x = torch.detach(x)
62
+ if g is not None:
63
+ g = torch.detach(g)
64
+ x = x + self.cond(g)
65
+ x = self.conv_1(x * x_mask)
66
+ x = torch.relu(x)
67
+ x = self.norm_1(x)
68
+ x = self.drop(x)
69
+ x = self.conv_2(x * x_mask)
70
+ x = torch.relu(x)
71
+ x = self.norm_2(x)
72
+ x = self.drop(x)
73
+
74
+ output_probs = []
75
+ for dur in [dur_r, dur_hat]:
76
+ output_prob = self.forward_probability(x, dur)
77
+ output_probs.append(output_prob)
78
+
79
+ return output_probs
80
+
81
+
82
+ class TransformerCouplingBlock(nn.Module):
83
+ def __init__(
84
+ self,
85
+ channels,
86
+ hidden_channels,
87
+ filter_channels,
88
+ n_heads,
89
+ n_layers,
90
+ kernel_size,
91
+ p_dropout,
92
+ n_flows=4,
93
+ gin_channels=0,
94
+ share_parameter=False,
95
+ ):
96
+ super().__init__()
97
+ self.channels = channels
98
+ self.hidden_channels = hidden_channels
99
+ self.kernel_size = kernel_size
100
+ self.n_layers = n_layers
101
+ self.n_flows = n_flows
102
+ self.gin_channels = gin_channels
103
+
104
+ self.flows = nn.ModuleList()
105
+
106
+ self.wn = (
107
+ attentions.FFT(
108
+ hidden_channels,
109
+ filter_channels,
110
+ n_heads,
111
+ n_layers,
112
+ kernel_size,
113
+ p_dropout,
114
+ isflow=True,
115
+ gin_channels=self.gin_channels,
116
+ )
117
+ if share_parameter
118
+ else None
119
+ )
120
+
121
+ for i in range(n_flows):
122
+ self.flows.append(
123
+ modules.TransformerCouplingLayer(
124
+ channels,
125
+ hidden_channels,
126
+ kernel_size,
127
+ n_layers,
128
+ n_heads,
129
+ p_dropout,
130
+ filter_channels,
131
+ mean_only=True,
132
+ wn_sharing_parameter=self.wn,
133
+ gin_channels=self.gin_channels,
134
+ )
135
+ )
136
+ self.flows.append(modules.Flip())
137
+
138
+ def forward(self, x, x_mask, g=None, reverse=False):
139
+ if not reverse:
140
+ for flow in self.flows:
141
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
142
+ else:
143
+ for flow in reversed(self.flows):
144
+ x = flow(x, x_mask, g=g, reverse=reverse)
145
+ return x
146
+
147
+
148
+ class StochasticDurationPredictor(nn.Module):
149
+ def __init__(
150
+ self,
151
+ in_channels,
152
+ filter_channels,
153
+ kernel_size,
154
+ p_dropout,
155
+ n_flows=4,
156
+ gin_channels=0,
157
+ ):
158
+ super().__init__()
159
+ filter_channels = in_channels # it needs to be removed from future version.
160
+ self.in_channels = in_channels
161
+ self.filter_channels = filter_channels
162
+ self.kernel_size = kernel_size
163
+ self.p_dropout = p_dropout
164
+ self.n_flows = n_flows
165
+ self.gin_channels = gin_channels
166
+
167
+ self.log_flow = modules.Log()
168
+ self.flows = nn.ModuleList()
169
+ self.flows.append(modules.ElementwiseAffine(2))
170
+ for i in range(n_flows):
171
+ self.flows.append(
172
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
173
+ )
174
+ self.flows.append(modules.Flip())
175
+
176
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
177
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
178
+ self.post_convs = modules.DDSConv(
179
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
180
+ )
181
+ self.post_flows = nn.ModuleList()
182
+ self.post_flows.append(modules.ElementwiseAffine(2))
183
+ for i in range(4):
184
+ self.post_flows.append(
185
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
186
+ )
187
+ self.post_flows.append(modules.Flip())
188
+
189
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
190
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
191
+ self.convs = modules.DDSConv(
192
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
193
+ )
194
+ if gin_channels != 0:
195
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
196
+
197
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
198
+ x = torch.detach(x)
199
+ x = self.pre(x)
200
+ if g is not None:
201
+ g = torch.detach(g)
202
+ x = x + self.cond(g)
203
+ x = self.convs(x, x_mask)
204
+ x = self.proj(x) * x_mask
205
+
206
+ if not reverse:
207
+ flows = self.flows
208
+ assert w is not None
209
+
210
+ logdet_tot_q = 0
211
+ h_w = self.post_pre(w)
212
+ h_w = self.post_convs(h_w, x_mask)
213
+ h_w = self.post_proj(h_w) * x_mask
214
+ e_q = (
215
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
216
+ * x_mask
217
+ )
218
+ z_q = e_q
219
+ for flow in self.post_flows:
220
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
221
+ logdet_tot_q += logdet_q
222
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
223
+ u = torch.sigmoid(z_u) * x_mask
224
+ z0 = (w - u) * x_mask
225
+ logdet_tot_q += torch.sum(
226
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
227
+ )
228
+ logq = (
229
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
230
+ - logdet_tot_q
231
+ )
232
+
233
+ logdet_tot = 0
234
+ z0, logdet = self.log_flow(z0, x_mask)
235
+ logdet_tot += logdet
236
+ z = torch.cat([z0, z1], 1)
237
+ for flow in flows:
238
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
239
+ logdet_tot = logdet_tot + logdet
240
+ nll = (
241
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
242
+ - logdet_tot
243
+ )
244
+ return nll + logq # [b]
245
+ else:
246
+ flows = list(reversed(self.flows))
247
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
248
+ z = (
249
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
250
+ * noise_scale
251
+ )
252
+ for flow in flows:
253
+ z = flow(z, x_mask, g=x, reverse=reverse)
254
+ z0, z1 = torch.split(z, [1, 1], 1)
255
+ logw = z0
256
+ return logw
257
+
258
+
259
+ class DurationPredictor(nn.Module):
260
+ def __init__(
261
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
262
+ ):
263
+ super().__init__()
264
+
265
+ self.in_channels = in_channels
266
+ self.filter_channels = filter_channels
267
+ self.kernel_size = kernel_size
268
+ self.p_dropout = p_dropout
269
+ self.gin_channels = gin_channels
270
+
271
+ self.drop = nn.Dropout(p_dropout)
272
+ self.conv_1 = nn.Conv1d(
273
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
274
+ )
275
+ self.norm_1 = modules.LayerNorm(filter_channels)
276
+ self.conv_2 = nn.Conv1d(
277
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
278
+ )
279
+ self.norm_2 = modules.LayerNorm(filter_channels)
280
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
281
+
282
+ if gin_channels != 0:
283
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
284
+
285
+ def forward(self, x, x_mask, g=None):
286
+ x = torch.detach(x)
287
+ if g is not None:
288
+ g = torch.detach(g)
289
+ x = x + self.cond(g)
290
+ x = self.conv_1(x * x_mask)
291
+ x = torch.relu(x)
292
+ x = self.norm_1(x)
293
+ x = self.drop(x)
294
+ x = self.conv_2(x * x_mask)
295
+ x = torch.relu(x)
296
+ x = self.norm_2(x)
297
+ x = self.drop(x)
298
+ x = self.proj(x * x_mask)
299
+ return x * x_mask
300
+
301
+
302
+ class Bottleneck(nn.Sequential):
303
+ def __init__(self, in_dim, hidden_dim):
304
+ c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
305
+ c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
306
+ super().__init__(*[c_fc1, c_fc2])
307
+
308
+
309
+ class Block(nn.Module):
310
+ def __init__(self, in_dim, hidden_dim) -> None:
311
+ super().__init__()
312
+ self.norm = nn.LayerNorm(in_dim)
313
+ self.mlp = MLP(in_dim, hidden_dim)
314
+
315
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
+ x = x + self.mlp(self.norm(x))
317
+ return x
318
+
319
+
320
+ class MLP(nn.Module):
321
+ def __init__(self, in_dim, hidden_dim):
322
+ super().__init__()
323
+ self.c_fc1 = nn.Linear(in_dim, hidden_dim, bias=False)
324
+ self.c_fc2 = nn.Linear(in_dim, hidden_dim, bias=False)
325
+ self.c_proj = nn.Linear(hidden_dim, in_dim, bias=False)
326
+
327
+ def forward(self, x: torch.Tensor):
328
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
329
+ x = self.c_proj(x)
330
+ return x
331
+
332
+
333
+ class TextEncoder(nn.Module):
334
+ def __init__(
335
+ self,
336
+ n_vocab,
337
+ out_channels,
338
+ hidden_channels,
339
+ filter_channels,
340
+ n_heads,
341
+ n_layers,
342
+ kernel_size,
343
+ p_dropout,
344
+ gin_channels=0,
345
+ ):
346
+ super().__init__()
347
+ self.n_vocab = n_vocab
348
+ self.out_channels = out_channels
349
+ self.hidden_channels = hidden_channels
350
+ self.filter_channels = filter_channels
351
+ self.n_heads = n_heads
352
+ self.n_layers = n_layers
353
+ self.kernel_size = kernel_size
354
+ self.p_dropout = p_dropout
355
+ self.gin_channels = gin_channels
356
+ self.emb = nn.Embedding(len(symbols), hidden_channels)
357
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
358
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
359
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
360
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
361
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
362
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
363
+
364
+ # Remove emo_vq since it's not working well.
365
+ self.style_proj = nn.Linear(256, hidden_channels)
366
+
367
+ self.encoder = attentions.Encoder(
368
+ hidden_channels,
369
+ filter_channels,
370
+ n_heads,
371
+ n_layers,
372
+ kernel_size,
373
+ p_dropout,
374
+ gin_channels=self.gin_channels,
375
+ )
376
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
377
+
378
+ def forward(self, x, x_lengths, tone, language, bert, style_vec, g=None):
379
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
380
+ style_emb = self.style_proj(style_vec.unsqueeze(1))
381
+ x = (
382
+ self.emb(x)
383
+ + self.tone_emb(tone)
384
+ + self.language_emb(language)
385
+ + bert_emb
386
+ + style_emb
387
+ ) * math.sqrt(
388
+ self.hidden_channels
389
+ ) # [b, t, h]
390
+ x = torch.transpose(x, 1, -1) # [b, h, t]
391
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
392
+ x.dtype
393
+ )
394
+
395
+ x = self.encoder(x * x_mask, x_mask, g=g)
396
+ stats = self.proj(x) * x_mask
397
+
398
+ m, logs = torch.split(stats, self.out_channels, dim=1)
399
+ return x, m, logs, x_mask
400
+
401
+
402
+ class ResidualCouplingBlock(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ n_flows=4,
411
+ gin_channels=0,
412
+ ):
413
+ super().__init__()
414
+ self.channels = channels
415
+ self.hidden_channels = hidden_channels
416
+ self.kernel_size = kernel_size
417
+ self.dilation_rate = dilation_rate
418
+ self.n_layers = n_layers
419
+ self.n_flows = n_flows
420
+ self.gin_channels = gin_channels
421
+
422
+ self.flows = nn.ModuleList()
423
+ for i in range(n_flows):
424
+ self.flows.append(
425
+ modules.ResidualCouplingLayer(
426
+ channels,
427
+ hidden_channels,
428
+ kernel_size,
429
+ dilation_rate,
430
+ n_layers,
431
+ gin_channels=gin_channels,
432
+ mean_only=True,
433
+ )
434
+ )
435
+ self.flows.append(modules.Flip())
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ if not reverse:
439
+ for flow in self.flows:
440
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
441
+ else:
442
+ for flow in reversed(self.flows):
443
+ x = flow(x, x_mask, g=g, reverse=reverse)
444
+ return x
445
+
446
+
447
+ class PosteriorEncoder(nn.Module):
448
+ def __init__(
449
+ self,
450
+ in_channels,
451
+ out_channels,
452
+ hidden_channels,
453
+ kernel_size,
454
+ dilation_rate,
455
+ n_layers,
456
+ gin_channels=0,
457
+ ):
458
+ super().__init__()
459
+ self.in_channels = in_channels
460
+ self.out_channels = out_channels
461
+ self.hidden_channels = hidden_channels
462
+ self.kernel_size = kernel_size
463
+ self.dilation_rate = dilation_rate
464
+ self.n_layers = n_layers
465
+ self.gin_channels = gin_channels
466
+
467
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
468
+ self.enc = modules.WN(
469
+ hidden_channels,
470
+ kernel_size,
471
+ dilation_rate,
472
+ n_layers,
473
+ gin_channels=gin_channels,
474
+ )
475
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
476
+
477
+ def forward(self, x, x_lengths, g=None):
478
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
479
+ x.dtype
480
+ )
481
+ x = self.pre(x) * x_mask
482
+ x = self.enc(x, x_mask, g=g)
483
+ stats = self.proj(x) * x_mask
484
+ m, logs = torch.split(stats, self.out_channels, dim=1)
485
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
486
+ return z, m, logs, x_mask
487
+
488
+
489
+ class Generator(torch.nn.Module):
490
+ def __init__(
491
+ self,
492
+ initial_channel,
493
+ resblock,
494
+ resblock_kernel_sizes,
495
+ resblock_dilation_sizes,
496
+ upsample_rates,
497
+ upsample_initial_channel,
498
+ upsample_kernel_sizes,
499
+ gin_channels=0,
500
+ ):
501
+ super(Generator, self).__init__()
502
+ self.num_kernels = len(resblock_kernel_sizes)
503
+ self.num_upsamples = len(upsample_rates)
504
+ self.conv_pre = Conv1d(
505
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
506
+ )
507
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
508
+
509
+ self.ups = nn.ModuleList()
510
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
511
+ self.ups.append(
512
+ weight_norm(
513
+ ConvTranspose1d(
514
+ upsample_initial_channel // (2**i),
515
+ upsample_initial_channel // (2 ** (i + 1)),
516
+ k,
517
+ u,
518
+ padding=(k - u) // 2,
519
+ )
520
+ )
521
+ )
522
+
523
+ self.resblocks = nn.ModuleList()
524
+ for i in range(len(self.ups)):
525
+ ch = upsample_initial_channel // (2 ** (i + 1))
526
+ for j, (k, d) in enumerate(
527
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
528
+ ):
529
+ self.resblocks.append(resblock(ch, k, d))
530
+
531
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
532
+ self.ups.apply(init_weights)
533
+
534
+ if gin_channels != 0:
535
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
536
+
537
+ def forward(self, x, g=None):
538
+ x = self.conv_pre(x)
539
+ if g is not None:
540
+ x = x + self.cond(g)
541
+
542
+ for i in range(self.num_upsamples):
543
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
544
+ x = self.ups[i](x)
545
+ xs = None
546
+ for j in range(self.num_kernels):
547
+ if xs is None:
548
+ xs = self.resblocks[i * self.num_kernels + j](x)
549
+ else:
550
+ xs += self.resblocks[i * self.num_kernels + j](x)
551
+ x = xs / self.num_kernels
552
+ x = F.leaky_relu(x)
553
+ x = self.conv_post(x)
554
+ x = torch.tanh(x)
555
+
556
+ return x
557
+
558
+ def remove_weight_norm(self):
559
+ print("Removing weight norm...")
560
+ for layer in self.ups:
561
+ remove_weight_norm(layer)
562
+ for layer in self.resblocks:
563
+ layer.remove_weight_norm()
564
+
565
+
566
+ class DiscriminatorP(torch.nn.Module):
567
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
568
+ super(DiscriminatorP, self).__init__()
569
+ self.period = period
570
+ self.use_spectral_norm = use_spectral_norm
571
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
572
+ self.convs = nn.ModuleList(
573
+ [
574
+ norm_f(
575
+ Conv2d(
576
+ 1,
577
+ 32,
578
+ (kernel_size, 1),
579
+ (stride, 1),
580
+ padding=(get_padding(kernel_size, 1), 0),
581
+ )
582
+ ),
583
+ norm_f(
584
+ Conv2d(
585
+ 32,
586
+ 128,
587
+ (kernel_size, 1),
588
+ (stride, 1),
589
+ padding=(get_padding(kernel_size, 1), 0),
590
+ )
591
+ ),
592
+ norm_f(
593
+ Conv2d(
594
+ 128,
595
+ 512,
596
+ (kernel_size, 1),
597
+ (stride, 1),
598
+ padding=(get_padding(kernel_size, 1), 0),
599
+ )
600
+ ),
601
+ norm_f(
602
+ Conv2d(
603
+ 512,
604
+ 1024,
605
+ (kernel_size, 1),
606
+ (stride, 1),
607
+ padding=(get_padding(kernel_size, 1), 0),
608
+ )
609
+ ),
610
+ norm_f(
611
+ Conv2d(
612
+ 1024,
613
+ 1024,
614
+ (kernel_size, 1),
615
+ 1,
616
+ padding=(get_padding(kernel_size, 1), 0),
617
+ )
618
+ ),
619
+ ]
620
+ )
621
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
622
+
623
+ def forward(self, x):
624
+ fmap = []
625
+
626
+ # 1d to 2d
627
+ b, c, t = x.shape
628
+ if t % self.period != 0: # pad first
629
+ n_pad = self.period - (t % self.period)
630
+ x = F.pad(x, (0, n_pad), "reflect")
631
+ t = t + n_pad
632
+ x = x.view(b, c, t // self.period, self.period)
633
+
634
+ for layer in self.convs:
635
+ x = layer(x)
636
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
637
+ fmap.append(x)
638
+ x = self.conv_post(x)
639
+ fmap.append(x)
640
+ x = torch.flatten(x, 1, -1)
641
+
642
+ return x, fmap
643
+
644
+
645
+ class DiscriminatorS(torch.nn.Module):
646
+ def __init__(self, use_spectral_norm=False):
647
+ super(DiscriminatorS, self).__init__()
648
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
649
+ self.convs = nn.ModuleList(
650
+ [
651
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
652
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
653
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
654
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
655
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
656
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
657
+ ]
658
+ )
659
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
660
+
661
+ def forward(self, x):
662
+ fmap = []
663
+
664
+ for layer in self.convs:
665
+ x = layer(x)
666
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
667
+ fmap.append(x)
668
+ x = self.conv_post(x)
669
+ fmap.append(x)
670
+ x = torch.flatten(x, 1, -1)
671
+
672
+ return x, fmap
673
+
674
+
675
+ class MultiPeriodDiscriminator(torch.nn.Module):
676
+ def __init__(self, use_spectral_norm=False):
677
+ super(MultiPeriodDiscriminator, self).__init__()
678
+ periods = [2, 3, 5, 7, 11]
679
+
680
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
681
+ discs = discs + [
682
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
683
+ ]
684
+ self.discriminators = nn.ModuleList(discs)
685
+
686
+ def forward(self, y, y_hat):
687
+ y_d_rs = []
688
+ y_d_gs = []
689
+ fmap_rs = []
690
+ fmap_gs = []
691
+ for i, d in enumerate(self.discriminators):
692
+ y_d_r, fmap_r = d(y)
693
+ y_d_g, fmap_g = d(y_hat)
694
+ y_d_rs.append(y_d_r)
695
+ y_d_gs.append(y_d_g)
696
+ fmap_rs.append(fmap_r)
697
+ fmap_gs.append(fmap_g)
698
+
699
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
700
+
701
+
702
+ class WavLMDiscriminator(nn.Module):
703
+ """docstring for Discriminator."""
704
+
705
+ def __init__(
706
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
707
+ ):
708
+ super(WavLMDiscriminator, self).__init__()
709
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
710
+ self.pre = norm_f(
711
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
712
+ )
713
+
714
+ self.convs = nn.ModuleList(
715
+ [
716
+ norm_f(
717
+ nn.Conv1d(
718
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
719
+ )
720
+ ),
721
+ norm_f(
722
+ nn.Conv1d(
723
+ initial_channel * 2,
724
+ initial_channel * 4,
725
+ kernel_size=5,
726
+ padding=2,
727
+ )
728
+ ),
729
+ norm_f(
730
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
731
+ ),
732
+ ]
733
+ )
734
+
735
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
736
+
737
+ def forward(self, x):
738
+ x = self.pre(x)
739
+
740
+ fmap = []
741
+ for l in self.convs:
742
+ x = l(x)
743
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
744
+ fmap.append(x)
745
+ x = self.conv_post(x)
746
+ x = torch.flatten(x, 1, -1)
747
+
748
+ return x
749
+
750
+
751
+ class ReferenceEncoder(nn.Module):
752
+ """
753
+ inputs --- [N, Ty/r, n_mels*r] mels
754
+ outputs --- [N, ref_enc_gru_size]
755
+ """
756
+
757
+ def __init__(self, spec_channels, gin_channels=0):
758
+ super().__init__()
759
+ self.spec_channels = spec_channels
760
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
761
+ K = len(ref_enc_filters)
762
+ filters = [1] + ref_enc_filters
763
+ convs = [
764
+ weight_norm(
765
+ nn.Conv2d(
766
+ in_channels=filters[i],
767
+ out_channels=filters[i + 1],
768
+ kernel_size=(3, 3),
769
+ stride=(2, 2),
770
+ padding=(1, 1),
771
+ )
772
+ )
773
+ for i in range(K)
774
+ ]
775
+ self.convs = nn.ModuleList(convs)
776
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
777
+
778
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
779
+ self.gru = nn.GRU(
780
+ input_size=ref_enc_filters[-1] * out_channels,
781
+ hidden_size=256 // 2,
782
+ batch_first=True,
783
+ )
784
+ self.proj = nn.Linear(128, gin_channels)
785
+
786
+ def forward(self, inputs, mask=None):
787
+ N = inputs.size(0)
788
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
789
+ for conv in self.convs:
790
+ out = conv(out)
791
+ # out = wn(out)
792
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
793
+
794
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
795
+ T = out.size(1)
796
+ N = out.size(0)
797
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
798
+
799
+ self.gru.flatten_parameters()
800
+ memory, out = self.gru(out) # out --- [1, N, 128]
801
+
802
+ return self.proj(out.squeeze(0))
803
+
804
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
805
+ for i in range(n_convs):
806
+ L = (L - kernel_size + 2 * pad) // stride + 1
807
+ return L
808
+
809
+
810
+ class SynthesizerTrn(nn.Module):
811
+ """
812
+ Synthesizer for Training
813
+ """
814
+
815
+ def __init__(
816
+ self,
817
+ n_vocab,
818
+ spec_channels,
819
+ segment_size,
820
+ inter_channels,
821
+ hidden_channels,
822
+ filter_channels,
823
+ n_heads,
824
+ n_layers,
825
+ kernel_size,
826
+ p_dropout,
827
+ resblock,
828
+ resblock_kernel_sizes,
829
+ resblock_dilation_sizes,
830
+ upsample_rates,
831
+ upsample_initial_channel,
832
+ upsample_kernel_sizes,
833
+ n_speakers=256,
834
+ gin_channels=256,
835
+ use_sdp=True,
836
+ n_flow_layer=4,
837
+ n_layers_trans_flow=6,
838
+ flow_share_parameter=False,
839
+ use_transformer_flow=True,
840
+ **kwargs
841
+ ):
842
+ super().__init__()
843
+ self.n_vocab = n_vocab
844
+ self.spec_channels = spec_channels
845
+ self.inter_channels = inter_channels
846
+ self.hidden_channels = hidden_channels
847
+ self.filter_channels = filter_channels
848
+ self.n_heads = n_heads
849
+ self.n_layers = n_layers
850
+ self.kernel_size = kernel_size
851
+ self.p_dropout = p_dropout
852
+ self.resblock = resblock
853
+ self.resblock_kernel_sizes = resblock_kernel_sizes
854
+ self.resblock_dilation_sizes = resblock_dilation_sizes
855
+ self.upsample_rates = upsample_rates
856
+ self.upsample_initial_channel = upsample_initial_channel
857
+ self.upsample_kernel_sizes = upsample_kernel_sizes
858
+ self.segment_size = segment_size
859
+ self.n_speakers = n_speakers
860
+ self.gin_channels = gin_channels
861
+ self.n_layers_trans_flow = n_layers_trans_flow
862
+ self.use_spk_conditioned_encoder = kwargs.get(
863
+ "use_spk_conditioned_encoder", True
864
+ )
865
+ self.use_sdp = use_sdp
866
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
867
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
868
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
869
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
870
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
871
+ self.enc_gin_channels = gin_channels
872
+ self.enc_p = TextEncoder(
873
+ n_vocab,
874
+ inter_channels,
875
+ hidden_channels,
876
+ filter_channels,
877
+ n_heads,
878
+ n_layers,
879
+ kernel_size,
880
+ p_dropout,
881
+ gin_channels=self.enc_gin_channels,
882
+ )
883
+ self.dec = Generator(
884
+ inter_channels,
885
+ resblock,
886
+ resblock_kernel_sizes,
887
+ resblock_dilation_sizes,
888
+ upsample_rates,
889
+ upsample_initial_channel,
890
+ upsample_kernel_sizes,
891
+ gin_channels=gin_channels,
892
+ )
893
+ self.enc_q = PosteriorEncoder(
894
+ spec_channels,
895
+ inter_channels,
896
+ hidden_channels,
897
+ 5,
898
+ 1,
899
+ 16,
900
+ gin_channels=gin_channels,
901
+ )
902
+ if use_transformer_flow:
903
+ self.flow = TransformerCouplingBlock(
904
+ inter_channels,
905
+ hidden_channels,
906
+ filter_channels,
907
+ n_heads,
908
+ n_layers_trans_flow,
909
+ 5,
910
+ p_dropout,
911
+ n_flow_layer,
912
+ gin_channels=gin_channels,
913
+ share_parameter=flow_share_parameter,
914
+ )
915
+ else:
916
+ self.flow = ResidualCouplingBlock(
917
+ inter_channels,
918
+ hidden_channels,
919
+ 5,
920
+ 1,
921
+ n_flow_layer,
922
+ gin_channels=gin_channels,
923
+ )
924
+ self.sdp = StochasticDurationPredictor(
925
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
926
+ )
927
+ self.dp = DurationPredictor(
928
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
929
+ )
930
+
931
+ if n_speakers >= 1:
932
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
933
+ else:
934
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels)
935
+
936
+ def forward(
937
+ self,
938
+ x,
939
+ x_lengths,
940
+ y,
941
+ y_lengths,
942
+ sid,
943
+ tone,
944
+ language,
945
+ bert,
946
+ style_vec,
947
+ ):
948
+ if self.n_speakers > 0:
949
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
950
+ else:
951
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
952
+ x, m_p, logs_p, x_mask = self.enc_p(
953
+ x, x_lengths, tone, language, bert, style_vec, g=g
954
+ )
955
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
956
+ z_p = self.flow(z, y_mask, g=g)
957
+
958
+ with torch.no_grad():
959
+ # negative cross-entropy
960
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
961
+ neg_cent1 = torch.sum(
962
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
963
+ ) # [b, 1, t_s]
964
+ neg_cent2 = torch.matmul(
965
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
966
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
967
+ neg_cent3 = torch.matmul(
968
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
969
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
970
+ neg_cent4 = torch.sum(
971
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
972
+ ) # [b, 1, t_s]
973
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
974
+ if self.use_noise_scaled_mas:
975
+ epsilon = (
976
+ torch.std(neg_cent)
977
+ * torch.randn_like(neg_cent)
978
+ * self.current_mas_noise_scale
979
+ )
980
+ neg_cent = neg_cent + epsilon
981
+
982
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
983
+ attn = (
984
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
985
+ .unsqueeze(1)
986
+ .detach()
987
+ )
988
+
989
+ w = attn.sum(2)
990
+
991
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
992
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
993
+
994
+ logw_ = torch.log(w + 1e-6) * x_mask
995
+ logw = self.dp(x, x_mask, g=g)
996
+ # logw_sdp = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
997
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
998
+ x_mask
999
+ ) # for averaging
1000
+ # l_length_sdp += torch.sum((logw_sdp - logw_) ** 2, [1, 2]) / torch.sum(x_mask)
1001
+
1002
+ l_length = l_length_dp + l_length_sdp
1003
+
1004
+ # expand prior
1005
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1006
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1007
+
1008
+ z_slice, ids_slice = commons.rand_slice_segments(
1009
+ z, y_lengths, self.segment_size
1010
+ )
1011
+ o = self.dec(z_slice, g=g)
1012
+ return (
1013
+ o,
1014
+ l_length,
1015
+ attn,
1016
+ ids_slice,
1017
+ x_mask,
1018
+ y_mask,
1019
+ (z, z_p, m_p, logs_p, m_q, logs_q),
1020
+ (x, logw, logw_), # , logw_sdp),
1021
+ g,
1022
+ )
1023
+
1024
+ def infer(
1025
+ self,
1026
+ x,
1027
+ x_lengths,
1028
+ sid,
1029
+ tone,
1030
+ language,
1031
+ bert,
1032
+ style_vec,
1033
+ noise_scale=0.667,
1034
+ length_scale=1,
1035
+ noise_scale_w=0.8,
1036
+ max_len=None,
1037
+ sdp_ratio=0,
1038
+ y=None,
1039
+ ):
1040
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
1041
+ # g = self.gst(y)
1042
+ if self.n_speakers > 0:
1043
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1044
+ else:
1045
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
1046
+ x, m_p, logs_p, x_mask = self.enc_p(
1047
+ x, x_lengths, tone, language, bert, style_vec, g=g
1048
+ )
1049
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
1050
+ sdp_ratio
1051
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1052
+ w = torch.exp(logw) * x_mask * length_scale
1053
+ w_ceil = torch.ceil(w)
1054
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1055
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1056
+ x_mask.dtype
1057
+ )
1058
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1059
+ attn = commons.generate_path(w_ceil, attn_mask)
1060
+
1061
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1062
+ 1, 2
1063
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1064
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1065
+ 1, 2
1066
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1067
+
1068
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1069
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1070
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1071
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
modules.py ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import Conv1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, weight_norm
9
+
10
+ import commons
11
+ from attentions import Encoder
12
+ from commons import get_padding, init_weights
13
+ from transforms import piecewise_rational_quadratic_transform
14
+
15
+ LRELU_SLOPE = 0.1
16
+
17
+
18
+ class LayerNorm(nn.Module):
19
+ def __init__(self, channels, eps=1e-5):
20
+ super().__init__()
21
+ self.channels = channels
22
+ self.eps = eps
23
+
24
+ self.gamma = nn.Parameter(torch.ones(channels))
25
+ self.beta = nn.Parameter(torch.zeros(channels))
26
+
27
+ def forward(self, x):
28
+ x = x.transpose(1, -1)
29
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
30
+ return x.transpose(1, -1)
31
+
32
+
33
+ class ConvReluNorm(nn.Module):
34
+ def __init__(
35
+ self,
36
+ in_channels,
37
+ hidden_channels,
38
+ out_channels,
39
+ kernel_size,
40
+ n_layers,
41
+ p_dropout,
42
+ ):
43
+ super().__init__()
44
+ self.in_channels = in_channels
45
+ self.hidden_channels = hidden_channels
46
+ self.out_channels = out_channels
47
+ self.kernel_size = kernel_size
48
+ self.n_layers = n_layers
49
+ self.p_dropout = p_dropout
50
+ assert n_layers > 1, "Number of layers should be larger than 0."
51
+
52
+ self.conv_layers = nn.ModuleList()
53
+ self.norm_layers = nn.ModuleList()
54
+ self.conv_layers.append(
55
+ nn.Conv1d(
56
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
57
+ )
58
+ )
59
+ self.norm_layers.append(LayerNorm(hidden_channels))
60
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
61
+ for _ in range(n_layers - 1):
62
+ self.conv_layers.append(
63
+ nn.Conv1d(
64
+ hidden_channels,
65
+ hidden_channels,
66
+ kernel_size,
67
+ padding=kernel_size // 2,
68
+ )
69
+ )
70
+ self.norm_layers.append(LayerNorm(hidden_channels))
71
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
72
+ self.proj.weight.data.zero_()
73
+ self.proj.bias.data.zero_()
74
+
75
+ def forward(self, x, x_mask):
76
+ x_org = x
77
+ for i in range(self.n_layers):
78
+ x = self.conv_layers[i](x * x_mask)
79
+ x = self.norm_layers[i](x)
80
+ x = self.relu_drop(x)
81
+ x = x_org + self.proj(x)
82
+ return x * x_mask
83
+
84
+
85
+ class DDSConv(nn.Module):
86
+ """
87
+ Dialted and Depth-Separable Convolution
88
+ """
89
+
90
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
91
+ super().__init__()
92
+ self.channels = channels
93
+ self.kernel_size = kernel_size
94
+ self.n_layers = n_layers
95
+ self.p_dropout = p_dropout
96
+
97
+ self.drop = nn.Dropout(p_dropout)
98
+ self.convs_sep = nn.ModuleList()
99
+ self.convs_1x1 = nn.ModuleList()
100
+ self.norms_1 = nn.ModuleList()
101
+ self.norms_2 = nn.ModuleList()
102
+ for i in range(n_layers):
103
+ dilation = kernel_size**i
104
+ padding = (kernel_size * dilation - dilation) // 2
105
+ self.convs_sep.append(
106
+ nn.Conv1d(
107
+ channels,
108
+ channels,
109
+ kernel_size,
110
+ groups=channels,
111
+ dilation=dilation,
112
+ padding=padding,
113
+ )
114
+ )
115
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
116
+ self.norms_1.append(LayerNorm(channels))
117
+ self.norms_2.append(LayerNorm(channels))
118
+
119
+ def forward(self, x, x_mask, g=None):
120
+ if g is not None:
121
+ x = x + g
122
+ for i in range(self.n_layers):
123
+ y = self.convs_sep[i](x * x_mask)
124
+ y = self.norms_1[i](y)
125
+ y = F.gelu(y)
126
+ y = self.convs_1x1[i](y)
127
+ y = self.norms_2[i](y)
128
+ y = F.gelu(y)
129
+ y = self.drop(y)
130
+ x = x + y
131
+ return x * x_mask
132
+
133
+
134
+ class WN(torch.nn.Module):
135
+ def __init__(
136
+ self,
137
+ hidden_channels,
138
+ kernel_size,
139
+ dilation_rate,
140
+ n_layers,
141
+ gin_channels=0,
142
+ p_dropout=0,
143
+ ):
144
+ super(WN, self).__init__()
145
+ assert kernel_size % 2 == 1
146
+ self.hidden_channels = hidden_channels
147
+ self.kernel_size = (kernel_size,)
148
+ self.dilation_rate = dilation_rate
149
+ self.n_layers = n_layers
150
+ self.gin_channels = gin_channels
151
+ self.p_dropout = p_dropout
152
+
153
+ self.in_layers = torch.nn.ModuleList()
154
+ self.res_skip_layers = torch.nn.ModuleList()
155
+ self.drop = nn.Dropout(p_dropout)
156
+
157
+ if gin_channels != 0:
158
+ cond_layer = torch.nn.Conv1d(
159
+ gin_channels, 2 * hidden_channels * n_layers, 1
160
+ )
161
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
162
+
163
+ for i in range(n_layers):
164
+ dilation = dilation_rate**i
165
+ padding = int((kernel_size * dilation - dilation) / 2)
166
+ in_layer = torch.nn.Conv1d(
167
+ hidden_channels,
168
+ 2 * hidden_channels,
169
+ kernel_size,
170
+ dilation=dilation,
171
+ padding=padding,
172
+ )
173
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
174
+ self.in_layers.append(in_layer)
175
+
176
+ # last one is not necessary
177
+ if i < n_layers - 1:
178
+ res_skip_channels = 2 * hidden_channels
179
+ else:
180
+ res_skip_channels = hidden_channels
181
+
182
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
183
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
184
+ self.res_skip_layers.append(res_skip_layer)
185
+
186
+ def forward(self, x, x_mask, g=None, **kwargs):
187
+ output = torch.zeros_like(x)
188
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
189
+
190
+ if g is not None:
191
+ g = self.cond_layer(g)
192
+
193
+ for i in range(self.n_layers):
194
+ x_in = self.in_layers[i](x)
195
+ if g is not None:
196
+ cond_offset = i * 2 * self.hidden_channels
197
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
198
+ else:
199
+ g_l = torch.zeros_like(x_in)
200
+
201
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
202
+ acts = self.drop(acts)
203
+
204
+ res_skip_acts = self.res_skip_layers[i](acts)
205
+ if i < self.n_layers - 1:
206
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
207
+ x = (x + res_acts) * x_mask
208
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
209
+ else:
210
+ output = output + res_skip_acts
211
+ return output * x_mask
212
+
213
+ def remove_weight_norm(self):
214
+ if self.gin_channels != 0:
215
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
216
+ for l in self.in_layers:
217
+ torch.nn.utils.remove_weight_norm(l)
218
+ for l in self.res_skip_layers:
219
+ torch.nn.utils.remove_weight_norm(l)
220
+
221
+
222
+ class ResBlock1(torch.nn.Module):
223
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
224
+ super(ResBlock1, self).__init__()
225
+ self.convs1 = nn.ModuleList(
226
+ [
227
+ weight_norm(
228
+ Conv1d(
229
+ channels,
230
+ channels,
231
+ kernel_size,
232
+ 1,
233
+ dilation=dilation[0],
234
+ padding=get_padding(kernel_size, dilation[0]),
235
+ )
236
+ ),
237
+ weight_norm(
238
+ Conv1d(
239
+ channels,
240
+ channels,
241
+ kernel_size,
242
+ 1,
243
+ dilation=dilation[1],
244
+ padding=get_padding(kernel_size, dilation[1]),
245
+ )
246
+ ),
247
+ weight_norm(
248
+ Conv1d(
249
+ channels,
250
+ channels,
251
+ kernel_size,
252
+ 1,
253
+ dilation=dilation[2],
254
+ padding=get_padding(kernel_size, dilation[2]),
255
+ )
256
+ ),
257
+ ]
258
+ )
259
+ self.convs1.apply(init_weights)
260
+
261
+ self.convs2 = nn.ModuleList(
262
+ [
263
+ weight_norm(
264
+ Conv1d(
265
+ channels,
266
+ channels,
267
+ kernel_size,
268
+ 1,
269
+ dilation=1,
270
+ padding=get_padding(kernel_size, 1),
271
+ )
272
+ ),
273
+ weight_norm(
274
+ Conv1d(
275
+ channels,
276
+ channels,
277
+ kernel_size,
278
+ 1,
279
+ dilation=1,
280
+ padding=get_padding(kernel_size, 1),
281
+ )
282
+ ),
283
+ weight_norm(
284
+ Conv1d(
285
+ channels,
286
+ channels,
287
+ kernel_size,
288
+ 1,
289
+ dilation=1,
290
+ padding=get_padding(kernel_size, 1),
291
+ )
292
+ ),
293
+ ]
294
+ )
295
+ self.convs2.apply(init_weights)
296
+
297
+ def forward(self, x, x_mask=None):
298
+ for c1, c2 in zip(self.convs1, self.convs2):
299
+ xt = F.leaky_relu(x, LRELU_SLOPE)
300
+ if x_mask is not None:
301
+ xt = xt * x_mask
302
+ xt = c1(xt)
303
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
304
+ if x_mask is not None:
305
+ xt = xt * x_mask
306
+ xt = c2(xt)
307
+ x = xt + x
308
+ if x_mask is not None:
309
+ x = x * x_mask
310
+ return x
311
+
312
+ def remove_weight_norm(self):
313
+ for l in self.convs1:
314
+ remove_weight_norm(l)
315
+ for l in self.convs2:
316
+ remove_weight_norm(l)
317
+
318
+
319
+ class ResBlock2(torch.nn.Module):
320
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
321
+ super(ResBlock2, self).__init__()
322
+ self.convs = nn.ModuleList(
323
+ [
324
+ weight_norm(
325
+ Conv1d(
326
+ channels,
327
+ channels,
328
+ kernel_size,
329
+ 1,
330
+ dilation=dilation[0],
331
+ padding=get_padding(kernel_size, dilation[0]),
332
+ )
333
+ ),
334
+ weight_norm(
335
+ Conv1d(
336
+ channels,
337
+ channels,
338
+ kernel_size,
339
+ 1,
340
+ dilation=dilation[1],
341
+ padding=get_padding(kernel_size, dilation[1]),
342
+ )
343
+ ),
344
+ ]
345
+ )
346
+ self.convs.apply(init_weights)
347
+
348
+ def forward(self, x, x_mask=None):
349
+ for c in self.convs:
350
+ xt = F.leaky_relu(x, LRELU_SLOPE)
351
+ if x_mask is not None:
352
+ xt = xt * x_mask
353
+ xt = c(xt)
354
+ x = xt + x
355
+ if x_mask is not None:
356
+ x = x * x_mask
357
+ return x
358
+
359
+ def remove_weight_norm(self):
360
+ for l in self.convs:
361
+ remove_weight_norm(l)
362
+
363
+
364
+ class Log(nn.Module):
365
+ def forward(self, x, x_mask, reverse=False, **kwargs):
366
+ if not reverse:
367
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
368
+ logdet = torch.sum(-y, [1, 2])
369
+ return y, logdet
370
+ else:
371
+ x = torch.exp(x) * x_mask
372
+ return x
373
+
374
+
375
+ class Flip(nn.Module):
376
+ def forward(self, x, *args, reverse=False, **kwargs):
377
+ x = torch.flip(x, [1])
378
+ if not reverse:
379
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
380
+ return x, logdet
381
+ else:
382
+ return x
383
+
384
+
385
+ class ElementwiseAffine(nn.Module):
386
+ def __init__(self, channels):
387
+ super().__init__()
388
+ self.channels = channels
389
+ self.m = nn.Parameter(torch.zeros(channels, 1))
390
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
391
+
392
+ def forward(self, x, x_mask, reverse=False, **kwargs):
393
+ if not reverse:
394
+ y = self.m + torch.exp(self.logs) * x
395
+ y = y * x_mask
396
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
397
+ return y, logdet
398
+ else:
399
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
400
+ return x
401
+
402
+
403
+ class ResidualCouplingLayer(nn.Module):
404
+ def __init__(
405
+ self,
406
+ channels,
407
+ hidden_channels,
408
+ kernel_size,
409
+ dilation_rate,
410
+ n_layers,
411
+ p_dropout=0,
412
+ gin_channels=0,
413
+ mean_only=False,
414
+ ):
415
+ assert channels % 2 == 0, "channels should be divisible by 2"
416
+ super().__init__()
417
+ self.channels = channels
418
+ self.hidden_channels = hidden_channels
419
+ self.kernel_size = kernel_size
420
+ self.dilation_rate = dilation_rate
421
+ self.n_layers = n_layers
422
+ self.half_channels = channels // 2
423
+ self.mean_only = mean_only
424
+
425
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
426
+ self.enc = WN(
427
+ hidden_channels,
428
+ kernel_size,
429
+ dilation_rate,
430
+ n_layers,
431
+ p_dropout=p_dropout,
432
+ gin_channels=gin_channels,
433
+ )
434
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
435
+ self.post.weight.data.zero_()
436
+ self.post.bias.data.zero_()
437
+
438
+ def forward(self, x, x_mask, g=None, reverse=False):
439
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
440
+ h = self.pre(x0) * x_mask
441
+ h = self.enc(h, x_mask, g=g)
442
+ stats = self.post(h) * x_mask
443
+ if not self.mean_only:
444
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
445
+ else:
446
+ m = stats
447
+ logs = torch.zeros_like(m)
448
+
449
+ if not reverse:
450
+ x1 = m + x1 * torch.exp(logs) * x_mask
451
+ x = torch.cat([x0, x1], 1)
452
+ logdet = torch.sum(logs, [1, 2])
453
+ return x, logdet
454
+ else:
455
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
456
+ x = torch.cat([x0, x1], 1)
457
+ return x
458
+
459
+
460
+ class ConvFlow(nn.Module):
461
+ def __init__(
462
+ self,
463
+ in_channels,
464
+ filter_channels,
465
+ kernel_size,
466
+ n_layers,
467
+ num_bins=10,
468
+ tail_bound=5.0,
469
+ ):
470
+ super().__init__()
471
+ self.in_channels = in_channels
472
+ self.filter_channels = filter_channels
473
+ self.kernel_size = kernel_size
474
+ self.n_layers = n_layers
475
+ self.num_bins = num_bins
476
+ self.tail_bound = tail_bound
477
+ self.half_channels = in_channels // 2
478
+
479
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
480
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
481
+ self.proj = nn.Conv1d(
482
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
483
+ )
484
+ self.proj.weight.data.zero_()
485
+ self.proj.bias.data.zero_()
486
+
487
+ def forward(self, x, x_mask, g=None, reverse=False):
488
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
489
+ h = self.pre(x0)
490
+ h = self.convs(h, x_mask, g=g)
491
+ h = self.proj(h) * x_mask
492
+
493
+ b, c, t = x0.shape
494
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
495
+
496
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
497
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
498
+ self.filter_channels
499
+ )
500
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
501
+
502
+ x1, logabsdet = piecewise_rational_quadratic_transform(
503
+ x1,
504
+ unnormalized_widths,
505
+ unnormalized_heights,
506
+ unnormalized_derivatives,
507
+ inverse=reverse,
508
+ tails="linear",
509
+ tail_bound=self.tail_bound,
510
+ )
511
+
512
+ x = torch.cat([x0, x1], 1) * x_mask
513
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
514
+ if not reverse:
515
+ return x, logdet
516
+ else:
517
+ return x
518
+
519
+
520
+ class TransformerCouplingLayer(nn.Module):
521
+ def __init__(
522
+ self,
523
+ channels,
524
+ hidden_channels,
525
+ kernel_size,
526
+ n_layers,
527
+ n_heads,
528
+ p_dropout=0,
529
+ filter_channels=0,
530
+ mean_only=False,
531
+ wn_sharing_parameter=None,
532
+ gin_channels=0,
533
+ ):
534
+ assert channels % 2 == 0, "channels should be divisible by 2"
535
+ super().__init__()
536
+ self.channels = channels
537
+ self.hidden_channels = hidden_channels
538
+ self.kernel_size = kernel_size
539
+ self.n_layers = n_layers
540
+ self.half_channels = channels // 2
541
+ self.mean_only = mean_only
542
+
543
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544
+ self.enc = (
545
+ Encoder(
546
+ hidden_channels,
547
+ filter_channels,
548
+ n_heads,
549
+ n_layers,
550
+ kernel_size,
551
+ p_dropout,
552
+ isflow=True,
553
+ gin_channels=gin_channels,
554
+ )
555
+ if wn_sharing_parameter is None
556
+ else wn_sharing_parameter
557
+ )
558
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559
+ self.post.weight.data.zero_()
560
+ self.post.bias.data.zero_()
561
+
562
+ def forward(self, x, x_mask, g=None, reverse=False):
563
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564
+ h = self.pre(x0) * x_mask
565
+ h = self.enc(h, x_mask, g=g)
566
+ stats = self.post(h) * x_mask
567
+ if not self.mean_only:
568
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569
+ else:
570
+ m = stats
571
+ logs = torch.zeros_like(m)
572
+
573
+ if not reverse:
574
+ x1 = m + x1 * torch.exp(logs) * x_mask
575
+ x = torch.cat([x0, x1], 1)
576
+ logdet = torch.sum(logs, [1, 2])
577
+ return x, logdet
578
+ else:
579
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
580
+ x = torch.cat([x0, x1], 1)
581
+ return x
preprocess_text.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from collections import defaultdict
5
+ from random import shuffle
6
+ from typing import Optional
7
+
8
+ import click
9
+ from tqdm import tqdm
10
+
11
+ from config import config
12
+ from text.cleaner import clean_text
13
+
14
+ preprocess_text_config = config.preprocess_text_config
15
+
16
+
17
+ @click.command()
18
+ @click.option(
19
+ "--transcription-path",
20
+ default=preprocess_text_config.transcription_path,
21
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
22
+ )
23
+ @click.option("--cleaned-path", default=preprocess_text_config.cleaned_path)
24
+ @click.option("--train-path", default=preprocess_text_config.train_path)
25
+ @click.option("--val-path", default=preprocess_text_config.val_path)
26
+ @click.option(
27
+ "--config-path",
28
+ default=preprocess_text_config.config_path,
29
+ type=click.Path(exists=True, file_okay=True, dir_okay=False),
30
+ )
31
+ @click.option("--val-per-lang", default=preprocess_text_config.val_per_lang)
32
+ @click.option("--max-val-total", default=preprocess_text_config.max_val_total)
33
+ @click.option("--clean/--no-clean", default=preprocess_text_config.clean)
34
+ @click.option("-y", "--yml_config")
35
+ def preprocess(
36
+ transcription_path: str,
37
+ cleaned_path: Optional[str],
38
+ train_path: str,
39
+ val_path: str,
40
+ config_path: str,
41
+ val_per_lang: int,
42
+ max_val_total: int,
43
+ clean: bool,
44
+ yml_config: str, # 这个不要删
45
+ ):
46
+ if cleaned_path == "" or cleaned_path is None:
47
+ cleaned_path = transcription_path + ".cleaned"
48
+
49
+ if clean:
50
+ with open(cleaned_path, "w", encoding="utf-8") as out_file:
51
+ with open(transcription_path, "r", encoding="utf-8") as trans_file:
52
+ lines = trans_file.readlines()
53
+ # print(lines, ' ', len(lines))
54
+ if len(lines) != 0:
55
+ for line in tqdm(lines, file=sys.stdout):
56
+ try:
57
+ utt, spk, language, text = line.strip().split("|")
58
+ norm_text, phones, tones, word2ph = clean_text(
59
+ text, language
60
+ )
61
+ out_file.write(
62
+ "{}|{}|{}|{}|{}|{}|{}\n".format(
63
+ utt,
64
+ spk,
65
+ language,
66
+ norm_text,
67
+ " ".join(phones),
68
+ " ".join([str(i) for i in tones]),
69
+ " ".join([str(i) for i in word2ph]),
70
+ )
71
+ )
72
+ except Exception as e:
73
+ print(line)
74
+ print(
75
+ f"An error occurred while generating the training set and validation set! Details:\n{e}"
76
+ )
77
+
78
+ transcription_path = cleaned_path
79
+ spk_utt_map = defaultdict(list)
80
+ spk_id_map = {}
81
+ current_sid = 0
82
+
83
+ with open(transcription_path, "r", encoding="utf-8") as f:
84
+ audioPaths = set()
85
+ countSame = 0
86
+ countNotFound = 0
87
+ for line in f.readlines():
88
+ utt, spk, language, text, phones, tones, word2ph = line.strip().split("|")
89
+ if utt in audioPaths:
90
+ # 过滤数据集错误:相同的音频匹配多个文本,导致后续bert出问题
91
+ print(f"Same audio matches multiple texts: {line}")
92
+ countSame += 1
93
+ continue
94
+ if not os.path.isfile(utt):
95
+ # 过滤数据集错误:不存在对应音频
96
+ print(f"Audio not found: {utt}")
97
+ countNotFound += 1
98
+ continue
99
+ audioPaths.add(utt)
100
+ spk_utt_map[language].append(line)
101
+ if spk not in spk_id_map.keys():
102
+ spk_id_map[spk] = current_sid
103
+ current_sid += 1
104
+ print(
105
+ f"Total repeated audios: {countSame}, Total number of audio not found: {countNotFound}"
106
+ )
107
+
108
+ train_list = []
109
+ val_list = []
110
+
111
+ for spk, utts in spk_utt_map.items():
112
+ shuffle(utts)
113
+ val_list += utts[:val_per_lang]
114
+ train_list += utts[val_per_lang:]
115
+
116
+ shuffle(val_list)
117
+ if len(val_list) > max_val_total:
118
+ train_list += val_list[max_val_total:]
119
+ val_list = val_list[:max_val_total]
120
+
121
+ with open(train_path, "w", encoding="utf-8") as f:
122
+ for line in train_list:
123
+ f.write(line)
124
+
125
+ with open(val_path, "w", encoding="utf-8") as f:
126
+ for line in val_list:
127
+ f.write(line)
128
+
129
+ json_config = json.load(open(config_path, encoding="utf-8"))
130
+ json_config["data"]["spk2id"] = spk_id_map
131
+ json_config["data"]["n_speakers"] = len(spk_id_map)
132
+ # 新增写入:写入训练版本、数据集路径
133
+ # json_config["version"] = latest_version
134
+ json_config["data"]["training_files"] = os.path.normpath(train_path).replace(
135
+ "\\", "/"
136
+ )
137
+ json_config["data"]["validation_files"] = os.path.normpath(val_path).replace(
138
+ "\\", "/"
139
+ )
140
+ with open(config_path, "w", encoding="utf-8") as f:
141
+ json.dump(json_config, f, indent=2, ensure_ascii=False)
142
+ print("Training set and validation set generation from texts is complete!")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ preprocess()
re_matching.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def extract_language_and_text_updated(speaker, dialogue):
5
+ # 使用正则表达式匹配<语言>标签和其后的文本
6
+ pattern_language_text = r"<(\S+?)>([^<]+)"
7
+ matches = re.findall(pattern_language_text, dialogue, re.DOTALL)
8
+ speaker = speaker[1:-1]
9
+ # 清理文本:去除两边的空白字符
10
+ matches_cleaned = [(lang.upper(), text.strip()) for lang, text in matches]
11
+ matches_cleaned.append(speaker)
12
+ return matches_cleaned
13
+
14
+
15
+ def validate_text(input_text):
16
+ # 验证说话人的正则表达式
17
+ pattern_speaker = r"(\[\S+?\])((?:\s*<\S+?>[^<\[\]]+?)+)"
18
+
19
+ # 使用re.DOTALL标志使.匹配包括换行符在内的所有字符
20
+ matches = re.findall(pattern_speaker, input_text, re.DOTALL)
21
+
22
+ # 对每个匹配到的说话人内容进行进一步验证
23
+ for _, dialogue in matches:
24
+ language_text_matches = extract_language_and_text_updated(_, dialogue)
25
+ if not language_text_matches:
26
+ return (
27
+ False,
28
+ "Error: Invalid format detected in dialogue content. Please check your input.",
29
+ )
30
+
31
+ # 如果输入的文本中没有找到任何匹配项
32
+ if not matches:
33
+ return (
34
+ False,
35
+ "Error: No valid speaker format detected. Please check your input.",
36
+ )
37
+
38
+ return True, "Input is valid."
39
+
40
+
41
+ def text_matching(text: str) -> list:
42
+ speaker_pattern = r"(\[\S+?\])(.+?)(?=\[\S+?\]|$)"
43
+ matches = re.findall(speaker_pattern, text, re.DOTALL)
44
+ result = []
45
+ for speaker, dialogue in matches:
46
+ result.append(extract_language_and_text_updated(speaker, dialogue))
47
+ return result
48
+
49
+
50
+ def cut_para(text):
51
+ splitted_para = re.split("[\n]", text) # 按段分
52
+ splitted_para = [
53
+ sentence.strip() for sentence in splitted_para if sentence.strip()
54
+ ] # 删除空字符串
55
+ return splitted_para
56
+
57
+
58
+ def cut_sent(para):
59
+ para = re.sub("([。!;?\?])([^”’])", r"\1\n\2", para) # 单字符断句符
60
+ para = re.sub("(\.{6})([^”’])", r"\1\n\2", para) # 英文省略号
61
+ para = re.sub("(\…{2})([^”’])", r"\1\n\2", para) # 中文省略号
62
+ para = re.sub("([。!?\?][”’])([^,。!?\?])", r"\1\n\2", para)
63
+ para = para.rstrip() # 段尾如果有多余的\n就去掉它
64
+ return para.split("\n")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ text = """
69
+ [说话人1]
70
+ [说话人2]<zh>你好吗?<jp>元気ですか?<jp>こんにちは,世界。<zh>你好吗?
71
+ [说话人3]<zh>谢谢。<jp>どういたしまして。
72
+ """
73
+ text_matching(text)
74
+ # 测试函数
75
+ test_text = """
76
+ [说话人1]<zh>你好,こんにちは!<jp>こんにちは,世界。
77
+ [说话人2]<zh>你好吗?
78
+ """
79
+ text_matching(test_text)
80
+ res = validate_text(test_text)
81
+ print(res)
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmudict
2
+ cn2an
3
+ g2p_en
4
+ GPUtil
5
+ gradio
6
+ jaconv
7
+ jieba
8
+ langid
9
+ librosa
10
+ loguru
11
+ matplotlib
12
+ mecab-python3
13
+ num2words
14
+ numba
15
+ numpy
16
+ psutil
17
+ pyannote.audio
18
+ pyopenjtalk-prebuilt
19
+ pypinyin
20
+ PyYAML
21
+ requests
22
+ safetensors
23
+ scipy
24
+ sentencepiece
25
+ tensorboard
26
+ torch>=2.1,<2.2 # For users without GPU or colab
27
+ transformers
server_fastapi.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ API server for TTS
3
+ """
4
+ import argparse
5
+ import os
6
+ import sys
7
+ from io import BytesIO
8
+ from typing import Dict, Optional, Union
9
+ from urllib.parse import unquote
10
+
11
+ import GPUtil
12
+ import psutil
13
+ import torch
14
+ import uvicorn
15
+ from fastapi import FastAPI, HTTPException, Query, Request, status
16
+ from fastapi.middleware.cors import CORSMiddleware
17
+ from fastapi.responses import FileResponse, Response
18
+ from scipy.io import wavfile
19
+
20
+ from common.constants import (
21
+ DEFAULT_ASSIST_TEXT_WEIGHT,
22
+ DEFAULT_LENGTH,
23
+ DEFAULT_LINE_SPLIT,
24
+ DEFAULT_NOISE,
25
+ DEFAULT_NOISEW,
26
+ DEFAULT_SDP_RATIO,
27
+ DEFAULT_SPLIT_INTERVAL,
28
+ DEFAULT_STYLE,
29
+ DEFAULT_STYLE_WEIGHT,
30
+ Languages,
31
+ )
32
+ from common.log import logger
33
+ from common.tts_model import Model, ModelHolder
34
+ from config import config
35
+
36
+ ln = config.server_config.language
37
+
38
+
39
+ def raise_validation_error(msg: str, param: str):
40
+ logger.warning(f"Validation error: {msg}")
41
+ raise HTTPException(
42
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
43
+ detail=[dict(type="invalid_params", msg=msg, loc=["query", param])],
44
+ )
45
+
46
+
47
+ class AudioResponse(Response):
48
+ media_type = "audio/wav"
49
+
50
+
51
+ def load_models(model_holder: ModelHolder):
52
+ model_holder.models = []
53
+ for model_name, model_paths in model_holder.model_files_dict.items():
54
+ model = Model(
55
+ model_path=model_paths[0],
56
+ config_path=os.path.join(model_holder.root_dir, model_name, "config.json"),
57
+ style_vec_path=os.path.join(
58
+ model_holder.root_dir, model_name, "style_vectors.npy"
59
+ ),
60
+ device=model_holder.device,
61
+ )
62
+ model.load_net_g()
63
+ model_holder.models.append(model)
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser()
68
+ parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
69
+ parser.add_argument(
70
+ "--dir", "-d", type=str, help="Model directory", default=config.assets_root
71
+ )
72
+ args = parser.parse_args()
73
+
74
+ if args.cpu:
75
+ device = "cpu"
76
+ else:
77
+ device = "cuda" if torch.cuda.is_available() else "cpu"
78
+
79
+ model_dir = args.dir
80
+ model_holder = ModelHolder(model_dir, device)
81
+ if len(model_holder.model_names) == 0:
82
+ logger.error(f"Models not found in {model_dir}.")
83
+ sys.exit(1)
84
+
85
+ logger.info("Loading models...")
86
+ load_models(model_holder)
87
+ limit = config.server_config.limit
88
+ app = FastAPI()
89
+ allow_origins = config.server_config.origins
90
+ if allow_origins:
91
+ logger.warning(
92
+ f"CORS allow_origins={config.server_config.origins}. If you don't want, modify config.yml"
93
+ )
94
+ app.add_middleware(
95
+ CORSMiddleware,
96
+ allow_origins=config.server_config.origins,
97
+ allow_credentials=True,
98
+ allow_methods=["*"],
99
+ allow_headers=["*"],
100
+ )
101
+ app.logger = logger
102
+
103
+ @app.get("/voice", response_class=AudioResponse)
104
+ async def voice(
105
+ request: Request,
106
+ text: str = Query(..., min_length=1, max_length=limit, description=f"セリフ"),
107
+ encoding: str = Query(None, description="textをURLデコードする(ex, `utf-8`)"),
108
+ model_id: int = Query(0, description="モデルID。`GET /models/info`のkeyの値を指定ください"),
109
+ speaker_name: str = Query(
110
+ None, description="話者名(speaker_idより優先)。esd.listの2列目の文字列を指定"
111
+ ),
112
+ speaker_id: int = Query(
113
+ 0, description="話者ID。model_assets>[model]>config.json内のspk2idを確認"
114
+ ),
115
+ sdp_ratio: float = Query(
116
+ DEFAULT_SDP_RATIO,
117
+ description="SDP(Stochastic Duration Predictor)/DP混合比。比率が高くなるほどトーンのばらつきが大きくなる",
118
+ ),
119
+ noise: float = Query(DEFAULT_NOISE, description="サンプルノイズの割合。大きくするほどランダム性が高まる"),
120
+ noisew: float = Query(
121
+ DEFAULT_NOISEW, description="SDPノイズ。大きくするほど発音の間隔にばらつきが出やすくなる"
122
+ ),
123
+ length: float = Query(
124
+ DEFAULT_LENGTH, description="話速。基準は1で大きくするほど音声は長くなり読み上げが遅まる"
125
+ ),
126
+ language: Languages = Query(ln, description=f"textの言語"),
127
+ auto_split: bool = Query(DEFAULT_LINE_SPLIT, description="改行で分けて生成"),
128
+ split_interval: float = Query(
129
+ DEFAULT_SPLIT_INTERVAL, description="分けた場合に挟む無音の長さ(秒)"
130
+ ),
131
+ assist_text: Optional[str] = Query(
132
+ None, description="このテキストの読み上げと似た声音・感情になりやすくなる。ただし抑揚やテンポ等が犠牲になる傾向がある"
133
+ ),
134
+ assist_text_weight: float = Query(
135
+ DEFAULT_ASSIST_TEXT_WEIGHT, description="assist_textの強さ"
136
+ ),
137
+ style: Optional[Union[int, str]] = Query(DEFAULT_STYLE, description="スタイル"),
138
+ style_weight: float = Query(DEFAULT_STYLE_WEIGHT, description="スタイルの強さ"),
139
+ reference_audio_path: Optional[str] = Query(None, description="スタイルを音声ファイルで行う"),
140
+ ):
141
+ """Infer text to speech(テキストから感情付き音声を生成する)"""
142
+ logger.info(
143
+ f"{request.client.host}:{request.client.port}/voice { unquote(str(request.query_params) )}"
144
+ )
145
+ if model_id >= len(model_holder.models): # /models/refresh があるためQuery(le)で表現不可
146
+ raise_validation_error(f"model_id={model_id} not found", "model_id")
147
+
148
+ model = model_holder.models[model_id]
149
+ if speaker_name is None:
150
+ if speaker_id not in model.id2spk.keys():
151
+ raise_validation_error(
152
+ f"speaker_id={speaker_id} not found", "speaker_id"
153
+ )
154
+ else:
155
+ if speaker_name not in model.spk2id.keys():
156
+ raise_validation_error(
157
+ f"speaker_name={speaker_name} not found", "speaker_name"
158
+ )
159
+ speaker_id = model.spk2id[speaker_name]
160
+ if style not in model.style2id.keys():
161
+ raise_validation_error(f"style={style} not found", "style")
162
+ if encoding is not None:
163
+ text = unquote(text, encoding=encoding)
164
+ sr, audio = model.infer(
165
+ text=text,
166
+ language=language,
167
+ sid=speaker_id,
168
+ reference_audio_path=reference_audio_path,
169
+ sdp_ratio=sdp_ratio,
170
+ noise=noise,
171
+ noisew=noisew,
172
+ length=length,
173
+ line_split=auto_split,
174
+ split_interval=split_interval,
175
+ assist_text=assist_text,
176
+ assist_text_weight=assist_text_weight,
177
+ use_assist_text=bool(assist_text),
178
+ style=style,
179
+ style_weight=style_weight,
180
+ )
181
+ logger.success("Audio data generated and sent successfully")
182
+ with BytesIO() as wavContent:
183
+ wavfile.write(wavContent, sr, audio)
184
+ return Response(content=wavContent.getvalue(), media_type="audio/wav")
185
+
186
+ @app.get("/models/info")
187
+ def get_loaded_models_info():
188
+ """ロードされたモデル情報の取得"""
189
+
190
+ result: Dict[str, Dict] = dict()
191
+ for model_id, model in enumerate(model_holder.models):
192
+ result[str(model_id)] = {
193
+ "config_path": model.config_path,
194
+ "model_path": model.model_path,
195
+ "device": model.device,
196
+ "spk2id": model.spk2id,
197
+ "id2spk": model.id2spk,
198
+ "style2id": model.style2id,
199
+ }
200
+ return result
201
+
202
+ @app.post("/models/refresh")
203
+ def refresh():
204
+ """モデルをパスに追加/削除した際などに読み込ませる"""
205
+ model_holder.refresh()
206
+ load_models(model_holder)
207
+ return get_loaded_models_info()
208
+
209
+ @app.get("/status")
210
+ def get_status():
211
+ """実行環境のステータスを取得"""
212
+ cpu_percent = psutil.cpu_percent(interval=1)
213
+ memory_info = psutil.virtual_memory()
214
+ memory_total = memory_info.total
215
+ memory_available = memory_info.available
216
+ memory_used = memory_info.used
217
+ memory_percent = memory_info.percent
218
+ gpuInfo = []
219
+ devices = ["cpu"]
220
+ for i in range(torch.cuda.device_count()):
221
+ devices.append(f"cuda:{i}")
222
+ gpus = GPUtil.getGPUs()
223
+ for gpu in gpus:
224
+ gpuInfo.append(
225
+ {
226
+ "gpu_id": gpu.id,
227
+ "gpu_load": gpu.load,
228
+ "gpu_memory": {
229
+ "total": gpu.memoryTotal,
230
+ "used": gpu.memoryUsed,
231
+ "free": gpu.memoryFree,
232
+ },
233
+ }
234
+ )
235
+ return {
236
+ "devices": devices,
237
+ "cpu_percent": cpu_percent,
238
+ "memory_total": memory_total,
239
+ "memory_available": memory_available,
240
+ "memory_used": memory_used,
241
+ "memory_percent": memory_percent,
242
+ "gpu": gpuInfo,
243
+ }
244
+
245
+ @app.get("/tools/get_audio", response_class=AudioResponse)
246
+ def get_audio(
247
+ request: Request, path: str = Query(..., description="local wav path")
248
+ ):
249
+ """wavデータを取得する"""
250
+ logger.info(
251
+ f"{request.client.host}:{request.client.port}/tools/get_audio { unquote(str(request.query_params) )}"
252
+ )
253
+ if not os.path.isfile(path):
254
+ raise_validation_error(f"path={path} not found", "path")
255
+ if not path.lower().endswith(".wav"):
256
+ raise_validation_error(f"wav file not found in {path}", "path")
257
+ return FileResponse(path=path, media_type="audio/wav")
258
+
259
+ logger.info(f"server listen: http://127.0.0.1:{config.server_config.port}")
260
+ logger.info(f"API docs: http://127.0.0.1:{config.server_config.port}/docs")
261
+ uvicorn.run(
262
+ app, port=config.server_config.port, host="0.0.0.0", log_level="warning"
263
+ )
spec_gen.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from tqdm import tqdm
3
+ from multiprocessing import Pool
4
+ from mel_processing import spectrogram_torch, mel_spectrogram_torch
5
+ from utils import load_wav_to_torch
6
+
7
+
8
+ class AudioProcessor:
9
+ def __init__(
10
+ self,
11
+ max_wav_value,
12
+ use_mel_spec_posterior,
13
+ filter_length,
14
+ n_mel_channels,
15
+ sampling_rate,
16
+ hop_length,
17
+ win_length,
18
+ mel_fmin,
19
+ mel_fmax,
20
+ ):
21
+ self.max_wav_value = max_wav_value
22
+ self.use_mel_spec_posterior = use_mel_spec_posterior
23
+ self.filter_length = filter_length
24
+ self.n_mel_channels = n_mel_channels
25
+ self.sampling_rate = sampling_rate
26
+ self.hop_length = hop_length
27
+ self.win_length = win_length
28
+ self.mel_fmin = mel_fmin
29
+ self.mel_fmax = mel_fmax
30
+
31
+ def process_audio(self, filename):
32
+ audio, sampling_rate = load_wav_to_torch(filename)
33
+ audio_norm = audio / self.max_wav_value
34
+ audio_norm = audio_norm.unsqueeze(0)
35
+ spec_filename = filename.replace(".wav", ".spec.pt")
36
+ if self.use_mel_spec_posterior:
37
+ spec_filename = spec_filename.replace(".spec.pt", ".mel.pt")
38
+ try:
39
+ spec = torch.load(spec_filename)
40
+ except:
41
+ if self.use_mel_spec_posterior:
42
+ spec = mel_spectrogram_torch(
43
+ audio_norm,
44
+ self.filter_length,
45
+ self.n_mel_channels,
46
+ self.sampling_rate,
47
+ self.hop_length,
48
+ self.win_length,
49
+ self.mel_fmin,
50
+ self.mel_fmax,
51
+ center=False,
52
+ )
53
+ else:
54
+ spec = spectrogram_torch(
55
+ audio_norm,
56
+ self.filter_length,
57
+ self.sampling_rate,
58
+ self.hop_length,
59
+ self.win_length,
60
+ center=False,
61
+ )
62
+ spec = torch.squeeze(spec, 0)
63
+ torch.save(spec, spec_filename)
64
+ return spec, audio_norm
65
+
66
+
67
+ # 使用示例
68
+ processor = AudioProcessor(
69
+ max_wav_value=32768.0,
70
+ use_mel_spec_posterior=False,
71
+ filter_length=2048,
72
+ n_mel_channels=128,
73
+ sampling_rate=44100,
74
+ hop_length=512,
75
+ win_length=2048,
76
+ mel_fmin=0.0,
77
+ mel_fmax="null",
78
+ )
79
+
80
+ with open("filelists/train.list", "r") as f:
81
+ filepaths = [line.split("|")[0] for line in f] # 取每一行的第一部分作为audiopath
82
+
83
+ # 使用多进程处理
84
+ with Pool(processes=32) as pool: # 使用4个进程
85
+ with tqdm(total=len(filepaths)) as pbar:
86
+ for i, _ in enumerate(pool.imap_unordered(processor.process_audio, filepaths)):
87
+ pbar.update()
style_gen.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ import utils
10
+ from common.log import logger
11
+ from common.stdout_wrapper import SAFE_STDOUT
12
+ from config import config
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ from pyannote.audio import Inference, Model
16
+
17
+ model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM")
18
+ inference = Inference(model, window="whole")
19
+ device = torch.device(config.style_gen_config.device)
20
+ inference.to(device)
21
+
22
+
23
+ class NaNValueError(ValueError):
24
+ """カスタム例外クラス。NaN値が見つかった場合に使用されます。"""
25
+
26
+ pass
27
+
28
+
29
+ # 推論時にインポートするために短いが関数を書く
30
+ def get_style_vector(wav_path):
31
+ return inference(wav_path)
32
+
33
+
34
+ def save_style_vector(wav_path):
35
+ try:
36
+ style_vec = get_style_vector(wav_path)
37
+ except Exception as e:
38
+ print("\n")
39
+ logger.error(f"Error occurred with file: {wav_path}, Details:\n{e}\n")
40
+ raise
41
+ # 値にNaNが含まれていると悪影響なのでチェックする
42
+ if np.isnan(style_vec).any():
43
+ print("\n")
44
+ logger.warning(f"NaN value found in style vector: {wav_path}")
45
+ raise NaNValueError(f"NaN value found in style vector: {wav_path}")
46
+ np.save(f"{wav_path}.npy", style_vec) # `test.wav` -> `test.wav.npy`
47
+
48
+
49
+ def process_line(line):
50
+ wavname = line.split("|")[0]
51
+ try:
52
+ save_style_vector(wavname)
53
+ return line, None
54
+ except NaNValueError:
55
+ return line, "nan_error"
56
+
57
+
58
+ def save_average_style_vector(style_vectors, filename="style_vectors.npy"):
59
+ average_vector = np.mean(style_vectors, axis=0)
60
+ np.save(filename, average_vector)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ parser = argparse.ArgumentParser()
65
+ parser.add_argument(
66
+ "-c", "--config", type=str, default=config.style_gen_config.config_path
67
+ )
68
+ parser.add_argument(
69
+ "--num_processes", type=int, default=config.style_gen_config.num_processes
70
+ )
71
+ args, _ = parser.parse_known_args()
72
+ config_path = args.config
73
+ num_processes = args.num_processes
74
+
75
+ hps = utils.get_hparams_from_file(config_path)
76
+
77
+ device = config.style_gen_config.device
78
+
79
+ training_lines = []
80
+ with open(hps.data.training_files, encoding="utf-8") as f:
81
+ training_lines.extend(f.readlines())
82
+ with ThreadPoolExecutor(max_workers=num_processes) as executor:
83
+ training_results = list(
84
+ tqdm(
85
+ executor.map(process_line, training_lines),
86
+ total=len(training_lines),
87
+ file=SAFE_STDOUT,
88
+ )
89
+ )
90
+ ok_training_lines = [line for line, error in training_results if error is None]
91
+ nan_training_lines = [
92
+ line for line, error in training_results if error == "nan_error"
93
+ ]
94
+ if nan_training_lines:
95
+ nan_files = [line.split("|")[0] for line in nan_training_lines]
96
+ logger.warning(
97
+ f"Found NaN value in {len(nan_training_lines)} files: {nan_files}, so they will be deleted from training data."
98
+ )
99
+
100
+ val_lines = []
101
+ with open(hps.data.validation_files, encoding="utf-8") as f:
102
+ val_lines.extend(f.readlines())
103
+
104
+ with ThreadPoolExecutor(max_workers=num_processes) as executor:
105
+ val_results = list(
106
+ tqdm(
107
+ executor.map(process_line, val_lines),
108
+ total=len(val_lines),
109
+ file=SAFE_STDOUT,
110
+ )
111
+ )
112
+ ok_val_lines = [line for line, error in val_results if error is None]
113
+ nan_val_lines = [line for line, error in val_results if error == "nan_error"]
114
+ if nan_val_lines:
115
+ nan_files = [line.split("|")[0] for line in nan_val_lines]
116
+ logger.warning(
117
+ f"Found NaN value in {len(nan_val_lines)} files: {nan_files}, so they will be deleted from validation data."
118
+ )
119
+
120
+ with open(hps.data.training_files, "w", encoding="utf-8") as f:
121
+ f.writelines(ok_training_lines)
122
+
123
+ with open(hps.data.validation_files, "w", encoding="utf-8") as f:
124
+ f.writelines(ok_val_lines)
125
+
126
+ ok_num = len(ok_training_lines) + len(ok_val_lines)
127
+
128
+ logger.info(f"Finished generating style vectors! total: {ok_num} npy files.")
transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ import subprocess
8
+
9
+ import numpy as np
10
+ import torch
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors import safe_open
13
+ from safetensors.torch import save_file
14
+ from scipy.io.wavfile import read
15
+
16
+ from common.log import logger
17
+
18
+ MATPLOTLIB_FLAG = False
19
+
20
+
21
+ def download_checkpoint(
22
+ dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
23
+ ):
24
+ repo_id = repo_config["repo_id"]
25
+ f_list = glob.glob(os.path.join(dir_path, regex))
26
+ if f_list:
27
+ print("Use existed model, skip downloading.")
28
+ return
29
+ for file in ["DUR_0.pth", "D_0.pth", "G_0.pth"]:
30
+ hf_hub_download(repo_id, file, local_dir=dir_path, local_dir_use_symlinks=False)
31
+
32
+
33
+ def load_checkpoint(
34
+ checkpoint_path, model, optimizer=None, skip_optimizer=False, for_infer=False
35
+ ):
36
+ assert os.path.isfile(checkpoint_path)
37
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
38
+ iteration = checkpoint_dict["iteration"]
39
+ learning_rate = checkpoint_dict["learning_rate"]
40
+ logger.info(
41
+ f"Loading model and optimizer at iteration {iteration} from {checkpoint_path}"
42
+ )
43
+ if (
44
+ optimizer is not None
45
+ and not skip_optimizer
46
+ and checkpoint_dict["optimizer"] is not None
47
+ ):
48
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
49
+ elif optimizer is None and not skip_optimizer:
50
+ # else: Disable this line if Infer and resume checkpoint,then enable the line upper
51
+ new_opt_dict = optimizer.state_dict()
52
+ new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
53
+ new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
54
+ new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
55
+ optimizer.load_state_dict(new_opt_dict)
56
+
57
+ saved_state_dict = checkpoint_dict["model"]
58
+ if hasattr(model, "module"):
59
+ state_dict = model.module.state_dict()
60
+ else:
61
+ state_dict = model.state_dict()
62
+
63
+ new_state_dict = {}
64
+ for k, v in state_dict.items():
65
+ try:
66
+ # assert "emb_g" not in k
67
+ new_state_dict[k] = saved_state_dict[k]
68
+ assert saved_state_dict[k].shape == v.shape, (
69
+ saved_state_dict[k].shape,
70
+ v.shape,
71
+ )
72
+ except:
73
+ # For upgrading from the old version
74
+ if "ja_bert_proj" in k:
75
+ v = torch.zeros_like(v)
76
+ logger.warning(
77
+ f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
78
+ )
79
+ elif "enc_q" in k and for_infer:
80
+ continue
81
+ else:
82
+ logger.error(f"{k} is not in the checkpoint {checkpoint_path}")
83
+
84
+ new_state_dict[k] = v
85
+
86
+ if hasattr(model, "module"):
87
+ model.module.load_state_dict(new_state_dict, strict=False)
88
+ else:
89
+ model.load_state_dict(new_state_dict, strict=False)
90
+
91
+ logger.info("Loaded '{}' (iteration {})".format(checkpoint_path, iteration))
92
+
93
+ return model, optimizer, learning_rate, iteration
94
+
95
+
96
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
97
+ logger.info(
98
+ "Saving model and optimizer state at iteration {} to {}".format(
99
+ iteration, checkpoint_path
100
+ )
101
+ )
102
+ if hasattr(model, "module"):
103
+ state_dict = model.module.state_dict()
104
+ else:
105
+ state_dict = model.state_dict()
106
+ torch.save(
107
+ {
108
+ "model": state_dict,
109
+ "iteration": iteration,
110
+ "optimizer": optimizer.state_dict(),
111
+ "learning_rate": learning_rate,
112
+ },
113
+ checkpoint_path,
114
+ )
115
+
116
+
117
+ def save_safetensors(model, iteration, checkpoint_path, is_half=False, for_infer=False):
118
+ """
119
+ Save model with safetensors.
120
+ """
121
+ if hasattr(model, "module"):
122
+ state_dict = model.module.state_dict()
123
+ else:
124
+ state_dict = model.state_dict()
125
+ keys = []
126
+ for k in state_dict:
127
+ if "enc_q" in k and for_infer:
128
+ continue # noqa: E701
129
+ keys.append(k)
130
+
131
+ new_dict = (
132
+ {k: state_dict[k].half() for k in keys}
133
+ if is_half
134
+ else {k: state_dict[k] for k in keys}
135
+ )
136
+ new_dict["iteration"] = torch.LongTensor([iteration])
137
+ logger.info(f"Saved safetensors to {checkpoint_path}")
138
+ save_file(new_dict, checkpoint_path)
139
+
140
+
141
+ def load_safetensors(checkpoint_path, model, for_infer=False):
142
+ """
143
+ Load safetensors model.
144
+ """
145
+
146
+ tensors = {}
147
+ iteration = None
148
+ with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
149
+ for key in f.keys():
150
+ if key == "iteration":
151
+ iteration = f.get_tensor(key).item()
152
+ tensors[key] = f.get_tensor(key)
153
+ if hasattr(model, "module"):
154
+ result = model.module.load_state_dict(tensors, strict=False)
155
+ else:
156
+ result = model.load_state_dict(tensors, strict=False)
157
+ for key in result.missing_keys:
158
+ if key.startswith("enc_q") and for_infer:
159
+ continue
160
+ logger.warning(f"Missing key: {key}")
161
+ for key in result.unexpected_keys:
162
+ if key == "iteration":
163
+ continue
164
+ logger.warning(f"Unexpected key: {key}")
165
+ if iteration is None:
166
+ logger.info(f"Loaded '{checkpoint_path}'")
167
+ else:
168
+ logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})")
169
+ return model, iteration
170
+
171
+
172
+ def summarize(
173
+ writer,
174
+ global_step,
175
+ scalars={},
176
+ histograms={},
177
+ images={},
178
+ audios={},
179
+ audio_sampling_rate=22050,
180
+ ):
181
+ for k, v in scalars.items():
182
+ writer.add_scalar(k, v, global_step)
183
+ for k, v in histograms.items():
184
+ writer.add_histogram(k, v, global_step)
185
+ for k, v in images.items():
186
+ writer.add_image(k, v, global_step, dataformats="HWC")
187
+ for k, v in audios.items():
188
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
189
+
190
+
191
+ def is_resuming(dir_path):
192
+ # JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する
193
+ g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
194
+ # d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
195
+ # dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
196
+ return len(g_list) > 0
197
+
198
+
199
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
200
+ f_list = glob.glob(os.path.join(dir_path, regex))
201
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
202
+ try:
203
+ x = f_list[-1]
204
+ except IndexError:
205
+ raise ValueError(f"No checkpoint found in {dir_path} with regex {regex}")
206
+ return x
207
+
208
+
209
+ def plot_spectrogram_to_numpy(spectrogram):
210
+ global MATPLOTLIB_FLAG
211
+ if not MATPLOTLIB_FLAG:
212
+ import matplotlib
213
+
214
+ matplotlib.use("Agg")
215
+ MATPLOTLIB_FLAG = True
216
+ mpl_logger = logging.getLogger("matplotlib")
217
+ mpl_logger.setLevel(logging.WARNING)
218
+ import matplotlib.pylab as plt
219
+ import numpy as np
220
+
221
+ fig, ax = plt.subplots(figsize=(10, 2))
222
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
223
+ plt.colorbar(im, ax=ax)
224
+ plt.xlabel("Frames")
225
+ plt.ylabel("Channels")
226
+ plt.tight_layout()
227
+
228
+ fig.canvas.draw()
229
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
230
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
231
+ plt.close()
232
+ return data
233
+
234
+
235
+ def plot_alignment_to_numpy(alignment, info=None):
236
+ global MATPLOTLIB_FLAG
237
+ if not MATPLOTLIB_FLAG:
238
+ import matplotlib
239
+
240
+ matplotlib.use("Agg")
241
+ MATPLOTLIB_FLAG = True
242
+ mpl_logger = logging.getLogger("matplotlib")
243
+ mpl_logger.setLevel(logging.WARNING)
244
+ import matplotlib.pylab as plt
245
+ import numpy as np
246
+
247
+ fig, ax = plt.subplots(figsize=(6, 4))
248
+ im = ax.imshow(
249
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
250
+ )
251
+ fig.colorbar(im, ax=ax)
252
+ xlabel = "Decoder timestep"
253
+ if info is not None:
254
+ xlabel += "\n\n" + info
255
+ plt.xlabel(xlabel)
256
+ plt.ylabel("Encoder timestep")
257
+ plt.tight_layout()
258
+
259
+ fig.canvas.draw()
260
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
261
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
262
+ plt.close()
263
+ return data
264
+
265
+
266
+ def load_wav_to_torch(full_path):
267
+ sampling_rate, data = read(full_path)
268
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
269
+
270
+
271
+ def load_filepaths_and_text(filename, split="|"):
272
+ with open(filename, encoding="utf-8") as f:
273
+ filepaths_and_text = [line.strip().split(split) for line in f]
274
+ return filepaths_and_text
275
+
276
+
277
+ def get_hparams(init=True):
278
+ parser = argparse.ArgumentParser()
279
+ parser.add_argument(
280
+ "-c",
281
+ "--config",
282
+ type=str,
283
+ default="./configs/base.json",
284
+ help="JSON file for configuration",
285
+ )
286
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
287
+
288
+ args = parser.parse_args()
289
+ model_dir = os.path.join("./logs", args.model)
290
+
291
+ if not os.path.exists(model_dir):
292
+ os.makedirs(model_dir)
293
+
294
+ config_path = args.config
295
+ config_save_path = os.path.join(model_dir, "config.json")
296
+ if init:
297
+ with open(config_path, "r", encoding="utf-8") as f:
298
+ data = f.read()
299
+ with open(config_save_path, "w", encoding="utf-8") as f:
300
+ f.write(data)
301
+ else:
302
+ with open(config_save_path, "r", vencoding="utf-8") as f:
303
+ data = f.read()
304
+ config = json.loads(data)
305
+ hparams = HParams(**config)
306
+ hparams.model_dir = model_dir
307
+ return hparams
308
+
309
+
310
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
311
+ """Freeing up space by deleting saved ckpts
312
+
313
+ Arguments:
314
+ path_to_models -- Path to the model directory
315
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
316
+ sort_by_time -- True -> chronologically delete ckpts
317
+ False -> lexicographically delete ckpts
318
+ """
319
+ import re
320
+
321
+ ckpts_files = [
322
+ f
323
+ for f in os.listdir(path_to_models)
324
+ if os.path.isfile(os.path.join(path_to_models, f))
325
+ ]
326
+
327
+ def name_key(_f):
328
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
329
+
330
+ def time_key(_f):
331
+ return os.path.getmtime(os.path.join(path_to_models, _f))
332
+
333
+ sort_key = time_key if sort_by_time else name_key
334
+
335
+ def x_sorted(_x):
336
+ return sorted(
337
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
338
+ key=sort_key,
339
+ )
340
+
341
+ to_del = [
342
+ os.path.join(path_to_models, fn)
343
+ for fn in (
344
+ x_sorted("G_")[:-n_ckpts_to_keep]
345
+ + x_sorted("D_")[:-n_ckpts_to_keep]
346
+ + x_sorted("WD_")[:-n_ckpts_to_keep]
347
+ + x_sorted("DUR_")[:-n_ckpts_to_keep]
348
+ )
349
+ ]
350
+
351
+ def del_info(fn):
352
+ return logger.info(f"Free up space by deleting ckpt {fn}")
353
+
354
+ def del_routine(x):
355
+ return [os.remove(x), del_info(x)]
356
+
357
+ [del_routine(fn) for fn in to_del]
358
+
359
+
360
+ def get_hparams_from_dir(model_dir):
361
+ config_save_path = os.path.join(model_dir, "config.json")
362
+ with open(config_save_path, "r", encoding="utf-8") as f:
363
+ data = f.read()
364
+ config = json.loads(data)
365
+
366
+ hparams = HParams(**config)
367
+ hparams.model_dir = model_dir
368
+ return hparams
369
+
370
+
371
+ def get_hparams_from_file(config_path):
372
+ # print("config_path: ", config_path)
373
+ with open(config_path, "r", encoding="utf-8") as f:
374
+ data = f.read()
375
+ config = json.loads(data)
376
+
377
+ hparams = HParams(**config)
378
+ return hparams
379
+
380
+
381
+ def check_git_hash(model_dir):
382
+ source_dir = os.path.dirname(os.path.realpath(__file__))
383
+ if not os.path.exists(os.path.join(source_dir, ".git")):
384
+ logger.warning(
385
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
386
+ source_dir
387
+ )
388
+ )
389
+ return
390
+
391
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
392
+
393
+ path = os.path.join(model_dir, "githash")
394
+ if os.path.exists(path):
395
+ saved_hash = open(path).read()
396
+ if saved_hash != cur_hash:
397
+ logger.warning(
398
+ "git hash values are different. {}(saved) != {}(current)".format(
399
+ saved_hash[:8], cur_hash[:8]
400
+ )
401
+ )
402
+ else:
403
+ open(path, "w").write(cur_hash)
404
+
405
+
406
+ def get_logger(model_dir, filename="train.log"):
407
+ global logger
408
+ logger = logging.getLogger(os.path.basename(model_dir))
409
+ logger.setLevel(logging.DEBUG)
410
+
411
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
412
+ if not os.path.exists(model_dir):
413
+ os.makedirs(model_dir)
414
+ h = logging.FileHandler(os.path.join(model_dir, filename))
415
+ h.setLevel(logging.DEBUG)
416
+ h.setFormatter(formatter)
417
+ logger.addHandler(h)
418
+ return logger
419
+
420
+
421
+ class HParams:
422
+ def __init__(self, **kwargs):
423
+ for k, v in kwargs.items():
424
+ if type(v) == dict:
425
+ v = HParams(**v)
426
+ self[k] = v
427
+
428
+ def keys(self):
429
+ return self.__dict__.keys()
430
+
431
+ def items(self):
432
+ return self.__dict__.items()
433
+
434
+ def values(self):
435
+ return self.__dict__.values()
436
+
437
+ def __len__(self):
438
+ return len(self.__dict__)
439
+
440
+ def __getitem__(self, key):
441
+ return getattr(self, key)
442
+
443
+ def __setitem__(self, key, value):
444
+ return setattr(self, key, value)
445
+
446
+ def __contains__(self, key):
447
+ return key in self.__dict__
448
+
449
+ def __repr__(self):
450
+ return self.__dict__.__repr__()
451
+
452
+
453
+ def load_model(model_path, config_path):
454
+ hps = get_hparams_from_file(config_path)
455
+ net = SynthesizerTrn(
456
+ # len(symbols),
457
+ 108,
458
+ hps.data.filter_length // 2 + 1,
459
+ hps.train.segment_size // hps.data.hop_length,
460
+ n_speakers=hps.data.n_speakers,
461
+ **hps.model,
462
+ ).to("cpu")
463
+ _ = net.eval()
464
+ _ = load_checkpoint(model_path, net, None, skip_optimizer=True)
465
+ return net
466
+
467
+
468
+ def mix_model(
469
+ network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5)
470
+ ):
471
+ if hasattr(network1, "module"):
472
+ state_dict1 = network1.module.state_dict()
473
+ state_dict2 = network2.module.state_dict()
474
+ else:
475
+ state_dict1 = network1.state_dict()
476
+ state_dict2 = network2.state_dict()
477
+ for k in state_dict1.keys():
478
+ if k not in state_dict2.keys():
479
+ continue
480
+ if "enc_p" in k:
481
+ state_dict1[k] = (
482
+ state_dict1[k].clone() * tone_ratio[0]
483
+ + state_dict2[k].clone() * tone_ratio[1]
484
+ )
485
+ else:
486
+ state_dict1[k] = (
487
+ state_dict1[k].clone() * voice_ratio[0]
488
+ + state_dict2[k].clone() * voice_ratio[1]
489
+ )
490
+ for k in state_dict2.keys():
491
+ if k not in state_dict1.keys():
492
+ state_dict1[k] = state_dict2[k].clone()
493
+ torch.save(
494
+ {"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0},
495
+ output_path,
496
+ )
497
+
498
+
499
+ def get_steps(model_path):
500
+ matches = re.findall(r"\d+", model_path)
501
+ return matches[-1] if matches else None