Awell00 commited on
Commit
889d37b
·
verified ·
1 Parent(s): 1dc6872

feat!: create app.py to initialize Gradio interface with main function

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yt_dlp
2
+ import re
3
+ import subprocess
4
+ import os
5
+ import shutil
6
+ from pydub import AudioSegment
7
+ import gradio as gr
8
+ import traceback
9
+ import logging
10
+ from inference import proc_folder_direct
11
+ from pathlib import Path
12
+
13
+ OUTPUT_FOLDER = "separation_results/"
14
+ INPUT_FOLDER = "input"
15
+ download_path = ""
16
+
17
+ def sanitize_filename(filename):
18
+ return re.sub(r'[\\/*?:"<>|]', '_', filename)
19
+
20
+ def delete_input_files(input_dir):
21
+ wav_dir = Path(input_dir) / "wav"
22
+ for wav_file in wav_dir.glob("*.wav"):
23
+ wav_file.unlink()
24
+ print(f"Deleted {wav_file}")
25
+
26
+ def download_youtube_audio_by_title(query, state=True):
27
+ if state:
28
+ delete_input_files(INPUT_FOLDER)
29
+
30
+ ydl_opts = {
31
+ 'quiet': True,
32
+ 'default_search': 'ytsearch',
33
+ 'noplaylist': True,
34
+ 'format': 'bestaudio/best',
35
+ 'outtmpl': './input/wav/%(title)s.%(ext)s',
36
+ 'postprocessors': [{
37
+ 'key': 'FFmpegExtractAudio',
38
+ 'preferredcodec': 'wav',
39
+ }],
40
+ }
41
+
42
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
43
+ search_results = ydl.extract_info(query, download=False)
44
+ video_info = search_results['entries'][0]
45
+ video_url = video_info['webpage_url']
46
+ video_title = video_info['title']
47
+
48
+ match = re.match(r'^(.*? - .*?)(?: \[.*\]|\(.*\))?$', video_title)
49
+ formatted_title = match.group(1) if match else video_title
50
+
51
+ formatted_title = sanitize_filename(formatted_title.strip())
52
+
53
+ ydl_opts['outtmpl'] = f'./input/wav/{formatted_title}.%(ext)s'
54
+
55
+ if state:
56
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
57
+ ydl.download([video_url])
58
+ return f'./input/wav/{formatted_title}.wav'
59
+
60
+ return formatted_title
61
+
62
+ def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
63
+ command = [
64
+ "python", "inference.py",
65
+ "--model_type", model_type,
66
+ "--config_path", config_path,
67
+ "--start_check_point", start_check_point,
68
+ "--INPUT_FOLDER", input_dir,
69
+ "--store_dir", output_dir,
70
+ "--device_ids", device_ids
71
+ ]
72
+ return subprocess.run(command, check=True, capture_output=True, text=True)
73
+
74
+ def move_stems_to_parent(input_dir):
75
+ for subdir, dirs, files in os.walk(input_dir):
76
+ if subdir == input_dir:
77
+ continue
78
+
79
+ parent_dir = os.path.dirname(subdir)
80
+ song_name = os.path.basename(parent_dir)
81
+
82
+ if 'htdemucs' in subdir:
83
+ print(f"Processing htdemucs in {subdir}")
84
+ bass_path = os.path.join(subdir, f"{song_name}_bass.wav")
85
+ if os.path.exists(bass_path):
86
+ new_bass_path = os.path.join(parent_dir, "bass.wav")
87
+ print(f"Moving {bass_path} to {new_bass_path}")
88
+ shutil.move(bass_path, new_bass_path)
89
+ else:
90
+ print(f"Bass file not found: {bass_path}")
91
+ elif 'mel_band_roformer' in subdir:
92
+ print(f"Processing mel_band_roformer in {subdir}")
93
+ vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav")
94
+ if os.path.exists(vocals_path):
95
+ new_vocals_path = os.path.join(parent_dir, "vocals.wav")
96
+ print(f"Moving {vocals_path} to {new_vocals_path}")
97
+ shutil.move(vocals_path, new_vocals_path)
98
+ else:
99
+ print(f"Vocals file not found: {vocals_path}")
100
+ elif 'scnet' in subdir:
101
+ print(f"Processing scnet in {subdir}")
102
+ other_path = os.path.join(subdir, f"{song_name}_other.wav")
103
+ if os.path.exists(other_path):
104
+ new_other_path = os.path.join(parent_dir, "other.wav")
105
+ print(f"Moving {other_path} to {new_other_path}")
106
+ shutil.move(other_path, new_other_path)
107
+ else:
108
+ print(f"Other file not found: {other_path}")
109
+ elif 'bs_roformer' in subdir:
110
+ print(f"Processing bs_roformer in {subdir}")
111
+ instrumental_path = os.path.join(subdir, f"{song_name}_other.wav")
112
+ if os.path.exists(instrumental_path):
113
+ new_instrumental_path = os.path.join(parent_dir, "instrumental.wav")
114
+ print(f"Moving {instrumental_path} to {new_instrumental_path}")
115
+ shutil.move(instrumental_path, new_instrumental_path)
116
+ else:
117
+ print(f"Instrumental file not found: {instrumental_path}")
118
+
119
+ def combine_stems_for_all(input_dir):
120
+ for subdir, _, _ in os.walk(input_dir):
121
+ if subdir == input_dir:
122
+ continue
123
+
124
+ song_name = os.path.basename(subdir)
125
+ print(f"Processing {subdir}")
126
+
127
+ stem_paths = {
128
+ "vocals": os.path.join(subdir, "vocals.wav"),
129
+ "bass": os.path.join(subdir, "bass.wav"),
130
+ "others": os.path.join(subdir, "other.wav"),
131
+ "instrumental": os.path.join(subdir, "instrumental.wav")
132
+ }
133
+
134
+ if not all(os.path.exists(path) for path in stem_paths.values()):
135
+ print(f"Skipping {subdir}, not all stems are present.")
136
+ continue
137
+
138
+ stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()}
139
+ combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"])
140
+
141
+ output_file = os.path.join(subdir, f"{song_name}.MDS.wav")
142
+ combined.export(output_file, format="wav")
143
+ print(f"Exported combined stems to {output_file}")
144
+
145
+ def delete_folders_and_files(input_dir):
146
+ folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer']
147
+ files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav']
148
+
149
+ for root, dirs, files in os.walk(input_dir, topdown=False):
150
+ if root == input_dir:
151
+ continue
152
+
153
+ for folder in folders_to_delete:
154
+ folder_path = os.path.join(root, folder)
155
+ if os.path.isdir(folder_path):
156
+ print(f"Deleting folder: {folder_path}")
157
+ shutil.rmtree(folder_path)
158
+
159
+ for file in files_to_delete:
160
+ file_path = os.path.join(root, file)
161
+ if os.path.isfile(file_path):
162
+ print(f"Deleting file: {file_path}")
163
+ os.remove(file_path)
164
+
165
+ for root, dirs, files in os.walk(OUTPUT_FOLDER):
166
+ for dir_name in dirs:
167
+ if dir_name.endswith('_vocals'):
168
+ dir_path = os.path.join(root, dir_name)
169
+ print(f"Deleting folder: {dir_path}")
170
+ shutil.rmtree(dir_path)
171
+
172
+ print("Cleanup completed.")
173
+
174
+ def process_audio(song_title):
175
+ try:
176
+ yield "Finding audio...", None
177
+ if title_input == "":
178
+ raise ValueError("Please enter a song title.")
179
+
180
+ formatted_title = download_youtube_audio_by_title(song_title, False)
181
+
182
+ yield "Starting SCNet inference...", None
183
+ proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
184
+
185
+ yield "Starting Mel Band Roformer inference...", None
186
+ proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True)
187
+ yield "Starting HTDemucs inference...", None
188
+ proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER)
189
+
190
+ source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav'
191
+ destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav'
192
+
193
+ os.rename(source_path, destination_path)
194
+
195
+ yield "Starting BS Roformer inference...", None
196
+ proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER)
197
+
198
+ yield "Moving input files...", None
199
+ delete_input_files(INPUT_FOLDER)
200
+
201
+ yield "Moving stems to parent...", None
202
+ move_stems_to_parent(OUTPUT_FOLDER)
203
+
204
+ yield "Combining stems...", None
205
+ combine_stems_for_all(OUTPUT_FOLDER)
206
+
207
+ yield "Cleaning up...", None
208
+ delete_folders_and_files(OUTPUT_FOLDER)
209
+
210
+
211
+ yield f"Audio processing completed successfully.", f'{OUTPUT_FOLDER}{formatted_title}/{formatted_title}.MDS.wav'
212
+ except Exception as e:
213
+ error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}"
214
+ logging.error(error_msg)
215
+ yield error_msg, None
216
+
217
+ with gr.Blocks() as demo:
218
+ gr.Markdown("# Music Player and Processor")
219
+
220
+ with gr.Row():
221
+ title_input = gr.Textbox(label="Enter Song Title")
222
+ play_button = gr.Button("Play")
223
+
224
+ audio_output = gr.Audio(label="Audio Player")
225
+
226
+ process_button = gr.Button("Process Audio")
227
+
228
+ log_output = gr.Textbox(label="Processing Log", interactive=False)
229
+ processed_audio_output = gr.Audio(label="Processed Audio")
230
+
231
+ play_button.click(
232
+ fn=download_youtube_audio_by_title,
233
+ inputs=title_input,
234
+ outputs=audio_output
235
+ )
236
+
237
+ process_button.click(
238
+ fn=process_audio,
239
+ inputs=title_input,
240
+ outputs=[log_output, processed_audio_output],
241
+ show_progress=True
242
+ )
243
+ demo.launch()