nicolaus625 commited on
Commit
4fe117d
·
verified ·
1 Parent(s): 2a4f061

Update Readme file

Browse files
Files changed (1) hide show
  1. README.md +131 -0
README.md CHANGED
@@ -1,3 +1,134 @@
1
  ---
2
  license: cc-by-nc-4.0
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-nc-4.0
3
+ language:
4
+ - en
5
+ library_name: transformers
6
+ tags:
7
+ - music
8
+ - art
9
  ---
10
+
11
+ ---
12
+ license: cc-by-4.0
13
+ language:
14
+ - en
15
+ tags:
16
+ - music
17
+ - art
18
+ ---
19
+ # Model Card for Model ID
20
+ ## Model Details
21
+ ### Model Description
22
+ The model consists of a music encoder ```MERT-v1-300M```, a natural language decoder ```vicuna-7b-delta-v0```, and a linear projection laer between the two.
23
+
24
+ This checkpoint of MusiLingo is developed on the MusicQA and can answer instructions with music raw audio, such as querying about the tempo, emotion, genre, tags or subjective feelings etc.
25
+ You can use the MusicQA dataset for the following demo. For the implementation of MusicQA, please refer to our [Github repo](https://github.com/zihaod/MusiLingo/blob/main/musilingo/datasets/datasets/musicqa_dataset.py).
26
+
27
+
28
+ ### Model Sources [optional]
29
+ - **Repository:** [GitHub repo](https://github.com/zihaod/MusiLingo)
30
+ - **Paper [optional]:** __[MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response](https://arxiv.org/abs/2309.08730)__
31
+ <!-- - **Demo [optional]:** [More Information Needed] -->
32
+
33
+
34
+
35
+ ## Getting Start
36
+ ```
37
+ from tqdm.auto import tqdm
38
+
39
+ import torch
40
+ from torch.utils.data import DataLoader
41
+ from transformers import Wav2Vec2FeatureExtractor
42
+ from transformers import StoppingCriteria, StoppingCriteriaList
43
+
44
+
45
+
46
+ class StoppingCriteriaSub(StoppingCriteria):
47
+ def __init__(self, stops=[], encounters=1):
48
+ super().__init__()
49
+ self.stops = stops
50
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
51
+ for stop in self.stops:
52
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
53
+ return True
54
+ return False
55
+
56
+ def answer(self, samples, stopping, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.5,
57
+ repetition_penalty=1.0, length_penalty=1, temperature=0.1, max_length=2000):
58
+ audio = samples["audio"].cuda()
59
+ audio_embeds, atts_audio = self.encode_audio(audio)
60
+ if 'instruction_input' in samples: # instruction dataset
61
+ #print('Instruction Batch')
62
+ instruction_prompt = []
63
+ for instruction in samples['instruction_input']:
64
+ prompt = '<Audio><AudioHere></Audio> ' + instruction
65
+ instruction_prompt.append(self.prompt_template.format(prompt))
66
+ audio_embeds, atts_audio = self.instruction_prompt_wrap(audio_embeds, atts_audio, instruction_prompt)
67
+ self.llama_tokenizer.padding_side = "right"
68
+ batch_size = audio_embeds.shape[0]
69
+ bos = torch.ones([batch_size, 1],
70
+ dtype=torch.long,
71
+ device=torch.device('cuda')) * self.llama_tokenizer.bos_token_id
72
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
73
+ atts_bos = atts_audio[:, :1]
74
+ inputs_embeds = torch.cat([bos_embeds, audio_embeds], dim=1)
75
+ attention_mask = torch.cat([atts_bos, atts_audio], dim=1)
76
+ outputs = self.llama_model.generate(
77
+ inputs_embeds=inputs_embeds,
78
+ max_new_tokens=max_new_tokens,
79
+ stopping_criteria=stopping,
80
+ num_beams=num_beams,
81
+ do_sample=True,
82
+ min_length=min_length,
83
+ top_p=top_p,
84
+ repetition_penalty=repetition_penalty,
85
+ length_penalty=length_penalty,
86
+ temperature=temperature,
87
+ )
88
+ output_token = outputs[0]
89
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
90
+ output_token = output_token[1:]
91
+ if output_token[0] == 1: # if there is a start token <s> at the beginning. remove it
92
+ output_token = output_token[1:]
93
+ output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
94
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
95
+ output_text = output_text.split('Assistant:')[-1].strip()
96
+ return output_text
97
+
98
+ processor = Wav2Vec2FeatureExtractor.from_pretrained("m-a-p/MERT-v1-330M",trust_remote_code=True)
99
+ ds = MusicQADataset(processor, f'{path}/data/music_data', 'Eval')
100
+ dl = DataLoader(
101
+ ds,
102
+ batch_size=1,
103
+ num_workers=0,
104
+ pin_memory=True,
105
+ shuffle=False,
106
+ drop_last=True,
107
+ collate_fn=ds.collater
108
+ )
109
+
110
+ stopping = StoppingCriteriaList([StoppingCriteriaSub([torch.tensor([835]).cuda(),
111
+ torch.tensor([2277, 29937]).cuda()])])
112
+
113
+ from transformers import AutoModel
114
+ model_musicqa = AutoModel.from_pretrained("m-a-p/MusiLingo-musicqa-v1")
115
+
116
+ for idx, sample in tqdm(enumerate(dl)):
117
+ ans = answer(Musilingo_musicqa.model, sample, stopping, length_penalty=100, temperature=0.1)
118
+ txt = sample['text_input'][0]
119
+ print(txt)
120
+ print(and)
121
+ ```
122
+
123
+ # Citing This Work
124
+
125
+ If you find the work useful for your research, please consider citing it using the following BibTeX entry:
126
+ ```
127
+ @inproceedings{deng2024musilingo,
128
+ title={MusiLingo: Bridging Music and Text with Pre-trained Language Models for Music Captioning and Query Response},
129
+ author={Deng, Zihao and Ma, Yinghao and Liu, Yudong and Guo, Rongchen and Zhang, Ge and Chen, Wenhu and Huang, Wenhao and Benetos, Emmanouil},
130
+ booktitle={Proceedings of the 2024 Annual Conference of the North American Chapter of the Association for Computational Linguistics (NAACL 2024)},
131
+ year={2024},
132
+ organization={Association for Computational Linguistics}
133
+ }
134
+ ```