PunGrumpy commited on
Commit
7acbfbc
·
1 Parent(s): f38b9a7

✨ feat: add spotify to find audio features

Browse files
Files changed (3) hide show
  1. .gitignore +11 -0
  2. app.py +87 -62
  3. requirements.txt +1 -0
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment
2
+ .env
3
+
4
+ # Flags
5
+ flagged/
6
+
7
+ # Cache
8
+ .cache/
9
+ .cache
10
+ __pycache__/
11
+ gradio_cached_examples/
app.py CHANGED
@@ -1,27 +1,33 @@
 
1
  import torch
 
2
  import numpy as np
3
  import gradio as gr
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
 
7
  from transformers import AutoTokenizer, AutoModel
 
8
 
9
- REPO_NAME = "PunGrumpy/music-genre-classification"
10
 
11
- GENRE = {"edm": 0, "r&b": 1, "rap": 2, "rock": 3, "pop": 4}
12
- AUDIO_FEATURES = {
13
- "acousticness": 0,
14
- "danceability": 0,
15
- "energy": 0,
16
- "instrumentalness": 0,
17
- "key": 0,
18
- "liveness": 0,
19
- "loudness": 0,
20
- "mode": 0,
21
- "speechiness": 0,
22
- "tempo": 0,
23
- "valence": 0,
24
- }
 
 
 
 
25
 
26
 
27
  class LyricsAudioModelInference:
@@ -30,13 +36,38 @@ class LyricsAudioModelInference:
30
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
31
  self.num_labels = num_labels
32
  self.classifier = nn.Linear(
33
- self.model.config.hidden_size + len(AUDIO_FEATURES), num_labels
 
 
 
 
 
 
34
  )
35
 
36
- def predict_genre(self, lyrics: str, *audio_features) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  input_lyrics = self.tokenizer(
38
  lyrics, return_tensors="pt", padding=True, truncation=True, max_length=512
39
  )
 
40
 
41
  outputs = self.model(**input_lyrics)
42
  lyrics_embedding = outputs.last_hidden_state.mean(dim=1)
@@ -59,70 +90,64 @@ class LyricsAudioModelInference:
59
  for i in range(3):
60
  genre_idx = top3_genres.indices[0][i].item()
61
  genre_prob = top3_genres.values[0][i].item()
62
- genre_label = [key for key, value in GENRE.items() if value == genre_idx][0]
 
 
 
 
63
  result[genre_label] = genre_prob
64
 
 
65
  return result
66
 
67
 
68
- if __name__ == "__main__":
69
  iface = gr.Interface(
70
- fn=LyricsAudioModelInference(model_name=REPO_NAME).predict_genre,
 
71
  inputs=[
72
  gr.Textbox(
73
- lines=20,
74
  placeholder="Enter lyrics here...",
75
  label="Lyrics",
76
  ),
77
- gr.Slider(
78
- minimum=0,
79
- maximum=1,
80
- label="Acousticness",
81
- step=0.01,
82
- ),
83
- gr.Slider(
84
- minimum=0,
85
- maximum=1,
86
- label="Danceability",
87
- step=0.01,
88
  ),
89
- gr.Slider(minimum=0, maximum=1, label="Energy", step=0.01),
90
- gr.Slider(
91
- minimum=0,
92
- maximum=1,
93
- label="Instrumentalness",
94
- step=0.01,
95
  ),
96
- gr.Slider(minimum=0, maximum=11, label="Key", step=1),
97
- gr.Slider(minimum=0, maximum=1, label="Liveness", step=0.01),
98
- gr.Slider(minimum=-60, maximum=0, label="Loudness", step=1),
99
- gr.Slider(minimum=0, maximum=1, label="Mode", step=1),
100
- gr.Slider(minimum=0, maximum=1, label="Speechiness", step=0.01),
101
- gr.Slider(minimum=0, maximum=200, label="Tempo", step=1),
102
- gr.Slider(minimum=0, maximum=1, label="Valence", step=0.01),
103
  ],
104
- outputs=gr.Label(
105
- num_top_classes=3,
106
- label="Top 3 Predicted Genres",
107
- ),
108
- title="Music Genre Classifier",
109
  description="This model predicts the genre of a song based on its lyrics and audio features.",
110
  examples=[
111
  [
112
  "When the sun is rising over streets so barren...",
113
- 0.7050,
114
- 0.420,
115
- 0.247,
116
- 0.00349,
117
- 2,
118
- 0.1270,
119
- -13.370,
120
- 0,
121
- 0.0360,
122
- 88.071,
123
- 0.138,
 
 
 
 
124
  ],
125
  ],
 
126
  )
127
 
128
- iface.launch(debug=True, show_api=True, share=True, inline=True)
 
 
1
+ import os
2
  import torch
3
+ import spotipy
4
  import numpy as np
5
  import gradio as gr
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
 
9
  from transformers import AutoTokenizer, AutoModel
10
+ from spotipy.oauth2 import SpotifyClientCredentials
11
 
 
12
 
13
+ class ConfigApp:
14
+ REPO_NAME = "PunGrumpy/music-genre-classification"
15
+ GENRE = {"edm": 0, "r&b": 1, "rap": 2, "rock": 3, "pop": 4}
16
+ AUDIO_FEATURES = {
17
+ "acousticness": 0,
18
+ "danceability": 0,
19
+ "energy": 0,
20
+ "instrumentalness": 0,
21
+ "key": 0,
22
+ "liveness": 0,
23
+ "loudness": 0,
24
+ "mode": 0,
25
+ "speechiness": 0,
26
+ "tempo": 0,
27
+ "valence": 0,
28
+ }
29
+ SPOTIFY_CLIENT_ID = os.getenv("SPOTIFY_CLIENT_ID")
30
+ SPOTIFY_ACCESS_TOKEN = os.getenv("SPOTIFY_ACCESS_TOKEN")
31
 
32
 
33
  class LyricsAudioModelInference:
 
36
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
37
  self.num_labels = num_labels
38
  self.classifier = nn.Linear(
39
+ self.model.config.hidden_size + len(ConfigApp.AUDIO_FEATURES), num_labels
40
+ )
41
+ self.sp = spotipy.Spotify(
42
+ client_credentials_manager=SpotifyClientCredentials(
43
+ client_id=ConfigApp.SPOTIFY_CLIENT_ID,
44
+ client_secret=ConfigApp.SPOTIFY_ACCESS_TOKEN,
45
+ )
46
  )
47
 
48
+ def get_audio_features(self, spotify_track_link: str) -> list:
49
+ track_id = spotify_track_link.split("/")[-1].split("?")[0]
50
+ audio_features = self.sp.audio_features(track_id)
51
+ audio_features = [
52
+ audio_features[0][feature] for feature in ConfigApp.AUDIO_FEATURES
53
+ ]
54
+ return audio_features
55
+
56
+ def get_track_info(self, spotify_track_link: str) -> dict:
57
+ track_id = spotify_track_link.split("/")[-1].split("?")[0]
58
+ track_info = self.sp.track(track_id)
59
+ song_name = track_info.get("name", "Unknown")
60
+ artist_name = ", ".join(
61
+ [artist["name"] for artist in track_info.get("artists", [])]
62
+ )
63
+ print(f"Song Name: {song_name}, Artist Name: {artist_name}")
64
+ return {"Song Name": song_name, "Artist Name": artist_name}
65
+
66
+ def predict_genre(self, lyrics: str, spotify_track_link: str) -> dict:
67
  input_lyrics = self.tokenizer(
68
  lyrics, return_tensors="pt", padding=True, truncation=True, max_length=512
69
  )
70
+ audio_features = self.get_audio_features(spotify_track_link)
71
 
72
  outputs = self.model(**input_lyrics)
73
  lyrics_embedding = outputs.last_hidden_state.mean(dim=1)
 
90
  for i in range(3):
91
  genre_idx = top3_genres.indices[0][i].item()
92
  genre_prob = top3_genres.values[0][i].item()
93
+ genre_label = [
94
+ key.capitalize()
95
+ for key, value in ConfigApp.GENRE.items()
96
+ if value == genre_idx
97
+ ][0]
98
  result[genre_label] = genre_prob
99
 
100
+ # track_info = self.get_track_info(spotify_track_link)
101
  return result
102
 
103
 
104
+ with gr.Blocks() as demo:
105
  iface = gr.Interface(
106
+ api_name="Music Genre Classifier",
107
+ fn=LyricsAudioModelInference(model_name=ConfigApp.REPO_NAME).predict_genre,
108
  inputs=[
109
  gr.Textbox(
110
+ lines=5,
111
  placeholder="Enter lyrics here...",
112
  label="Lyrics",
113
  ),
114
+ gr.Textbox(
115
+ lines=1,
116
+ placeholder="Enter Spotify Track Link here...",
117
+ label="Spotify Track Link",
 
 
 
 
 
 
 
118
  ),
119
+ ],
120
+ outputs=[
121
+ gr.Label(
122
+ num_top_classes=3,
123
+ label="Top 3 Predicted Genres",
124
+ elem_id="top3-genres",
125
  ),
 
 
 
 
 
 
 
126
  ],
127
+ title="🎷 Music Genre Classifier",
 
 
 
 
128
  description="This model predicts the genre of a song based on its lyrics and audio features.",
129
  examples=[
130
  [
131
  "When the sun is rising over streets so barren...",
132
+ "https://open.spotify.com/track/2rGS4ipEZzldN0EpcfH3PK",
133
+ ],
134
+ [
135
+ "Tastes like strawberries On a summer evenin'...",
136
+ "https://open.spotify.com/track/6UelLqGlWMcVH1E5c4H7lY",
137
+ ],
138
+ [
139
+ """Da, da, da, da, da
140
+ It's the motherfuckin' D-O-double-G (Snoop Dogg!)
141
+ Da, da, da, da, da
142
+ You know I'm mobbin' with the D.R.E. (Yeah, yeah, yeah)
143
+ You know who's back up in this motherfucker! (What, what, what, what?)
144
+ So blaze the weed up then! (Blaze it up, blaze it up!)
145
+ Blaze that shit up, nigga... yeah 'Sup Snoop?""",
146
+ "https://open.spotify.com/track/4LwU4Vp6od3Sb08CsP99GC",
147
  ],
148
  ],
149
+ analytics_enabled=True,
150
  )
151
 
152
+
153
+ demo.launch(debug=True, show_api=True)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch==2.2.1
2
  numpy==1.26.4
3
  gradio==4.21.0
 
4
  transformers==4.38.0
 
1
  torch==2.2.1
2
  numpy==1.26.4
3
  gradio==4.21.0
4
+ spotipy==2.23.0
5
  transformers==4.38.0