|
Para un uso sencillo del modelo utilize el siguiente codigo: |
|
|
|
```py |
|
#! pip install transformers |
|
#! pip install torch |
|
#! pip install datasets |
|
``` |
|
|
|
```py |
|
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from datasets import load_dataset |
|
import numpy as np |
|
import torch |
|
|
|
|
|
dataset = load_dataset("manoh2f2/songs_resampled") |
|
|
|
# Cargar el dataset en un DataFrame |
|
split_name = 'train' |
|
df_resampled = dataset[split_name].to_pandas() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("manoh2f2/recommend_songs") |
|
model = AutoModelForSequenceClassification.from_pretrained("manoh2f2/recommend_songs") |
|
|
|
# Define a prompt |
|
prompt = "I am happy" |
|
|
|
# Tokenize the prompt |
|
encoded_prompt = tokenizer(prompt, return_tensors='pt', max_length=256) |
|
|
|
# Make a prediction using the trained model |
|
with torch.no_grad(): |
|
model_output = model(**encoded_prompt) |
|
|
|
# Get the predicted emotion index |
|
predicted_emotion_index = torch.argmax(model_output.logits).item() |
|
|
|
# Map the index back to the emotion label using the DataFrame |
|
predicted_emotion_label = df_resampled['emotions'].unique()[predicted_emotion_index] |
|
|
|
# Get a song associated with the predicted emotion from the DaraFrame |
|
result = df_resampled[df_resampled['emotions'] == predicted_emotion_label] |
|
|
|
# Get the number of rows in the DataFrame |
|
num_rows = result.shape[0] |
|
#Generate a random index to select a random song from the DataFrame |
|
random_index = np.random.randint(0, num_rows) |
|
|
|
#Get the recommended song and artist |
|
recommended_song = result['song'].iloc[random_index] |
|
recommended_artist = result['artist'].iloc[random_index] |
|
|
|
#Print the results |
|
print(f"Prompt: {prompt}") |
|
print(f"Predicted Emotion: {predicted_emotion_label}") |
|
print(f"Recommended Song: {recommended_song} - {recommended_artist}") |
|
``` |