alibabasglab commited on
Commit
87c9b2c
·
verified ·
1 Parent(s): 18ff1c0

Upload networks.py

Browse files
Files changed (1) hide show
  1. networks.py +603 -0
networks.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Authors: Shengkui Zhao, Zexu Pan
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import soundfile as sf
8
+ import os
9
+ import subprocess
10
+ import librosa
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ from pydub import AudioSegment
14
+ from utils.decode import decode_one_audio
15
+ from dataloader.dataloader import DataReader
16
+
17
+ MAX_WAV_VALUE = 32768.0
18
+
19
+ class SpeechModel:
20
+ """
21
+ The SpeechModel class is a base class designed to handle speech processing tasks,
22
+ such as loading, processing, and decoding audio data. It initializes the computational
23
+ device (CPU or GPU) and holds model-related attributes. The class is flexible and intended
24
+ to be extended by specific speech models for tasks like speech enhancement, speech separation,
25
+ target speaker extraction etc.
26
+
27
+ Attributes:
28
+ - args: Argument parser object that contains configuration settings.
29
+ - device: The device (CPU or GPU) on which the model will run.
30
+ - model: The actual model used for speech processing tasks (to be loaded by subclasses).
31
+ - name: A placeholder for the model's name.
32
+ - data: A dictionary to store any additional data related to the model, such as audio input.
33
+ """
34
+
35
+ def __init__(self, args):
36
+ """
37
+ Initializes the SpeechModel class by determining the computation device
38
+ (GPU or CPU) to be used for running the model, based on system availability.
39
+
40
+ Args:
41
+ - args: Argument parser object containing settings like whether to use CUDA (GPU) or not.
42
+ """
43
+ # Check if a GPU is available
44
+ if torch.cuda.is_available():
45
+ # Find the GPU with the most free memory using a custom method
46
+ free_gpu_id = self.get_free_gpu()
47
+ if free_gpu_id is not None:
48
+ args.use_cuda = 1
49
+ torch.cuda.set_device(free_gpu_id)
50
+ self.device = torch.device('cuda')
51
+ else:
52
+ # If no GPU is detected, use the CPU
53
+ #print("No GPU found. Using CPU.")
54
+ args.use_cuda = 0
55
+ self.device = torch.device('cpu')
56
+ else:
57
+ # If no GPU is detected, use the CPU
58
+ args.use_cuda = 0
59
+ self.device = torch.device('cpu')
60
+
61
+ self.args = args
62
+ self.model = None
63
+ self.name = None
64
+ self.data = {}
65
+ self.print = False
66
+
67
+ def get_free_gpu(self):
68
+ """
69
+ Identifies the GPU with the most free memory using 'nvidia-smi' and returns its index.
70
+
71
+ This function queries the available GPUs on the system and determines which one has
72
+ the highest amount of free memory. It uses the `nvidia-smi` command-line tool to gather
73
+ GPU memory usage data. If successful, it returns the index of the GPU with the most free memory.
74
+ If the query fails or an error occurs, it returns None.
75
+
76
+ Returns:
77
+ int: Index of the GPU with the most free memory, or None if no GPU is found or an error occurs.
78
+ """
79
+ try:
80
+ # Run nvidia-smi to query GPU memory usage and free memory
81
+ result = subprocess.run(['nvidia-smi', '--query-gpu=memory.used,memory.free', '--format=csv,nounits,noheader'], stdout=subprocess.PIPE)
82
+ gpu_info = result.stdout.decode('utf-8').strip().split('\n')
83
+
84
+ free_gpu = None
85
+ max_free_memory = 0
86
+ for i, info in enumerate(gpu_info):
87
+ used, free = map(int, info.split(','))
88
+ if free > max_free_memory:
89
+ max_free_memory = free
90
+ free_gpu = i
91
+ return free_gpu
92
+ except Exception as e:
93
+ print(f"Error finding free GPU: {e}")
94
+ return None
95
+
96
+ def download_model(self, model_name):
97
+ checkpoint_dir = self.args.checkpoint_dir
98
+ from huggingface_hub import snapshot_download
99
+ if not os.path.exists(checkpoint_dir):
100
+ os.makedirs(checkpoint_dir)
101
+ print(f'downloading checkpoint for {model_name}')
102
+ try:
103
+ snapshot_download(repo_id=f'alibabasglab/{model_name}', local_dir=checkpoint_dir)
104
+ return True
105
+ except:
106
+ return False
107
+
108
+ def load_model(self):
109
+ """
110
+ Loads a pre-trained model checkpoints from a specified directory. It checks for
111
+ the best model ('last_best_checkpoint') in the checkpoint directory. If a model is
112
+ found, it loads the model state into the current model instance.
113
+
114
+ If no checkpoint is found, it will try to download the model from huggingface.
115
+ If the downloading fails, it prints a warning message.
116
+
117
+ Steps:
118
+ - Search for the best model checkpoint or the most recent one.
119
+ - Load the model's state dictionary from the checkpoint file.
120
+
121
+ Raises:
122
+ - FileNotFoundError: If neither 'last_best_checkpoint' nor 'last_checkpoint' files are found.
123
+ """
124
+ # Define paths for the best model and the last checkpoint
125
+ best_name = os.path.join(self.args.checkpoint_dir, 'last_best_checkpoint')
126
+ # Check if the last best checkpoint exists
127
+ if not os.path.isfile(best_name):
128
+ if not self.download_model(self.name):
129
+ # If downloading is unsuccessful
130
+ print(f'Warning: Downloading model {self.name} is not successful. Please try again or manually download from https://huggingface.co/alibabasglab/{self.name}/tree/main !')
131
+ return
132
+
133
+ if isinstance(self.model, nn.ModuleList):
134
+ with open(best_name, 'r') as f:
135
+ model_name = f.readline().strip()
136
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name)
137
+ self._load_model(self.model[0], checkpoint_path, model_key='mossformer')
138
+ model_name = f.readline().strip()
139
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name)
140
+ self._load_model(self.model[1], checkpoint_path, model_key='generator')
141
+ else:
142
+ # Read the model's checkpoint name from the file
143
+ with open(best_name, 'r') as f:
144
+ model_name = f.readline().strip()
145
+ # Form the full path to the model's checkpoint
146
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, model_name)
147
+ self._load_model(self.model, checkpoint_path, model_key='model')
148
+
149
+ def _load_model(self, model, checkpoint_path, model_key=None):
150
+ # Load the checkpoint file into memory (map_location ensures compatibility with different devices)
151
+ checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
152
+ # Load the model's state dictionary (weights and biases) into the current model
153
+ if model_key in checkpoint:
154
+ pretrained_model = checkpoint[model_key]
155
+ else:
156
+ pretrained_model = checkpoint
157
+ state = model.state_dict()
158
+ for key in state.keys():
159
+ if key in pretrained_model and state[key].shape == pretrained_model[key].shape:
160
+ state[key] = pretrained_model[key]
161
+ elif key.replace('module.', '') in pretrained_model and state[key].shape == pretrained_model[key.replace('module.', '')].shape:
162
+ state[key] = pretrained_model[key.replace('module.', '')]
163
+ elif 'module.'+key in pretrained_model and state[key].shape == pretrained_model['module.'+key].shape:
164
+ state[key] = pretrained_model['module.'+key]
165
+ elif self.print: print(f'{key} not loaded')
166
+ model.load_state_dict(state)
167
+
168
+ def decode(self):
169
+ """
170
+ Decodes the input audio data using the loaded model and ensures the output matches the original audio length.
171
+
172
+ This method processes the audio through a speech model (e.g., for enhancement, separation, etc.),
173
+ and truncates the resulting audio to match the original input's length. The method supports multiple speakers
174
+ if the model handles multi-speaker audio.
175
+
176
+ Returns:
177
+ output_audio: The decoded audio after processing, truncated to the input audio length.
178
+ If multi-speaker audio is processed, a list of truncated audio outputs per speaker is returned.
179
+ """
180
+ # Decode the audio using the loaded model on the given device (e.g., CPU or GPU)
181
+ output_audios = []
182
+ for i in range(len(self.data['audio'])):
183
+ output_audio = decode_one_audio(self.model, self.device, self.data['audio'][i], self.args)
184
+ # Ensure the decoded output matches the length of the input audio
185
+ if isinstance(output_audio, list):
186
+ # If multi-speaker audio (a list of outputs), truncate each speaker's audio to input length
187
+ for spk in range(self.args.num_spks):
188
+ output_audio[spk] = output_audio[spk][:self.data['audio_len']]
189
+ else:
190
+ # Single output, truncate to input audio length
191
+ output_audio = output_audio[:self.data['audio_len']]
192
+ output_audios.append(output_audio)
193
+
194
+ if isinstance(output_audios[0], list):
195
+ output_audios_np = []
196
+ for spk in range(self.args.num_spks):
197
+ output_audio_buf = []
198
+ for i in range(len(output_audios)):
199
+ output_audio_buf.append(output_audios[i][spk])
200
+ #output_audio_buf = np.vstack((output_audio_buf, output_audios[i][spk])).T
201
+ output_audios_np.append(np.array(output_audio_buf))
202
+ else:
203
+ output_audios_np = np.array(output_audios)
204
+ return output_audios_np
205
+
206
+ def process(self, input_path, online_write=False, output_path=None):
207
+ """
208
+ Load and process audio files from the specified input path. Optionally,
209
+ write the output audio files to the specified output directory.
210
+
211
+ Args:
212
+ input_path (str): Path to the input audio files or folder.
213
+ online_write (bool): Whether to write the processed audio to disk in real-time.
214
+ output_path (str): Optional path for writing output files. If None, output
215
+ will be stored in self.result.
216
+
217
+ Returns:
218
+ dict or ndarray: Processed audio results either as a dictionary or as a single array,
219
+ depending on the number of audio files processed.
220
+ Returns None if online_write is enabled.
221
+ """
222
+
223
+ self.result = {}
224
+ self.args.input_path = input_path
225
+ data_reader = DataReader(self.args) # Initialize a data reader to load the audio files
226
+
227
+
228
+ # Check if online writing is enabled
229
+ if online_write:
230
+ output_wave_dir = self.args.output_dir # Set the default output directory
231
+ if isinstance(output_path, str): # If a specific output path is provided, use it
232
+ output_wave_dir = os.path.join(output_path, self.name)
233
+ # Create the output directory if it does not exist
234
+ if not os.path.isdir(output_wave_dir):
235
+ os.makedirs(output_wave_dir)
236
+
237
+ num_samples = len(data_reader) # Get the total number of samples to process
238
+ print(f'Running {self.name} ...') # Display the model being used
239
+
240
+ if self.args.task == 'target_speaker_extraction':
241
+ from utils.video_process import process_tse
242
+ assert online_write == True
243
+ process_tse(self.args, self.model, self.device, data_reader, output_wave_dir)
244
+ else:
245
+ # Disable gradient calculation for better efficiency during inference
246
+ with torch.no_grad():
247
+ for idx in tqdm(range(num_samples)): # Loop over all audio samples
248
+ self.data = {}
249
+ # Read the audio, waveform ID, and audio length from the data reader
250
+ input_audio, wav_id, input_len, scalars, audio_info = data_reader[idx]
251
+ # Store the input audio and metadata in self.data
252
+ self.data['audio'] = input_audio
253
+ self.data['id'] = wav_id
254
+ self.data['audio_len'] = input_len
255
+ self.data.update(audio_info)
256
+
257
+ # Perform the audio decoding/processing
258
+ output_audios = self.decode()
259
+
260
+ # Perform audio renormalization
261
+ if not isinstance(output_audios, list):
262
+ if len(scalars) > 1:
263
+ for i in range(len(scalars)):
264
+ output_audios[:,i] = output_audios[:,i] * scalars[i]
265
+ else:
266
+ output_audios = output_audios * scalars[0]
267
+
268
+ if online_write:
269
+ # If online writing is enabled, save the output audio to files
270
+ if isinstance(output_audios, list):
271
+ # In case of multi-speaker output, save each speaker's output separately
272
+ for spk in range(self.args.num_spks):
273
+ output_file = os.path.join(output_wave_dir, wav_id.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext']))
274
+ self.write_audio(output_file, key=None, spk=spk, audio=output_audios)
275
+ else:
276
+ # Single-speaker or standard output
277
+ output_file = os.path.join(output_wave_dir, wav_id)
278
+ self.write_audio(output_file, key=None, spk=None, audio=output_audios)
279
+ else:
280
+ # If not writing to disk, store the output in the result dictionary
281
+ self.result[wav_id] = output_audios
282
+
283
+ # Return the processed results if not writing to disk
284
+ if not online_write:
285
+ if len(self.result) == 1:
286
+ # If there is only one result, return it directly
287
+ return next(iter(self.result.values()))
288
+ else:
289
+ # Otherwise, return the entire result dictionary
290
+ return self.result
291
+
292
+ def write_audio(self, output_path, key=None, spk=None, audio=None):
293
+ """
294
+ This function writes an audio signal to an output file, applying necessary transformations
295
+ such as resampling, channel handling, and format conversion based on the provided parameters
296
+ and the instance's internal settings.
297
+
298
+ Args:
299
+ output_path (str): The file path where the audio will be saved.
300
+ key (str, optional): The key used to retrieve audio from the internal result dictionary
301
+ if audio is not provided.
302
+ spk (str, optional): A specific speaker identifier, used to extract a particular speaker's
303
+ audio from a multi-speaker dataset or result.
304
+ audio (numpy.ndarray, optional): A numpy array containing the audio data to be written.
305
+ If provided, key and spk are ignored.
306
+ """
307
+
308
+ if audio is not None:
309
+ if spk is not None:
310
+ result_ = audio[spk]
311
+ else:
312
+ result_ = audio
313
+ else:
314
+ if spk is not None:
315
+ result_ = self.result[key][spk]
316
+ else:
317
+ result_ = self.result[key]
318
+
319
+ if self.data['sample_rate'] != self.args.sampling_rate:
320
+ if self.data['channels'] == 2:
321
+ left_channel = librosa.resample(result_[0,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate'])
322
+ right_channel = librosa.resample(result_[1,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate'])
323
+ result = np.vstack((left_channel, right_channel)).T
324
+ else:
325
+ result = librosa.resample(result_[0,:], orig_sr=self.args.sampling_rate, target_sr=self.data['sample_rate'])
326
+ else:
327
+ if self.data['channels'] == 2:
328
+ left_channel = result_[0,:]
329
+ right_channel = result_[1,:]
330
+ result = np.vstack((left_channel, right_channel)).T
331
+ else:
332
+ result = result_[0,:]
333
+
334
+ if self.data['sample_width'] == 4: ##32 bit float
335
+ MAX_WAV_VALUE = 2147483648.0
336
+ np_type = np.int32
337
+ elif self.data['sample_width'] == 2: ##16 bit int
338
+ MAX_WAV_VALUE = 32768.0
339
+ np_type = np.int16
340
+ else:
341
+ self.data['sample_width'] = 2 ##16 bit int
342
+ MAX_WAV_VALUE = 32768.0
343
+ np_type = np.int16
344
+
345
+ result = result * MAX_WAV_VALUE
346
+ result = result.astype(np_type)
347
+ audio_segment = AudioSegment(
348
+ result.tobytes(), # Raw audio data as bytes
349
+ frame_rate=self.data['sample_rate'], # Sample rate
350
+ sample_width=self.data['sample_width'], # No. bytes per sample
351
+ channels=self.data['channels'] # No. channels
352
+ )
353
+ audio_format = 'ipod' if self.data['ext'] in ['m4a', 'aac'] else self.data['ext']
354
+ audio_segment.export(output_path, format=audio_format)
355
+
356
+ def write(self, output_path, add_subdir=False, use_key=False):
357
+ """
358
+ Write the processed audio results to the specified output path.
359
+
360
+ Args:
361
+ output_path (str): The directory or file path where processed audio will be saved. If not
362
+ provided, defaults to self.args.output_dir.
363
+ add_subdir (bool): If True, appends the model name as a subdirectory to the output path.
364
+ use_key (bool): If True, uses the result dictionary's keys (audio file IDs) for filenames.
365
+
366
+ Returns:
367
+ None: Outputs are written to disk, no data is returned.
368
+ """
369
+
370
+ # Ensure the output path is a string. If not provided, use the default output directory
371
+ if not isinstance(output_path, str):
372
+ output_path = self.args.output_dir
373
+
374
+ # If add_subdir is enabled, create a subdirectory for the model name
375
+ if add_subdir:
376
+ if os.path.isfile(output_path):
377
+ print(f'File exists: {output_path}, remove it and try again!')
378
+ return
379
+ output_path = os.path.join(output_path, self.name)
380
+ if not os.path.isdir(output_path):
381
+ os.makedirs(output_path)
382
+
383
+ # Ensure proper directory setup when using keys for filenames
384
+ if use_key and not os.path.isdir(output_path):
385
+ if os.path.exists(output_path):
386
+ print(f'File exists: {output_path}, remove it and try again!')
387
+ return
388
+ os.makedirs(output_path)
389
+ # If not using keys and output path is a directory, check for conflicts
390
+ if not use_key and os.path.isdir(output_path):
391
+ print(f'Directory exists: {output_path}, remove it and try again!')
392
+ return
393
+
394
+ # Iterate over the results dictionary to write the processed audio to disk
395
+ for key in self.result:
396
+ if use_key:
397
+ # If using keys, format filenames based on the result dictionary's keys (audio IDs)
398
+ if isinstance(self.result[key], list): # For multi-speaker outputs
399
+ for spk in range(self.args.num_spks):
400
+ output_file = os.path.join(output_path, key.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext']))
401
+ self.write_audio(output_file, key, spk)
402
+ else:
403
+ output_file = os.path.join(output_path, key)
404
+ self.write_audio(output_path, key)
405
+ else:
406
+ # If not using keys, write audio to the specified output path directly
407
+ if isinstance(self.result[key], list): # For multi-speaker outputs
408
+ for spk in range(self.args.num_spks):
409
+ output_file = output_path.replace('.'+self.data['ext'], f'_s{spk+1}.'+self.data['ext'])
410
+ self.write_audio(output_file, key, spk)
411
+ else:
412
+ self.write_audio(output_path, key)
413
+
414
+ # The model classes for specific sub-tasks
415
+
416
+ class CLS_FRCRN_SE_16K(SpeechModel):
417
+ """
418
+ A subclass of SpeechModel that implements a speech enhancement model using
419
+ the FRCRN architecture for 16 kHz speech enhancement.
420
+
421
+ Args:
422
+ args (Namespace): The argument parser containing model configurations and paths.
423
+ """
424
+
425
+ def __init__(self, args):
426
+ # Initialize the parent SpeechModel class
427
+ super(CLS_FRCRN_SE_16K, self).__init__(args)
428
+
429
+ # Import the FRCRN speech enhancement model for 16 kHz
430
+ from models.frcrn_se.frcrn import FRCRN_SE_16K
431
+
432
+ # Initialize the model
433
+ self.model = FRCRN_SE_16K(args).model
434
+ self.name = 'FRCRN_SE_16K'
435
+
436
+ # Load pre-trained model checkpoint
437
+ self.load_model()
438
+
439
+ # Move model to the appropriate device (GPU/CPU)
440
+ if args.use_cuda == 1:
441
+ self.model.to(self.device)
442
+
443
+ # Set the model to evaluation mode (no gradient calculation)
444
+ self.model.eval()
445
+
446
+ class CLS_MossFormer2_SE_48K(SpeechModel):
447
+ """
448
+ A subclass of SpeechModel that implements the MossFormer2 architecture for
449
+ 48 kHz speech enhancement.
450
+
451
+ Args:
452
+ args (Namespace): The argument parser containing model configurations and paths.
453
+ """
454
+
455
+ def __init__(self, args):
456
+ # Initialize the parent SpeechModel class
457
+ super(CLS_MossFormer2_SE_48K, self).__init__(args)
458
+
459
+ # Import the MossFormer2 speech enhancement model for 48 kHz
460
+ from models.mossformer2_se.mossformer2_se_wrapper import MossFormer2_SE_48K
461
+
462
+ # Initialize the model
463
+ self.model = MossFormer2_SE_48K(args).model
464
+ self.name = 'MossFormer2_SE_48K'
465
+
466
+ # Load pre-trained model checkpoint
467
+ self.load_model()
468
+
469
+ # Move model to the appropriate device (GPU/CPU)
470
+ if args.use_cuda == 1:
471
+ self.model.to(self.device)
472
+
473
+ # Set the model to evaluation mode (no gradient calculation)
474
+ self.model.eval()
475
+
476
+ class CLS_MossFormer2_SR_48K(SpeechModel):
477
+ """
478
+ A subclass of SpeechModel that implements the MossFormer2 architecture for
479
+ 48 kHz speech super-resolution.
480
+
481
+ Args:
482
+ args (Namespace): The argument parser containing model configurations and paths.
483
+ """
484
+
485
+ def __init__(self, args):
486
+ # Initialize the parent SpeechModel class
487
+ super(CLS_MossFormer2_SR_48K, self).__init__(args)
488
+
489
+ # Import the MossFormer2 speech enhancement model for 48 kHz
490
+ from models.mossformer2_sr.mossformer2_sr_wrapper import MossFormer2_SR_48K
491
+
492
+ # Initialize the model
493
+ self.model = nn.ModuleList()
494
+ self.model.append(MossFormer2_SR_48K(args).model_m)
495
+ self.model.append(MossFormer2_SR_48K(args).model_g)
496
+ self.name = 'MossFormer2_SR_48K'
497
+
498
+ # Load pre-trained model checkpoint
499
+ self.load_model()
500
+
501
+ # Move model to the appropriate device (GPU/CPU)
502
+ if args.use_cuda == 1:
503
+ for model in self.model:
504
+ model.to(self.device)
505
+
506
+ # Set the model to evaluation mode (no gradient calculation)
507
+ for model in self.model:
508
+ model.eval()
509
+ self.model[1].remove_weight_norm()
510
+
511
+ class CLS_MossFormerGAN_SE_16K(SpeechModel):
512
+ """
513
+ A subclass of SpeechModel that implements the MossFormerGAN architecture for
514
+ 16 kHz speech enhancement, utilizing GAN-based speech processing.
515
+
516
+ Args:
517
+ args (Namespace): The argument parser containing model configurations and paths.
518
+ """
519
+
520
+ def __init__(self, args):
521
+ # Initialize the parent SpeechModel class
522
+ super(CLS_MossFormerGAN_SE_16K, self).__init__(args)
523
+
524
+ # Import the MossFormerGAN speech enhancement model for 16 kHz
525
+ from models.mossformer_gan_se.generator import MossFormerGAN_SE_16K
526
+
527
+ # Initialize the model
528
+ self.model = MossFormerGAN_SE_16K(args).model
529
+ self.name = 'MossFormerGAN_SE_16K'
530
+
531
+ # Load pre-trained model checkpoint
532
+ self.load_model()
533
+
534
+ # Move model to the appropriate device (GPU/CPU)
535
+ if args.use_cuda == 1:
536
+ self.model.to(self.device)
537
+
538
+ # Set the model to evaluation mode (no gradient calculation)
539
+ self.model.eval()
540
+
541
+ class CLS_MossFormer2_SS_16K(SpeechModel):
542
+ """
543
+ A subclass of SpeechModel that implements the MossFormer2 architecture for
544
+ 16 kHz speech separation.
545
+
546
+ Args:
547
+ args (Namespace): The argument parser containing model configurations and paths.
548
+ """
549
+
550
+ def __init__(self, args):
551
+ # Initialize the parent SpeechModel class
552
+ super(CLS_MossFormer2_SS_16K, self).__init__(args)
553
+
554
+ # Import the MossFormer2 speech separation model for 16 kHz
555
+ from models.mossformer2_ss.mossformer2 import MossFormer2_SS_16K
556
+
557
+ # Initialize the model
558
+ self.model = MossFormer2_SS_16K(args).model
559
+ self.name = 'MossFormer2_SS_16K'
560
+
561
+ # Load pre-trained model checkpoint
562
+ self.load_model()
563
+
564
+ # Move model to the appropriate device (GPU/CPU)
565
+ if args.use_cuda == 1:
566
+ self.model.to(self.device)
567
+
568
+ # Set the model to evaluation mode (no gradient calculation)
569
+ self.model.eval()
570
+
571
+
572
+ class CLS_AV_MossFormer2_TSE_16K(SpeechModel):
573
+ """
574
+ A subclass of SpeechModel that implements an audio-visual (AV) model using
575
+ the AV-MossFormer2 architecture for target speaker extraction (TSE) at 16 kHz.
576
+ This model leverages both audio and visual cues to perform speaker extraction.
577
+
578
+ Args:
579
+ args (Namespace): The argument parser containing model configurations and paths.
580
+ """
581
+
582
+ def __init__(self, args):
583
+ # Initialize the parent SpeechModel class
584
+ super(CLS_AV_MossFormer2_TSE_16K, self).__init__(args)
585
+
586
+ # Import the AV-MossFormer2 model for 16 kHz target speech enhancement
587
+ from models.av_mossformer2_tse.av_mossformer2 import AV_MossFormer2_TSE_16K
588
+
589
+ # Initialize the model
590
+ self.model = AV_MossFormer2_TSE_16K(args).model
591
+ self.name = 'AV_MossFormer2_TSE_16K'
592
+
593
+ # Load pre-trained model checkpoint
594
+ self.load_model()
595
+
596
+ # Move model to the appropriate device (GPU/CPU)
597
+ if args.use_cuda == 1:
598
+ self.model.to(self.device)
599
+
600
+ # Set the model to evaluation mode (no gradient calculation)
601
+ self.model.eval()
602
+
603
+