|
import streamlit as st |
|
import time |
|
from transformers import pipeline |
|
import librosa |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
import tempfile |
|
import os |
|
import soundfile as sf |
|
|
|
|
|
st.set_page_config(page_title="π΅ Music Genre Classification", layout="wide") |
|
|
|
|
|
st.markdown(""" |
|
<style> |
|
.main-title { |
|
font-size: 3rem; |
|
color: #1DB954; |
|
text-align: center; |
|
padding: 2rem 0; |
|
text-shadow: 2px 2px 4px rgba(0,0,0,0.1); |
|
} |
|
.sub-title { |
|
font-size: 1.5rem; |
|
color: #191414; |
|
text-align: center; |
|
margin-bottom: 2rem; |
|
} |
|
.stAudio { |
|
margin: 2rem auto; |
|
display: block; |
|
} |
|
.genre-result { |
|
font-size: 2rem; |
|
font-weight: bold; |
|
text-align: center; |
|
color: #1DB954; |
|
margin: 1rem 0; |
|
} |
|
.prediction-time { |
|
font-size: 1.2rem; |
|
color: #191414; |
|
text-align: center; |
|
} |
|
</style> |
|
""", unsafe_allow_html=True) |
|
|
|
@st.cache_resource |
|
def load_model(): |
|
return pipeline("audio-classification", model="juangtzi/wav2vec2-base-finetuned-gtzan") |
|
|
|
pipe = load_model() |
|
|
|
def convert_to_wav(audio_file): |
|
"""Converts uploaded audio file to WAV format.""" |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav: |
|
|
|
audio_data, samplerate = sf.read(audio_file) |
|
sf.write(tmp_wav.name, audio_data, samplerate) |
|
return tmp_wav.name |
|
|
|
def classify_audio(audio_file): |
|
"""Classifies the audio file using the loaded model.""" |
|
start_time = time.time() |
|
|
|
|
|
wav_file = convert_to_wav(audio_file) |
|
|
|
try: |
|
|
|
preds = pipe(wav_file) |
|
outputs = {p["label"]: p["score"] for p in preds} |
|
end_time = time.time() |
|
prediction_time = end_time - start_time |
|
return outputs, prediction_time |
|
finally: |
|
os.unlink(wav_file) |
|
|
|
|
|
st.markdown("<h1 class='main-title'>π΅ Music Genre Classifier</h1>", unsafe_allow_html=True) |
|
st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True) |
|
|
|
|
|
st.sidebar.title("About") |
|
st.sidebar.info(""" |
|
This app uses a fine-tuned wav2vec2-base model to classify music genres. |
|
Model: juangtzi/wav2vec2-base-finetuned-gtzan |
|
Dataset: GTZAN |
|
""") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
st.audio(uploaded_file) |
|
|
|
|
|
if st.button("Classify Genre"): |
|
with st.spinner("Analyzing the music... π§"): |
|
try: |
|
results, pred_time = classify_audio(uploaded_file) |
|
|
|
|
|
top_genre = max(results, key=results.get) |
|
|
|
|
|
st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True) |
|
st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True) |
|
|
|
|
|
fig = go.Figure(data=[go.Bar( |
|
x=list(results.keys()), |
|
y=list(results.values()), |
|
marker_color='#1DB954' |
|
)]) |
|
fig.update_layout( |
|
title="Genre Probabilities", |
|
xaxis_title="Genre", |
|
yaxis_title="Probability", |
|
paper_bgcolor='rgba(0,0,0,0)', |
|
plot_bgcolor='rgba(0,0,0,0)' |
|
) |
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.balloons() |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred while processing the audio: {str(e)}") |
|
st.info("Please try uploading the file again or use a different audio file.") |
|
|
|
|
|
st.markdown(""" |
|
<div style='text-align: center; margin-top: 2rem;'> |
|
<p>Created with β€οΈ by AI. Powered by Streamlit and Hugging Face Transformers.</p> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|