|
import gradio as gr |
|
import json |
|
from utils import * |
|
from unidecode import unidecode |
|
from transformers import AutoTokenizer |
|
|
|
description = """ |
|
<div> |
|
<a style="display:inline-block" href='https://github.com/microsoft/muzic/tree/main/clamp'><img src='https://img.shields.io/github/stars/microsoft/muzic?style=social' /></a> |
|
<a style='display:inline-block' href='https://ai-muzic.github.io/clamp/'><img src='https://img.shields.io/badge/website-CLaMP-ff69b4.svg' /></a> |
|
<a style="display:inline-block" href="https://huggingface.co/datasets/sander-wood/wikimusictext"><img src="https://img.shields.io/badge/huggingface-dataset-ffcc66.svg"></a> |
|
<a style="display:inline-block" href="https://arxiv.org/pdf/2304.11029.pdf"><img src="https://img.shields.io/badge/arXiv-2304.11029-b31b1b.svg"></a> |
|
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/sander-wood/clamp_semantic_music_search?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg" alt="Duplicate Space"></a> |
|
</div> |
|
|
|
## ℹ️ How to use this demo? |
|
1. Enter a query in the text box. |
|
2. Click "Submit" and wait for the result. |
|
3. It will return the most matching music score from the WikiMusictext dataset (1010 scores in total). |
|
|
|
## ❕Notice |
|
- The text box is case-sensitive. |
|
- You can enter longer text for the text box, but the demo will only use the first 128 tokens. |
|
- The returned results include the title, artist, genre, description, and the score in ABC notation. |
|
- The genre and description may not be accurate, as they are collected from the web. |
|
- The demo is based on CLaMP-S/512, a CLaMP model with 6-layer Transformer text/music encoders and a sequence length of 512. |
|
|
|
## 🔠👉🎵 Semantic Music Search |
|
Semantic search is a technique for retrieving music by open-domain queries, which differs from traditional keyword-based searches that depend on exact matches or meta-information. This involves two steps: 1) extracting music features from all scores in the library, and 2) transforming the query into a text feature. By calculating the similarities between the text feature and the music features, it can efficiently locate the score that best matches the user's query in the library. |
|
|
|
""" |
|
examples = [ |
|
"Jazz standard in Minor key with a swing feel.", |
|
"Jazz standard in Major key with a fast tempo.", |
|
"Jazz standard in Blues form with a soulfoul melody.", |
|
"a painting of a starry night with the moon in the sky", |
|
"a green field with a blue sky and clouds", |
|
"a beach with a castle on top of it" |
|
] |
|
|
|
CLAMP_MODEL_NAME = 'sander-wood/clamp-small-512' |
|
QUERY_MODAL = 'text' |
|
KEY_MODAL = 'music' |
|
TOP_N = 1 |
|
TEXT_MODEL_NAME = 'distilroberta-base' |
|
TEXT_LENGTH = 128 |
|
device = torch.device("cpu") |
|
|
|
|
|
model = CLaMP.from_pretrained(CLAMP_MODEL_NAME) |
|
music_length = model.config.max_length |
|
model = model.to(device) |
|
model.eval() |
|
|
|
|
|
patchilizer = MusicPatchilizer() |
|
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) |
|
softmax = torch.nn.Softmax(dim=1) |
|
|
|
def compute_values(Q_e, K_e, t=1): |
|
""" |
|
Compute the values for the attention matrix |
|
|
|
Args: |
|
Q_e (torch.Tensor): Query embeddings |
|
K_e (torch.Tensor): Key embeddings |
|
t (float): Temperature for the softmax |
|
|
|
Returns: |
|
values (torch.Tensor): Values for the attention matrix |
|
""" |
|
|
|
Q_e = torch.nn.functional.normalize(Q_e, dim=1) |
|
K_e = torch.nn.functional.normalize(K_e, dim=1) |
|
|
|
|
|
logits = torch.mm(Q_e, K_e.T) * torch.exp(torch.tensor(t)) |
|
values = softmax(logits) |
|
return values.squeeze() |
|
|
|
|
|
def encoding_data(data, modal): |
|
""" |
|
Encode the data into ids |
|
|
|
Args: |
|
data (list): List of strings |
|
modal (str): "music" or "text" |
|
|
|
Returns: |
|
ids_list (list): List of ids |
|
""" |
|
ids_list = [] |
|
if modal=="music": |
|
for item in data: |
|
patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True) |
|
ids_list.append(torch.tensor(patches).reshape(-1)) |
|
else: |
|
for item in data: |
|
text_encodings = tokenizer(item, |
|
return_tensors='pt', |
|
truncation=True, |
|
max_length=TEXT_LENGTH) |
|
ids_list.append(text_encodings['input_ids'].squeeze(0)) |
|
|
|
return ids_list |
|
|
|
|
|
def get_features(ids_list, modal): |
|
""" |
|
Get the features from the CLaMP model |
|
|
|
Args: |
|
ids_list (list): List of ids |
|
modal (str): "music" or "text" |
|
|
|
Returns: |
|
features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size) |
|
""" |
|
features_list = [] |
|
print("Extracting "+modal+" features...") |
|
with torch.no_grad(): |
|
for ids in tqdm(ids_list): |
|
ids = ids.unsqueeze(0) |
|
if modal=="text": |
|
masks = torch.tensor([1]*len(ids[0])).unsqueeze(0) |
|
features = model.text_enc(ids.to(device), attention_mask=masks.to(device))['last_hidden_state'] |
|
features = model.avg_pooling(features, masks) |
|
features = model.text_proj(features) |
|
else: |
|
masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0) |
|
features = model.music_enc(ids, masks)['last_hidden_state'] |
|
features = model.avg_pooling(features, masks) |
|
features = model.music_proj(features) |
|
|
|
features_list.append(features[0]) |
|
|
|
return torch.stack(features_list).to(device) |
|
|
|
|
|
def semantic_music_search(query): |
|
""" |
|
Semantic music search |
|
|
|
Args: |
|
query (str): Query string |
|
|
|
Returns: |
|
output (str): Search result |
|
""" |
|
with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f: |
|
key_cache = torch.load(f) |
|
print("\nQuery: "+query+"\n") |
|
|
|
query_ids = encoding_data([unidecode(query)], QUERY_MODAL) |
|
query_feature = get_features(query_ids, QUERY_MODAL) |
|
|
|
key_filenames = key_cache["filenames"] |
|
key_features = key_cache["features"] |
|
|
|
|
|
values = compute_values(query_feature, key_features) |
|
idx = torch.argsort(values)[-1] |
|
filename = key_filenames[idx].split('/')[-1][:-4] |
|
|
|
with open("wikimusictext.json", 'r') as f: |
|
wikimusictext = json.load(f) |
|
|
|
for item in wikimusictext: |
|
if item['title']==filename: |
|
|
|
|
|
|
|
|
|
|
|
print("Title: " + item['title']) |
|
print("Artist: " + item['artist']) |
|
print("Genre: " + item['genre']) |
|
print("Description: " + item['text']) |
|
print("ABC notation:\n" + item['music']) |
|
return item["title"], item["artist"], item["genre"], item["text"], item["music"] |
|
|
|
output_title = gr.outputs.Textbox(label="Title") |
|
output_artist = gr.outputs.Textbox(label="Artist") |
|
output_genre = gr.outputs.Textbox(label="Genre") |
|
output_description = gr.outputs.Textbox(label="Description") |
|
output_abc = gr.outputs.Textbox(label="ABC notation") |
|
|
|
gr.Interface( |
|
fn=semantic_music_search, |
|
inputs=gr.Textbox(lines=2, placeholder="Describe the music you want to search...", label="Query"), |
|
outputs=[output_title, output_artist, output_genre, output_description, output_abc], |
|
title="🗜️ CLaMP: Semantic Music Search", |
|
description=description, |
|
examples=examples).launch() |