Spaces:
Runtime error
Runtime error
PunGrumpy
commited on
Commit
·
7acbfbc
1
Parent(s):
f38b9a7
✨ feat: add spotify to find audio features
Browse files- .gitignore +11 -0
- app.py +87 -62
- requirements.txt +1 -0
.gitignore
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Environment
|
2 |
+
.env
|
3 |
+
|
4 |
+
# Flags
|
5 |
+
flagged/
|
6 |
+
|
7 |
+
# Cache
|
8 |
+
.cache/
|
9 |
+
.cache
|
10 |
+
__pycache__/
|
11 |
+
gradio_cached_examples/
|
app.py
CHANGED
@@ -1,27 +1,33 @@
|
|
|
|
1 |
import torch
|
|
|
2 |
import numpy as np
|
3 |
import gradio as gr
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
|
7 |
from transformers import AutoTokenizer, AutoModel
|
|
|
8 |
|
9 |
-
REPO_NAME = "PunGrumpy/music-genre-classification"
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
"
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
class LyricsAudioModelInference:
|
@@ -30,13 +36,38 @@ class LyricsAudioModelInference:
|
|
30 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
31 |
self.num_labels = num_labels
|
32 |
self.classifier = nn.Linear(
|
33 |
-
self.model.config.hidden_size + len(AUDIO_FEATURES), num_labels
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
)
|
35 |
|
36 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
input_lyrics = self.tokenizer(
|
38 |
lyrics, return_tensors="pt", padding=True, truncation=True, max_length=512
|
39 |
)
|
|
|
40 |
|
41 |
outputs = self.model(**input_lyrics)
|
42 |
lyrics_embedding = outputs.last_hidden_state.mean(dim=1)
|
@@ -59,70 +90,64 @@ class LyricsAudioModelInference:
|
|
59 |
for i in range(3):
|
60 |
genre_idx = top3_genres.indices[0][i].item()
|
61 |
genre_prob = top3_genres.values[0][i].item()
|
62 |
-
genre_label = [
|
|
|
|
|
|
|
|
|
63 |
result[genre_label] = genre_prob
|
64 |
|
|
|
65 |
return result
|
66 |
|
67 |
|
68 |
-
|
69 |
iface = gr.Interface(
|
70 |
-
|
|
|
71 |
inputs=[
|
72 |
gr.Textbox(
|
73 |
-
lines=
|
74 |
placeholder="Enter lyrics here...",
|
75 |
label="Lyrics",
|
76 |
),
|
77 |
-
gr.
|
78 |
-
|
79 |
-
|
80 |
-
label="
|
81 |
-
step=0.01,
|
82 |
-
),
|
83 |
-
gr.Slider(
|
84 |
-
minimum=0,
|
85 |
-
maximum=1,
|
86 |
-
label="Danceability",
|
87 |
-
step=0.01,
|
88 |
),
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
label="
|
94 |
-
|
95 |
),
|
96 |
-
gr.Slider(minimum=0, maximum=11, label="Key", step=1),
|
97 |
-
gr.Slider(minimum=0, maximum=1, label="Liveness", step=0.01),
|
98 |
-
gr.Slider(minimum=-60, maximum=0, label="Loudness", step=1),
|
99 |
-
gr.Slider(minimum=0, maximum=1, label="Mode", step=1),
|
100 |
-
gr.Slider(minimum=0, maximum=1, label="Speechiness", step=0.01),
|
101 |
-
gr.Slider(minimum=0, maximum=200, label="Tempo", step=1),
|
102 |
-
gr.Slider(minimum=0, maximum=1, label="Valence", step=0.01),
|
103 |
],
|
104 |
-
|
105 |
-
num_top_classes=3,
|
106 |
-
label="Top 3 Predicted Genres",
|
107 |
-
),
|
108 |
-
title="Music Genre Classifier",
|
109 |
description="This model predicts the genre of a song based on its lyrics and audio features.",
|
110 |
examples=[
|
111 |
[
|
112 |
"When the sun is rising over streets so barren...",
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
124 |
],
|
125 |
],
|
|
|
126 |
)
|
127 |
|
128 |
-
|
|
|
|
1 |
+
import os
|
2 |
import torch
|
3 |
+
import spotipy
|
4 |
import numpy as np
|
5 |
import gradio as gr
|
6 |
import torch.nn as nn
|
7 |
import torch.nn.functional as F
|
8 |
|
9 |
from transformers import AutoTokenizer, AutoModel
|
10 |
+
from spotipy.oauth2 import SpotifyClientCredentials
|
11 |
|
|
|
12 |
|
13 |
+
class ConfigApp:
|
14 |
+
REPO_NAME = "PunGrumpy/music-genre-classification"
|
15 |
+
GENRE = {"edm": 0, "r&b": 1, "rap": 2, "rock": 3, "pop": 4}
|
16 |
+
AUDIO_FEATURES = {
|
17 |
+
"acousticness": 0,
|
18 |
+
"danceability": 0,
|
19 |
+
"energy": 0,
|
20 |
+
"instrumentalness": 0,
|
21 |
+
"key": 0,
|
22 |
+
"liveness": 0,
|
23 |
+
"loudness": 0,
|
24 |
+
"mode": 0,
|
25 |
+
"speechiness": 0,
|
26 |
+
"tempo": 0,
|
27 |
+
"valence": 0,
|
28 |
+
}
|
29 |
+
SPOTIFY_CLIENT_ID = os.getenv("SPOTIFY_CLIENT_ID")
|
30 |
+
SPOTIFY_ACCESS_TOKEN = os.getenv("SPOTIFY_ACCESS_TOKEN")
|
31 |
|
32 |
|
33 |
class LyricsAudioModelInference:
|
|
|
36 |
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
37 |
self.num_labels = num_labels
|
38 |
self.classifier = nn.Linear(
|
39 |
+
self.model.config.hidden_size + len(ConfigApp.AUDIO_FEATURES), num_labels
|
40 |
+
)
|
41 |
+
self.sp = spotipy.Spotify(
|
42 |
+
client_credentials_manager=SpotifyClientCredentials(
|
43 |
+
client_id=ConfigApp.SPOTIFY_CLIENT_ID,
|
44 |
+
client_secret=ConfigApp.SPOTIFY_ACCESS_TOKEN,
|
45 |
+
)
|
46 |
)
|
47 |
|
48 |
+
def get_audio_features(self, spotify_track_link: str) -> list:
|
49 |
+
track_id = spotify_track_link.split("/")[-1].split("?")[0]
|
50 |
+
audio_features = self.sp.audio_features(track_id)
|
51 |
+
audio_features = [
|
52 |
+
audio_features[0][feature] for feature in ConfigApp.AUDIO_FEATURES
|
53 |
+
]
|
54 |
+
return audio_features
|
55 |
+
|
56 |
+
def get_track_info(self, spotify_track_link: str) -> dict:
|
57 |
+
track_id = spotify_track_link.split("/")[-1].split("?")[0]
|
58 |
+
track_info = self.sp.track(track_id)
|
59 |
+
song_name = track_info.get("name", "Unknown")
|
60 |
+
artist_name = ", ".join(
|
61 |
+
[artist["name"] for artist in track_info.get("artists", [])]
|
62 |
+
)
|
63 |
+
print(f"Song Name: {song_name}, Artist Name: {artist_name}")
|
64 |
+
return {"Song Name": song_name, "Artist Name": artist_name}
|
65 |
+
|
66 |
+
def predict_genre(self, lyrics: str, spotify_track_link: str) -> dict:
|
67 |
input_lyrics = self.tokenizer(
|
68 |
lyrics, return_tensors="pt", padding=True, truncation=True, max_length=512
|
69 |
)
|
70 |
+
audio_features = self.get_audio_features(spotify_track_link)
|
71 |
|
72 |
outputs = self.model(**input_lyrics)
|
73 |
lyrics_embedding = outputs.last_hidden_state.mean(dim=1)
|
|
|
90 |
for i in range(3):
|
91 |
genre_idx = top3_genres.indices[0][i].item()
|
92 |
genre_prob = top3_genres.values[0][i].item()
|
93 |
+
genre_label = [
|
94 |
+
key.capitalize()
|
95 |
+
for key, value in ConfigApp.GENRE.items()
|
96 |
+
if value == genre_idx
|
97 |
+
][0]
|
98 |
result[genre_label] = genre_prob
|
99 |
|
100 |
+
# track_info = self.get_track_info(spotify_track_link)
|
101 |
return result
|
102 |
|
103 |
|
104 |
+
with gr.Blocks() as demo:
|
105 |
iface = gr.Interface(
|
106 |
+
api_name="Music Genre Classifier",
|
107 |
+
fn=LyricsAudioModelInference(model_name=ConfigApp.REPO_NAME).predict_genre,
|
108 |
inputs=[
|
109 |
gr.Textbox(
|
110 |
+
lines=5,
|
111 |
placeholder="Enter lyrics here...",
|
112 |
label="Lyrics",
|
113 |
),
|
114 |
+
gr.Textbox(
|
115 |
+
lines=1,
|
116 |
+
placeholder="Enter Spotify Track Link here...",
|
117 |
+
label="Spotify Track Link",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
),
|
119 |
+
],
|
120 |
+
outputs=[
|
121 |
+
gr.Label(
|
122 |
+
num_top_classes=3,
|
123 |
+
label="Top 3 Predicted Genres",
|
124 |
+
elem_id="top3-genres",
|
125 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
],
|
127 |
+
title="🎷 Music Genre Classifier",
|
|
|
|
|
|
|
|
|
128 |
description="This model predicts the genre of a song based on its lyrics and audio features.",
|
129 |
examples=[
|
130 |
[
|
131 |
"When the sun is rising over streets so barren...",
|
132 |
+
"https://open.spotify.com/track/2rGS4ipEZzldN0EpcfH3PK",
|
133 |
+
],
|
134 |
+
[
|
135 |
+
"Tastes like strawberries On a summer evenin'...",
|
136 |
+
"https://open.spotify.com/track/6UelLqGlWMcVH1E5c4H7lY",
|
137 |
+
],
|
138 |
+
[
|
139 |
+
"""Da, da, da, da, da
|
140 |
+
It's the motherfuckin' D-O-double-G (Snoop Dogg!)
|
141 |
+
Da, da, da, da, da
|
142 |
+
You know I'm mobbin' with the D.R.E. (Yeah, yeah, yeah)
|
143 |
+
You know who's back up in this motherfucker! (What, what, what, what?)
|
144 |
+
So blaze the weed up then! (Blaze it up, blaze it up!)
|
145 |
+
Blaze that shit up, nigga... yeah 'Sup Snoop?""",
|
146 |
+
"https://open.spotify.com/track/4LwU4Vp6od3Sb08CsP99GC",
|
147 |
],
|
148 |
],
|
149 |
+
analytics_enabled=True,
|
150 |
)
|
151 |
|
152 |
+
|
153 |
+
demo.launch(debug=True, show_api=True)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
torch==2.2.1
|
2 |
numpy==1.26.4
|
3 |
gradio==4.21.0
|
|
|
4 |
transformers==4.38.0
|
|
|
1 |
torch==2.2.1
|
2 |
numpy==1.26.4
|
3 |
gradio==4.21.0
|
4 |
+
spotipy==2.23.0
|
5 |
transformers==4.38.0
|