jbetker commited on
Commit
22c8aaf
·
1 Parent(s): 07a6edc

is this from tortoise?

Browse files
Files changed (3) hide show
  1. api.py +20 -1
  2. is_this_from_tortoise.py +14 -0
  3. models/classifier.py +14 -9
api.py CHANGED
@@ -8,6 +8,7 @@ import torch.nn.functional as F
8
  import progressbar
9
  import torchaudio
10
 
 
11
  from models.cvvp import CVVP
12
  from models.diffusion_decoder import DiffusionTts
13
  from models.autoregressive import UnifiedVoice
@@ -24,7 +25,7 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
24
  pbar = None
25
 
26
 
27
- def download_models():
28
  """
29
  Call to download all the models that Tortoise uses.
30
  """
@@ -50,6 +51,8 @@ def download_models():
50
  pbar.finish()
51
  pbar = None
52
  for model_name, url in MODELS.items():
 
 
53
  if os.path.exists(f'.models/{model_name}'):
54
  continue
55
  print(f'Downloading {model_name} from {url}...')
@@ -145,6 +148,22 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_sa
145
  return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  class TextToSpeech:
149
  """
150
  Main entry point into Tortoise.
 
8
  import progressbar
9
  import torchaudio
10
 
11
+ from models.classifier import AudioMiniEncoderWithClassifierHead
12
  from models.cvvp import CVVP
13
  from models.diffusion_decoder import DiffusionTts
14
  from models.autoregressive import UnifiedVoice
 
25
  pbar = None
26
 
27
 
28
+ def download_models(specific_models=None):
29
  """
30
  Call to download all the models that Tortoise uses.
31
  """
 
51
  pbar.finish()
52
  pbar = None
53
  for model_name, url in MODELS.items():
54
+ if specific_models is not None and model_name not in specific_models:
55
+ continue
56
  if os.path.exists(f'.models/{model_name}'):
57
  continue
58
  print(f'Downloading {model_name} from {url}...')
 
148
  return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
149
 
150
 
151
+ def classify_audio_clip(clip):
152
+ """
153
+ Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
154
+ :param clip: torch tensor containing audio waveform data (get it from load_audio)
155
+ :return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
156
+ """
157
+ download_models(['classifier'])
158
+ classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
159
+ resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
160
+ dropout=0, kernel_size=5, distribute_zero_label=False)
161
+ classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
162
+ clip = clip.cpu().unsqueeze(0)
163
+ results = F.softmax(classifier(clip), dim=-1)
164
+ return results[0][0]
165
+
166
+
167
  class TextToSpeech:
168
  """
169
  Main entry point into Tortoise.
is_this_from_tortoise.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ from api import classify_audio_clip
4
+ from utils.audio import load_audio
5
+
6
+ if __name__ == '__main__':
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="results/favorite_riding_hood.mp3")
9
+ args = parser.parse_args()
10
+
11
+ clip = load_audio(args.clip, 24000)
12
+ clip = clip[:, :220000]
13
+ prob = classify_audio_clip(clip)
14
+ print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.")
models/classifier.py CHANGED
@@ -1,4 +1,9 @@
1
  import torch
 
 
 
 
 
2
 
3
 
4
  class ResBlock(nn.Module):
@@ -27,7 +32,7 @@ class ResBlock(nn.Module):
27
  self.in_layers = nn.Sequential(
28
  normalization(channels),
29
  nn.SiLU(),
30
- conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
31
  )
32
 
33
  self.updown = up or down
@@ -46,18 +51,18 @@ class ResBlock(nn.Module):
46
  nn.SiLU(),
47
  nn.Dropout(p=dropout),
48
  zero_module(
49
- conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
50
  ),
51
  )
52
 
53
  if self.out_channels == channels:
54
  self.skip_connection = nn.Identity()
55
  elif use_conv:
56
- self.skip_connection = conv_nd(
57
  dims, channels, self.out_channels, kernel_size, padding=padding
58
  )
59
  else:
60
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
61
 
62
  def forward(self, x):
63
  if self.do_checkpoint:
@@ -94,21 +99,21 @@ class AudioMiniEncoder(nn.Module):
94
  kernel_size=3):
95
  super().__init__()
96
  self.init = nn.Sequential(
97
- conv_nd(1, spec_dim, base_channels, 3, padding=1)
98
  )
99
  ch = base_channels
100
  res = []
101
  self.layers = depth
102
  for l in range(depth):
103
  for r in range(resnet_blocks):
104
- res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size))
105
- res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor))
106
  ch *= 2
107
  self.res = nn.Sequential(*res)
108
  self.final = nn.Sequential(
109
  normalization(ch),
110
  nn.SiLU(),
111
- conv_nd(1, ch, embedding_dim, 1)
112
  )
113
  attn = []
114
  for a in range(attn_blocks):
@@ -118,7 +123,7 @@ class AudioMiniEncoder(nn.Module):
118
 
119
  def forward(self, x):
120
  h = self.init(x)
121
- h = sequential_checkpoint(self.res, self.layers, h)
122
  h = self.final(h)
123
  for blk in self.attn:
124
  h = checkpoint(blk, h)
 
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.checkpoint import checkpoint
5
+
6
+ from models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
7
 
8
 
9
  class ResBlock(nn.Module):
 
32
  self.in_layers = nn.Sequential(
33
  normalization(channels),
34
  nn.SiLU(),
35
+ nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
36
  )
37
 
38
  self.updown = up or down
 
51
  nn.SiLU(),
52
  nn.Dropout(p=dropout),
53
  zero_module(
54
+ nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
55
  ),
56
  )
57
 
58
  if self.out_channels == channels:
59
  self.skip_connection = nn.Identity()
60
  elif use_conv:
61
+ self.skip_connection = nn.Conv1d(
62
  dims, channels, self.out_channels, kernel_size, padding=padding
63
  )
64
  else:
65
+ self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
66
 
67
  def forward(self, x):
68
  if self.do_checkpoint:
 
99
  kernel_size=3):
100
  super().__init__()
101
  self.init = nn.Sequential(
102
+ nn.Conv1d(spec_dim, base_channels, 3, padding=1)
103
  )
104
  ch = base_channels
105
  res = []
106
  self.layers = depth
107
  for l in range(depth):
108
  for r in range(resnet_blocks):
109
+ res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
110
+ res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
111
  ch *= 2
112
  self.res = nn.Sequential(*res)
113
  self.final = nn.Sequential(
114
  normalization(ch),
115
  nn.SiLU(),
116
+ nn.Conv1d(ch, embedding_dim, 1)
117
  )
118
  attn = []
119
  for a in range(attn_blocks):
 
123
 
124
  def forward(self, x):
125
  h = self.init(x)
126
+ h = self.res(h)
127
  h = self.final(h)
128
  for blk in self.attn:
129
  h = checkpoint(blk, h)