Awell00 commited on
Commit
757c094
·
verified ·
1 Parent(s): b9d3d3a

feat: add inline comments throughout app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -21
app.py CHANGED
@@ -13,29 +13,59 @@ import spaces
13
  from pydub.exceptions import CouldntEncodeError
14
  from transformers import pipeline
15
 
 
16
  model = pipeline('text-generation', model='EleutherAI/gpt-neo-125M')
17
 
 
18
  OUTPUT_FOLDER = "separation_results/"
19
  INPUT_FOLDER = "input"
20
  download_path = ""
21
 
22
  def sanitize_filename(filename):
 
 
 
 
 
 
 
 
 
23
  return re.sub(r'[\\/*?:"<>|]', '_', filename)
24
 
25
  def delete_input_files(input_dir):
 
 
 
 
 
 
26
  wav_dir = Path(input_dir) / "wav"
27
  for wav_file in wav_dir.glob("*.wav"):
28
  wav_file.unlink()
29
  print(f"Deleted {wav_file}")
30
 
31
  def standardize_title(input_title):
 
 
 
 
 
 
 
 
 
 
32
  title_cleaned = re.sub(r"[\(\[].*?[\)\]]", "", input_title)
33
 
 
34
  unnecessary_words = ["official", "video", "hd", "4k", "lyrics", "music", "audio", "visualizer", "remix"]
35
  title_cleaned = re.sub(r"\b(?:{})\b".format("|".join(unnecessary_words)), "", title_cleaned, flags=re.IGNORECASE)
36
 
 
37
  parts = re.split(r"\s*-\s*|\s*,\s*", title_cleaned)
38
 
 
39
  if len(parts) >= 2:
40
  title_part = parts[-1].strip()
41
  artist_part = ', '.join(parts[:-1]).strip()
@@ -43,27 +73,38 @@ def standardize_title(input_title):
43
  artist_part = "Unknown Artist"
44
  title_part = title_cleaned.strip()
45
 
 
46
  if "with" in input_title.lower() or "feat" in input_title.lower():
47
  match = re.search(r"\((with|feat\.?) (.*?)\)", input_title, re.IGNORECASE)
48
  if match:
49
  additional_artist = match.group(2).strip()
50
  artist_part = f"{artist_part}, {additional_artist}" if artist_part != "Unknown Artist" else additional_artist
51
 
 
52
  artist_part = re.sub(r'\s+', ' ', artist_part).title()
53
  title_part = re.sub(r'\s+', ' ', title_part).title()
54
 
 
55
  standardized_output = f"{artist_part} - {title_part}"
56
 
57
  return standardized_output.strip()
58
 
59
  def handle_file_upload(file):
 
 
 
 
 
 
 
 
 
60
  if file is None:
61
  return None, "No file uploaded"
62
 
63
  filename = os.path.basename(file.name)
64
 
65
  formatted_title = standardize_title(filename)
66
-
67
  formatted_title = sanitize_filename(formatted_title.strip())
68
 
69
  input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav")
@@ -73,8 +114,21 @@ def handle_file_upload(file):
73
 
74
  return input_path, formatted_title
75
 
76
-
77
  def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  command = [
79
  "python", "inference.py",
80
  "--model_type", model_type,
@@ -87,6 +141,12 @@ def run_inference(model_type, config_path, start_check_point, input_dir, output_
87
  return subprocess.run(command, check=True, capture_output=True, text=True)
88
 
89
  def move_stems_to_parent(input_dir):
 
 
 
 
 
 
90
  for subdir, dirs, files in os.walk(input_dir):
91
  if subdir == input_dir:
92
  continue
@@ -94,42 +154,51 @@ def move_stems_to_parent(input_dir):
94
  parent_dir = os.path.dirname(subdir)
95
  song_name = os.path.basename(parent_dir)
96
 
 
97
  if 'htdemucs' in subdir:
98
- print(f"Processing htdemucs in {subdir}")
99
  bass_path = os.path.join(subdir, f"{song_name}_bass.wav")
100
  if os.path.exists(bass_path):
101
  new_bass_path = os.path.join(parent_dir, "bass.wav")
102
- print(f"Moving {bass_path} to {new_bass_path}")
103
  shutil.move(bass_path, new_bass_path)
104
  else:
105
  print(f"Bass file not found: {bass_path}")
 
 
106
  elif 'mel_band_roformer' in subdir:
107
- print(f"Processing mel_band_roformer in {subdir}")
108
  vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav")
109
  if os.path.exists(vocals_path):
110
  new_vocals_path = os.path.join(parent_dir, "vocals.wav")
111
- print(f"Moving {vocals_path} to {new_vocals_path}")
112
  shutil.move(vocals_path, new_vocals_path)
113
  else:
114
  print(f"Vocals file not found: {vocals_path}")
 
 
115
  elif 'scnet' in subdir:
116
- print(f"Processing scnet in {subdir}")
117
  other_path = os.path.join(subdir, f"{song_name}_other.wav")
118
  if os.path.exists(other_path):
119
  new_other_path = os.path.join(parent_dir, "other.wav")
120
- print(f"Moving {other_path} to {new_other_path}")
121
  shutil.move(other_path, new_other_path)
122
  else:
123
  print(f"Other file not found: {other_path}")
 
 
124
  elif 'bs_roformer' in subdir:
125
- print(f"Processing bs_roformer in {subdir}")
126
  instrumental_path = os.path.join(subdir, f"{song_name}_other.wav")
127
  if os.path.exists(instrumental_path):
128
  new_instrumental_path = os.path.join(parent_dir, "instrumental.wav")
129
- print(f"Moving {instrumental_path} to {new_instrumental_path}")
130
  shutil.move(instrumental_path, new_instrumental_path)
131
 
132
  def combine_stems_for_all(input_dir, output_format):
 
 
 
 
 
 
 
 
 
 
133
  for subdir, _, _ in os.walk(input_dir):
134
  if subdir == input_dir:
135
  continue
@@ -144,20 +213,22 @@ def combine_stems_for_all(input_dir, output_format):
144
  "instrumental": os.path.join(subdir, "instrumental.wav")
145
  }
146
 
 
147
  if not all(os.path.exists(path) for path in stem_paths.values()):
148
  print(f"Skipping {subdir}, not all stems are present.")
149
  continue
150
 
 
151
  stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()}
152
  combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"])
153
 
154
- # Trim silence from the end of the combined audio
155
  trimmed_combined = trim_silence_at_end(combined)
156
 
157
- # Determine the output file format and codec
158
  output_file = os.path.join(subdir, f"{song_name}")
159
 
160
  try:
 
161
  if output_format == "m4a":
162
  trimmed_combined.export(output_file, format="ipod", codec="aac")
163
  else:
@@ -171,11 +242,15 @@ def combine_stems_for_all(input_dir, output_format):
171
 
172
  def trim_silence_at_end(audio_segment, silence_thresh=-50, chunk_size=10):
173
  """
174
- Trims silence at the end of an AudioSegment.
175
- :param audio_segment: The audio segment to trim.
176
- :param silence_thresh: The threshold in dB below which is considered silence.
177
- :param chunk_size: The size of the chunks in milliseconds that are checked for silence.
178
- :return: A trimmed AudioSegment with silence removed from the end.
 
 
 
 
179
  """
180
  silence_end = silence.detect_silence(audio_segment, min_silence_len=chunk_size, silence_thresh=silence_thresh)
181
 
@@ -186,6 +261,12 @@ def trim_silence_at_end(audio_segment, silence_thresh=-50, chunk_size=10):
186
  return audio_segment
187
 
188
  def delete_folders_and_files(input_dir):
 
 
 
 
 
 
189
  folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer']
190
  files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav']
191
 
@@ -193,18 +274,21 @@ def delete_folders_and_files(input_dir):
193
  if root == input_dir:
194
  continue
195
 
 
196
  for folder in folders_to_delete:
197
  folder_path = os.path.join(root, folder)
198
  if os.path.isdir(folder_path):
199
  print(f"Deleting folder: {folder_path}")
200
  shutil.rmtree(folder_path)
201
 
 
202
  for file in files_to_delete:
203
  file_path = os.path.join(root, file)
204
  if os.path.isfile(file_path):
205
  print(f"Deleting file: {file_path}")
206
  os.remove(file_path)
207
 
 
208
  for root, dirs, files in os.walk(OUTPUT_FOLDER):
209
  for dir_name in dirs:
210
  if dir_name.endswith('_vocals'):
@@ -214,8 +298,17 @@ def delete_folders_and_files(input_dir):
214
 
215
  print("Cleanup completed.")
216
 
217
- @spaces.GPU(duration=120) # Adjust the duration as needed
218
  def process_audio(uploaded_file):
 
 
 
 
 
 
 
 
 
219
  try:
220
  yield "Processing audio...", None
221
 
@@ -226,6 +319,7 @@ def process_audio(uploaded_file):
226
  else:
227
  raise ValueError("Please upload a WAV file.")
228
 
 
229
  yield "Starting SCNet inference...", None
230
  proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
231
 
@@ -235,14 +329,15 @@ def process_audio(uploaded_file):
235
  yield "Starting HTDemucs inference...", None
236
  proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
237
 
 
238
  source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav'
239
  destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav'
240
-
241
  os.rename(source_path, destination_path)
242
 
243
  yield "Starting BS Roformer inference...", None
244
  proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER)
245
 
 
246
  yield "Moving input files...", None
247
  delete_input_files(INPUT_FOLDER)
248
 
@@ -261,13 +356,12 @@ def process_audio(uploaded_file):
261
  logging.error(error_msg)
262
  yield error_msg, None
263
 
 
264
  with gr.Blocks() as demo:
265
  gr.Markdown("# Music Player and Processor")
266
 
267
  file_upload = gr.File(label="Upload WAV file", file_types=[".m4a"])
268
-
269
  process_button = gr.Button("Process Audio")
270
-
271
  log_output = gr.Textbox(label="Processing Log", interactive=False)
272
  processed_audio_output = gr.File(label="Processed Audio")
273
 
@@ -278,4 +372,5 @@ with gr.Blocks() as demo:
278
  show_progress=True
279
  )
280
 
 
281
  demo.launch()
 
13
  from pydub.exceptions import CouldntEncodeError
14
  from transformers import pipeline
15
 
16
+ # Initialize text generation model
17
  model = pipeline('text-generation', model='EleutherAI/gpt-neo-125M')
18
 
19
+ # Define constants
20
  OUTPUT_FOLDER = "separation_results/"
21
  INPUT_FOLDER = "input"
22
  download_path = ""
23
 
24
  def sanitize_filename(filename):
25
+ """
26
+ Remove special characters from filename to ensure it's valid across different file systems.
27
+
28
+ Args:
29
+ filename (str): The original filename
30
+
31
+ Returns:
32
+ str: Sanitized filename
33
+ """
34
  return re.sub(r'[\\/*?:"<>|]', '_', filename)
35
 
36
  def delete_input_files(input_dir):
37
+ """
38
+ Delete all WAV files in the input directory.
39
+
40
+ Args:
41
+ input_dir (str): Path to the input directory
42
+ """
43
  wav_dir = Path(input_dir) / "wav"
44
  for wav_file in wav_dir.glob("*.wav"):
45
  wav_file.unlink()
46
  print(f"Deleted {wav_file}")
47
 
48
  def standardize_title(input_title):
49
+ """
50
+ Standardize the title format by removing unnecessary words and rearranging artist and title.
51
+
52
+ Args:
53
+ input_title (str): The original title
54
+
55
+ Returns:
56
+ str: Standardized title in "Artist - Title" format
57
+ """
58
+ # Remove content within parentheses or brackets
59
  title_cleaned = re.sub(r"[\(\[].*?[\)\]]", "", input_title)
60
 
61
+ # Remove unnecessary words
62
  unnecessary_words = ["official", "video", "hd", "4k", "lyrics", "music", "audio", "visualizer", "remix"]
63
  title_cleaned = re.sub(r"\b(?:{})\b".format("|".join(unnecessary_words)), "", title_cleaned, flags=re.IGNORECASE)
64
 
65
+ # Split title into parts
66
  parts = re.split(r"\s*-\s*|\s*,\s*", title_cleaned)
67
 
68
+ # Determine artist and title parts
69
  if len(parts) >= 2:
70
  title_part = parts[-1].strip()
71
  artist_part = ', '.join(parts[:-1]).strip()
 
73
  artist_part = "Unknown Artist"
74
  title_part = title_cleaned.strip()
75
 
76
+ # Handle "with" or "feat" in the title
77
  if "with" in input_title.lower() or "feat" in input_title.lower():
78
  match = re.search(r"\((with|feat\.?) (.*?)\)", input_title, re.IGNORECASE)
79
  if match:
80
  additional_artist = match.group(2).strip()
81
  artist_part = f"{artist_part}, {additional_artist}" if artist_part != "Unknown Artist" else additional_artist
82
 
83
+ # Clean up and capitalize
84
  artist_part = re.sub(r'\s+', ' ', artist_part).title()
85
  title_part = re.sub(r'\s+', ' ', title_part).title()
86
 
87
+ # Combine artist and title
88
  standardized_output = f"{artist_part} - {title_part}"
89
 
90
  return standardized_output.strip()
91
 
92
  def handle_file_upload(file):
93
+ """
94
+ Handle file upload, standardize the filename, and copy it to the input folder.
95
+
96
+ Args:
97
+ file: Uploaded file object
98
+
99
+ Returns:
100
+ tuple: (input_path, formatted_title) or (None, error_message)
101
+ """
102
  if file is None:
103
  return None, "No file uploaded"
104
 
105
  filename = os.path.basename(file.name)
106
 
107
  formatted_title = standardize_title(filename)
 
108
  formatted_title = sanitize_filename(formatted_title.strip())
109
 
110
  input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav")
 
114
 
115
  return input_path, formatted_title
116
 
 
117
  def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
118
+ """
119
+ Run inference using the specified model and parameters.
120
+
121
+ Args:
122
+ model_type (str): Type of the model
123
+ config_path (str): Path to the model configuration
124
+ start_check_point (str): Path to the model checkpoint
125
+ input_dir (str): Input directory
126
+ output_dir (str): Output directory
127
+ device_ids (str): GPU device IDs to use
128
+
129
+ Returns:
130
+ subprocess.CompletedProcess: Result of the subprocess run
131
+ """
132
  command = [
133
  "python", "inference.py",
134
  "--model_type", model_type,
 
141
  return subprocess.run(command, check=True, capture_output=True, text=True)
142
 
143
  def move_stems_to_parent(input_dir):
144
+ """
145
+ Move generated stem files to their parent directories.
146
+
147
+ Args:
148
+ input_dir (str): Input directory containing stem folders
149
+ """
150
  for subdir, dirs, files in os.walk(input_dir):
151
  if subdir == input_dir:
152
  continue
 
154
  parent_dir = os.path.dirname(subdir)
155
  song_name = os.path.basename(parent_dir)
156
 
157
+ # Move bass stem
158
  if 'htdemucs' in subdir:
 
159
  bass_path = os.path.join(subdir, f"{song_name}_bass.wav")
160
  if os.path.exists(bass_path):
161
  new_bass_path = os.path.join(parent_dir, "bass.wav")
 
162
  shutil.move(bass_path, new_bass_path)
163
  else:
164
  print(f"Bass file not found: {bass_path}")
165
+
166
+ # Move vocals stem
167
  elif 'mel_band_roformer' in subdir:
 
168
  vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav")
169
  if os.path.exists(vocals_path):
170
  new_vocals_path = os.path.join(parent_dir, "vocals.wav")
 
171
  shutil.move(vocals_path, new_vocals_path)
172
  else:
173
  print(f"Vocals file not found: {vocals_path}")
174
+
175
+ # Move other stem
176
  elif 'scnet' in subdir:
 
177
  other_path = os.path.join(subdir, f"{song_name}_other.wav")
178
  if os.path.exists(other_path):
179
  new_other_path = os.path.join(parent_dir, "other.wav")
 
180
  shutil.move(other_path, new_other_path)
181
  else:
182
  print(f"Other file not found: {other_path}")
183
+
184
+ # Move instrumental stem
185
  elif 'bs_roformer' in subdir:
 
186
  instrumental_path = os.path.join(subdir, f"{song_name}_other.wav")
187
  if os.path.exists(instrumental_path):
188
  new_instrumental_path = os.path.join(parent_dir, "instrumental.wav")
 
189
  shutil.move(instrumental_path, new_instrumental_path)
190
 
191
  def combine_stems_for_all(input_dir, output_format):
192
+ """
193
+ Combine all stems for each song in the input directory.
194
+
195
+ Args:
196
+ input_dir (str): Input directory containing song folders
197
+ output_format (str): Output audio format (e.g., 'm4a')
198
+
199
+ Returns:
200
+ str: Path to the combined audio file
201
+ """
202
  for subdir, _, _ in os.walk(input_dir):
203
  if subdir == input_dir:
204
  continue
 
213
  "instrumental": os.path.join(subdir, "instrumental.wav")
214
  }
215
 
216
+ # Skip if not all stems are present
217
  if not all(os.path.exists(path) for path in stem_paths.values()):
218
  print(f"Skipping {subdir}, not all stems are present.")
219
  continue
220
 
221
+ # Load and combine stems
222
  stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()}
223
  combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"])
224
 
225
+ # Trim silence at the end
226
  trimmed_combined = trim_silence_at_end(combined)
227
 
 
228
  output_file = os.path.join(subdir, f"{song_name}")
229
 
230
  try:
231
+ # Export combined audio
232
  if output_format == "m4a":
233
  trimmed_combined.export(output_file, format="ipod", codec="aac")
234
  else:
 
242
 
243
  def trim_silence_at_end(audio_segment, silence_thresh=-50, chunk_size=10):
244
  """
245
+ Trim silence at the end of an audio segment.
246
+
247
+ Args:
248
+ audio_segment (AudioSegment): Input audio segment
249
+ silence_thresh (int): Silence threshold in dB
250
+ chunk_size (int): Size of chunks to analyze in ms
251
+
252
+ Returns:
253
+ AudioSegment: Trimmed audio segment
254
  """
255
  silence_end = silence.detect_silence(audio_segment, min_silence_len=chunk_size, silence_thresh=silence_thresh)
256
 
 
261
  return audio_segment
262
 
263
  def delete_folders_and_files(input_dir):
264
+ """
265
+ Delete temporary folders and files after processing.
266
+
267
+ Args:
268
+ input_dir (str): Input directory to clean up
269
+ """
270
  folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer']
271
  files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav']
272
 
 
274
  if root == input_dir:
275
  continue
276
 
277
+ # Delete specified folders
278
  for folder in folders_to_delete:
279
  folder_path = os.path.join(root, folder)
280
  if os.path.isdir(folder_path):
281
  print(f"Deleting folder: {folder_path}")
282
  shutil.rmtree(folder_path)
283
 
284
+ # Delete specified files
285
  for file in files_to_delete:
286
  file_path = os.path.join(root, file)
287
  if os.path.isfile(file_path):
288
  print(f"Deleting file: {file_path}")
289
  os.remove(file_path)
290
 
291
+ # Delete vocals folders
292
  for root, dirs, files in os.walk(OUTPUT_FOLDER):
293
  for dir_name in dirs:
294
  if dir_name.endswith('_vocals'):
 
298
 
299
  print("Cleanup completed.")
300
 
301
+ @spaces.GPU(duration=120)
302
  def process_audio(uploaded_file):
303
+ """
304
+ Main function to process the uploaded audio file.
305
+
306
+ Args:
307
+ uploaded_file: Uploaded file object
308
+
309
+ Yields:
310
+ tuple: (status_message, output_file_path)
311
+ """
312
  try:
313
  yield "Processing audio...", None
314
 
 
319
  else:
320
  raise ValueError("Please upload a WAV file.")
321
 
322
+ # Run inference for different models
323
  yield "Starting SCNet inference...", None
324
  proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
325
 
 
329
  yield "Starting HTDemucs inference...", None
330
  proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
331
 
332
+ # Rename instrumental file
333
  source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav'
334
  destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav'
 
335
  os.rename(source_path, destination_path)
336
 
337
  yield "Starting BS Roformer inference...", None
338
  proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER)
339
 
340
+ # Clean up and organize files
341
  yield "Moving input files...", None
342
  delete_input_files(INPUT_FOLDER)
343
 
 
356
  logging.error(error_msg)
357
  yield error_msg, None
358
 
359
+ # Set up Gradio interface
360
  with gr.Blocks() as demo:
361
  gr.Markdown("# Music Player and Processor")
362
 
363
  file_upload = gr.File(label="Upload WAV file", file_types=[".m4a"])
 
364
  process_button = gr.Button("Process Audio")
 
365
  log_output = gr.Textbox(label="Processing Log", interactive=False)
366
  processed_audio_output = gr.File(label="Processed Audio")
367
 
 
372
  show_progress=True
373
  )
374
 
375
+ # Launch the Gradio app
376
  demo.launch()