File size: 7,908 Bytes
55f37e6 87c000d 11bfc95 dd77523 4890cd8 87c000d 11bfc95 87c000d c8fe19a 87c000d ca1dc6f 55f37e6 ab08646 55f37e6 11bfc95 55f37e6 a86e986 55f37e6 a86e986 ca1dc6f ab08646 a86e986 55f37e6 0509b74 a86e986 87c000d c8fe19a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import gradio as gr
import json
from utils import *
from unidecode import unidecode
from transformers import AutoTokenizer
description = """
<div>
<a style="display:inline-block" href='https://github.com/microsoft/muzic/tree/main/clamp'><img src='https://img.shields.io/github/stars/microsoft/muzic?style=social' /></a>
<a style='display:inline-block' href='https://ai-muzic.github.io/clamp/'><img src='https://img.shields.io/badge/website-CLaMP-ff69b4.svg' /></a>
<a style="display:inline-block" href="https://huggingface.co/datasets/sander-wood/wikimusictext"><img src="https://img.shields.io/badge/huggingface-dataset-ffcc66.svg"></a>
<a style="display:inline-block" href="https://arxiv.org/pdf/2304.11029.pdf"><img src="https://img.shields.io/badge/arXiv-2304.11029-b31b1b.svg"></a>
<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/sander-wood/clamp_semantic_music_search?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg" alt="Duplicate Space"></a>
</div>
## ℹ️ How to use this demo?
1. Enter a query in the text box.
2. Click "Submit" and wait for the result.
3. It will return the most matching music score from the WikiMusictext dataset (1010 scores in total).
## ❕Notice
- The text box is case-sensitive.
- You can enter longer text for the text box, but the demo will only use the first 128 tokens.
- The returned results include the title, artist, genre, description, and the score in ABC notation.
- The genre and description may not be accurate, as they are collected from the web.
- The demo is based on CLaMP-S/512, a CLaMP model with 6-layer Transformer text/music encoders and a sequence length of 512.
## 🔠👉🎵 Semantic Music Search
Semantic search is a technique for retrieving music by open-domain queries, which differs from traditional keyword-based searches that depend on exact matches or meta-information. This involves two steps: 1) extracting music features from all scores in the library, and 2) transforming the query into a text feature. By calculating the similarities between the text feature and the music features, it can efficiently locate the score that best matches the user's query in the library.
"""
examples = [
"Jazz standard in Minor key with a swing feel.",
"Jazz standard in Major key with a fast tempo.",
"Jazz standard in Blues form with a soulfoul melody.",
"a painting of a starry night with the moon in the sky",
"a green field with a blue sky and clouds",
"a beach with a castle on top of it"
]
CLAMP_MODEL_NAME = 'sander-wood/clamp-small-512'
QUERY_MODAL = 'text'
KEY_MODAL = 'music'
TOP_N = 1
TEXT_MODEL_NAME = 'distilroberta-base'
TEXT_LENGTH = 128
device = torch.device("cpu")
# load CLaMP model
model = CLaMP.from_pretrained(CLAMP_MODEL_NAME)
music_length = model.config.max_length
model = model.to(device)
model.eval()
# initialize patchilizer, tokenizer, and softmax
patchilizer = MusicPatchilizer()
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
softmax = torch.nn.Softmax(dim=1)
def compute_values(Q_e, K_e, t=1):
"""
Compute the values for the attention matrix
Args:
Q_e (torch.Tensor): Query embeddings
K_e (torch.Tensor): Key embeddings
t (float): Temperature for the softmax
Returns:
values (torch.Tensor): Values for the attention matrix
"""
# Normalize the feature representations
Q_e = torch.nn.functional.normalize(Q_e, dim=1)
K_e = torch.nn.functional.normalize(K_e, dim=1)
# Scaled pairwise cosine similarities [1, n]
logits = torch.mm(Q_e, K_e.T) * torch.exp(torch.tensor(t))
values = softmax(logits)
return values.squeeze()
def encoding_data(data, modal):
"""
Encode the data into ids
Args:
data (list): List of strings
modal (str): "music" or "text"
Returns:
ids_list (list): List of ids
"""
ids_list = []
if modal=="music":
for item in data:
patches = patchilizer.encode(item, music_length=music_length, add_eos_patch=True)
ids_list.append(torch.tensor(patches).reshape(-1))
else:
for item in data:
text_encodings = tokenizer(item,
return_tensors='pt',
truncation=True,
max_length=TEXT_LENGTH)
ids_list.append(text_encodings['input_ids'].squeeze(0))
return ids_list
def get_features(ids_list, modal):
"""
Get the features from the CLaMP model
Args:
ids_list (list): List of ids
modal (str): "music" or "text"
Returns:
features_list (torch.Tensor): Tensor of features with a shape of (batch_size, hidden_size)
"""
features_list = []
print("Extracting "+modal+" features...")
with torch.no_grad():
for ids in tqdm(ids_list):
ids = ids.unsqueeze(0)
if modal=="text":
masks = torch.tensor([1]*len(ids[0])).unsqueeze(0)
features = model.text_enc(ids.to(device), attention_mask=masks.to(device))['last_hidden_state']
features = model.avg_pooling(features, masks)
features = model.text_proj(features)
else:
masks = torch.tensor([1]*(int(len(ids[0])/PATCH_LENGTH))).unsqueeze(0)
features = model.music_enc(ids, masks)['last_hidden_state']
features = model.avg_pooling(features, masks)
features = model.music_proj(features)
features_list.append(features[0])
return torch.stack(features_list).to(device)
def semantic_music_search(query):
"""
Semantic music search
Args:
query (str): Query string
Returns:
output (str): Search result
"""
with open(KEY_MODAL+"_key_cache_"+str(music_length)+".pth", 'rb') as f:
key_cache = torch.load(f)
print("\nQuery: "+query+"\n")
# encode query
query_ids = encoding_data([unidecode(query)], QUERY_MODAL)
query_feature = get_features(query_ids, QUERY_MODAL)
key_filenames = key_cache["filenames"]
key_features = key_cache["features"]
# compute values
values = compute_values(query_feature, key_features)
idx = torch.argsort(values)[-1]
filename = key_filenames[idx].split('/')[-1][:-4]
with open("wikimusictext.json", 'r') as f:
wikimusictext = json.load(f)
for item in wikimusictext:
if item['title']==filename:
# output = "Title:\n" + item['title']+'\n\n'
# output += "Artist:\n" + item['artist']+ '\n\n'
# output += "Genre:\n" + item['genre']+ '\n\n'
# output += "Description:\n" + item['text']+ '\n\n'
# output += "ABC notation:\n" + item['music']+ '\n\n'
print("Title: " + item['title'])
print("Artist: " + item['artist'])
print("Genre: " + item['genre'])
print("Description: " + item['text'])
print("ABC notation:\n" + item['music'])
return item["title"], item["artist"], item["genre"], item["text"], item["music"]
output_title = gr.outputs.Textbox(label="Title")
output_artist = gr.outputs.Textbox(label="Artist")
output_genre = gr.outputs.Textbox(label="Genre")
output_description = gr.outputs.Textbox(label="Description")
output_abc = gr.outputs.Textbox(label="ABC notation")
gr.Interface(
fn=semantic_music_search,
inputs=gr.Textbox(lines=2, placeholder="Describe the music you want to search...", label="Query"),
outputs=[output_title, output_artist, output_genre, output_description, output_abc],
title="🗜️ CLaMP: Semantic Music Search",
description=description,
examples=examples).launch() |