Awell00 commited on
Commit
6bd24ce
·
verified ·
1 Parent(s): ad69837

fix: add llm to use instead match

Browse files
Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -11,6 +11,9 @@ from inference import proc_folder_direct
11
  from pathlib import Path
12
  import spaces
13
  from pydub.exceptions import CouldntEncodeError
 
 
 
14
 
15
  OUTPUT_FOLDER = "separation_results/"
16
  INPUT_FOLDER = "input"
@@ -25,23 +28,27 @@ def delete_input_files(input_dir):
25
  wav_file.unlink()
26
  print(f"Deleted {wav_file}")
27
 
 
 
 
 
 
 
 
 
 
 
 
28
  def handle_file_upload(file):
29
  if file is None:
30
  return None, "No file uploaded"
31
 
32
  filename = os.path.basename(file.name)
 
 
 
33
 
34
- # This regex captures both "Artist - Title" and "Title - Artist" formats
35
- match = re.match(
36
- r'^(.*? - .*?)\s*(?:[\(\[].*?[\)\]])?$', filename
37
- )
38
-
39
- if match:
40
- # This directly captures the "Artist - Title" part
41
- formatted_title = match.group(1).strip()
42
- else:
43
- # If no match, fallback to the original filename
44
- formatted_title = sanitize_filename(filename.strip())
45
 
46
  input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav")
47
  os.makedirs(os.path.dirname(input_path), exist_ok=True)
@@ -51,6 +58,7 @@ def handle_file_upload(file):
51
 
52
  return input_path, formatted_title
53
 
 
54
  def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
55
  command = [
56
  "python", "inference.py",
 
11
  from pathlib import Path
12
  import spaces
13
  from pydub.exceptions import CouldntEncodeError
14
+ from transformers import pipeline
15
+
16
+ model = pipeline('text-generation', model='EleutherAI/gpt-neo-125M')
17
 
18
  OUTPUT_FOLDER = "separation_results/"
19
  INPUT_FOLDER = "input"
 
28
  wav_file.unlink()
29
  print(f"Deleted {wav_file}")
30
 
31
+ def analyze_filename_with_llm(filename):
32
+ prompt = f"Extract the artist and song title from the following filename and format it as 'Artist - Title':\n\n{filename}"
33
+
34
+ # Generate a response using the local model
35
+ response = model(prompt, max_length=50, do_sample=False)[0]['generated_text']
36
+
37
+ # Extract the first line of the response, which should be the "Artist - Title"
38
+ artist_title = response.strip().split('\n')[0]
39
+
40
+ return artist_title
41
+
42
  def handle_file_upload(file):
43
  if file is None:
44
  return None, "No file uploaded"
45
 
46
  filename = os.path.basename(file.name)
47
+
48
+ # Use LLM to analyze the filename and return the formatted title
49
+ formatted_title = analyze_filename_with_llm(filename)
50
 
51
+ formatted_title = sanitize_filename(formatted_title.strip())
 
 
 
 
 
 
 
 
 
 
52
 
53
  input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav")
54
  os.makedirs(os.path.dirname(input_path), exist_ok=True)
 
58
 
59
  return input_path, formatted_title
60
 
61
+
62
  def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"):
63
  command = [
64
  "python", "inference.py",