jgwill commited on
Commit
31532c1
Β·
1 Parent(s): 872656b

imported:ChatMusician-space

Browse files
Files changed (6) hide show
  1. README.md +8 -6
  2. app.py +298 -0
  3. index.html +0 -19
  4. packages.txt +2 -0
  5. requirements.txt +3 -0
  6. style.css +0 -28
README.md CHANGED
@@ -1,11 +1,13 @@
1
  ---
2
- title: Orpheuscm01
3
- emoji: πŸ†
4
- colorFrom: red
5
- colorTo: pink
6
- sdk: static
 
 
7
  pinned: false
8
- license: mit
9
  ---
10
 
11
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ChatMusician
3
+ emoji: πŸ’»
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: app.py
9
  pinned: false
10
+ license: cc
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import copy
4
+ import time
5
+ import logging
6
+ import subprocess
7
+ from uuid import uuid4
8
+ import gradio as gr
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer
11
+ from transformers.generation import GenerationConfig
12
+ import spaces
13
+
14
+ os.environ['QT_QPA_PLATFORM']='offscreen'
15
+
16
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
17
+ torch.backends.cuda.enable_flash_sdp(False)
18
+
19
+ # log_dir
20
+ os.makedirs("logs", exist_ok=True)
21
+ os.makedirs("tmp", exist_ok=True)
22
+ logging.basicConfig(
23
+ filename=f'logs/chatmusician_server_{time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(time.time()))}.log',
24
+ level=logging.WARNING,
25
+ format='%(asctime)s [%(levelname)s]: %(message)s',
26
+ datefmt='%Y-%m-%d %H:%M:%S'
27
+ )
28
+
29
+ MODEL_PATH = 'm-a-p/ChatMusician'
30
+
31
+
32
+ def get_uuid():
33
+ return str(uuid4())
34
+
35
+
36
+ # todo
37
+ def log_conversation(conversation_id, history, messages, response, generate_kwargs):
38
+ timestamp = time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime(time.time()))
39
+ data = {
40
+ "conversation_id": conversation_id,
41
+ "timestamp": timestamp,
42
+ "history": history,
43
+ "messages": messages,
44
+ "response": response,
45
+ "generate_kwargs": generate_kwargs,
46
+ }
47
+ logging.critical(f"{data}")
48
+
49
+
50
+ def _parse_text(text):
51
+ lines = text.split("\n")
52
+ lines = [line for line in lines if line != ""]
53
+ count = 0
54
+ for i, line in enumerate(lines):
55
+ if "```" in line:
56
+ count += 1
57
+ items = line.split("`")
58
+ if count % 2 == 1:
59
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
60
+ else:
61
+ lines[i] = f"<br></code></pre>"
62
+ else:
63
+ if i > 0:
64
+ if count % 2 == 1:
65
+ line = line.replace("`", r"\`")
66
+ line = line.replace("<", "&lt;")
67
+ line = line.replace(">", "&gt;")
68
+ line = line.replace(" ", "&nbsp;")
69
+ line = line.replace("*", "&ast;")
70
+ line = line.replace("_", "&lowbar;")
71
+ line = line.replace("-", "&#45;")
72
+ line = line.replace(".", "&#46;")
73
+ line = line.replace("!", "&#33;")
74
+ line = line.replace("(", "&#40;")
75
+ line = line.replace(")", "&#41;")
76
+ line = line.replace("$", "&#36;")
77
+ lines[i] = "<br>" + line
78
+ text = "".join(lines)
79
+ return text
80
+
81
+
82
+ def convert_history_to_text(task_history):
83
+ history_cp = copy.deepcopy(task_history)
84
+ text = "".join(
85
+ [f"Human: {item[0]} </s> Assistant: {item[1]} </s> " for item in history_cp[:-1] if item[0]]
86
+ )
87
+ text += f"Human: {history_cp[-1][0]} </s> Assistant: "
88
+ return text
89
+
90
+ # todo
91
+ def postprocess_abc(text, conversation_id):
92
+ os.makedirs(f"tmp/{conversation_id}", exist_ok=True)
93
+ abc_pattern = r'(X:\d+\n(?:[^\n]*\n)+)'
94
+ abc_notation = re.findall(abc_pattern, text+'\n')
95
+ print(f'extract abc block: {abc_notation}')
96
+ if abc_notation:
97
+ ts = time.time()
98
+ # Write the ABC text to a temporary file
99
+ tmp_abc = f"tmp/{conversation_id}/{ts}.abc"
100
+ with open(tmp_abc, "w") as abc_file:
101
+ abc_file.write(abc_notation[0])
102
+ # Convert abc notation to midi
103
+ tmp_midi = f'tmp/{conversation_id}/{ts}.mid'
104
+ subprocess.run(["abc2midi", str(tmp_abc), "-o", tmp_midi])
105
+ # Convert abc notation to SVG
106
+ svg_file = f'tmp/{conversation_id}/{ts}.svg'
107
+ audio_file = f'tmp/{conversation_id}/{ts}.mp3'
108
+ subprocess.run(["musescore", "-o", svg_file, tmp_midi], capture_output=True, text=True)
109
+ subprocess.run(["musescore","-o", audio_file, tmp_midi])
110
+ return svg_file.replace(".svg", "-1.svg"), audio_file
111
+ else:
112
+ return None, None
113
+
114
+
115
+ def _launch_demo(model, tokenizer):
116
+
117
+ @spaces.GPU
118
+ def predict(_chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id):
119
+ query = task_history[-1][0]
120
+ print("User: " + _parse_text(query))
121
+ # model generation
122
+ messages = convert_history_to_text(task_history)
123
+ inputs = tokenizer(messages, return_tensors="pt", add_special_tokens=False)
124
+ generation_config = GenerationConfig(
125
+ temperature=float(temperature),
126
+ top_p = float(top_p),
127
+ top_k = top_k,
128
+ repetition_penalty = float(repetition_penalty),
129
+ max_new_tokens=1536,
130
+ min_new_tokens=5,
131
+ do_sample=True,
132
+ num_beams=1,
133
+ num_return_sequences=1
134
+ )
135
+ response = model.generate(
136
+ input_ids=inputs["input_ids"].to(model.device),
137
+ attention_mask=inputs['attention_mask'].to(model.device),
138
+ eos_token_id=tokenizer.eos_token_id,
139
+ pad_token_id=tokenizer.eos_token_id,
140
+ generation_config=generation_config,
141
+ )
142
+ response = tokenizer.decode(response[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
143
+ _chatbot[-1] = (_parse_text(query), _parse_text(response))
144
+ task_history[-1] = (_parse_text(query), response)
145
+ # log
146
+ log_conversation(conversation_id, task_history, messages, _chatbot[-1][1], generation_config.to_json_string())
147
+ return _chatbot, task_history
148
+
149
+ def process_and_render_abc(_chatbot, task_history, conversation_id):
150
+ svg_file, wav_file = None, None
151
+ try:
152
+ svg_file, wav_file = postprocess_abc(task_history[-1][1], conversation_id)
153
+ except Exception as e:
154
+ logging.error(e)
155
+
156
+ if svg_file and wav_file:
157
+ if os.path.exists(svg_file) and os.path.exists(wav_file):
158
+ logging.critical(f"generate: svg: {svg_file} wav: {wav_file}")
159
+ print(f"generate:\n{svg_file}\n{wav_file}")
160
+ _chatbot.append((None, (str(wav_file),)))
161
+ _chatbot.append((None, (str(svg_file),)))
162
+ else:
163
+ logging.error(f"fail to convert: {svg_file[:-4]}.musicxml")
164
+ return _chatbot
165
+
166
+ def add_text(history, task_history, text):
167
+ history = history + [(_parse_text(text), None)]
168
+ task_history = task_history + [(text, None)]
169
+ return history, task_history, ""
170
+
171
+ def reset_user_input():
172
+ return gr.update(value="")
173
+
174
+ def reset_state(task_history):
175
+ task_history.clear()
176
+ return []
177
+
178
+ with gr.Blocks() as demo:
179
+ conversation_id = gr.State(get_uuid)
180
+ gr.Markdown(
181
+ """<h1><center>Chat Musician</center></h1>"""
182
+ )
183
+ gr.Markdown("""\
184
+ <center><font size=4><a href="https://ezmonyi.github.io/ChatMusician/">🌐 DemoPage</a>&nbsp |
185
+ &nbsp<a href="https://github.com/hf-lin/ChatMusician">πŸ’» Github</a>&nbsp |
186
+ &nbsp<a href="http://arxiv.org/abs/2402.16153">πŸ“– arXiv</a>&nbsp |
187
+ &nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicTheoryBench">πŸ€— Benchmark</a>&nbsp |
188
+ &nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicPile">πŸ€— Pretrain Dataset</a>&nbsp |
189
+ &nbsp<a href="https://huggingface.co/datasets/m-a-p/MusicPile-sft">πŸ€— SFT Dataset</a></center>""")
190
+ gr.Markdown("""\
191
+ <center><font size=4>πŸ’‘Note: The music clips on this page is auto-converted using musescore2 which may not be perfect,
192
+ and we recommend using better software for analysis.</center>""")
193
+
194
+ chatbot = gr.Chatbot(label='ChatMusician', elem_classes="control-height", height=750)
195
+ query = gr.Textbox(lines=2, label='Input')
196
+ task_history = gr.State([])
197
+
198
+ with gr.Row():
199
+ submit_btn = gr.Button("πŸš€ Submit (发送)")
200
+ empty_bin = gr.Button("🧹 Clear History (ζΈ…ι™€εŽ†ε²)")
201
+ # regen_btn = gr.Button("πŸ€”οΈ Regenerate (重试)")
202
+ gr.Examples(
203
+ examples=[
204
+ ["Create music by following the alphabetic representation of the assigned musical structure and the given motif.\n'ABCA';X:1\nL:1/16\nM:2/4\nK:A\n['E2GB d2c2 B2A2', 'D2 C2E2 A2c2']"],
205
+ ["Develop a melody using the given chord pattern.\n'C', 'C', 'G/D', 'D', 'G', 'C', 'G', 'G', 'C', 'C', 'F', 'C/G', 'G7', 'C'"],
206
+ ["Create sheet music in ABC notation from the provided text.\nAlternative title: \nThe Legacy\nKey: G\nMeter: 6/8\nNote Length: 1/8\nRhythm: Jig\nOrigin: English\nTranscription: John Chambers"],
207
+ ],
208
+ inputs=query
209
+ )
210
+ with gr.Row():
211
+ with gr.Accordion("Advanced Options:", open=False):
212
+ with gr.Row():
213
+ with gr.Column():
214
+ with gr.Row():
215
+ temperature = gr.Slider(
216
+ label="Temperature",
217
+ value=0.2,
218
+ minimum=0.0,
219
+ maximum=10.0,
220
+ step=0.1,
221
+ interactive=True,
222
+ info="Higher values produce more diverse outputs",
223
+ )
224
+ with gr.Column():
225
+ with gr.Row():
226
+ top_p = gr.Slider(
227
+ label="Top-p (nucleus sampling)",
228
+ value=0.9,
229
+ minimum=0.0,
230
+ maximum=1,
231
+ step=0.01,
232
+ interactive=True,
233
+ info=(
234
+ "Sample from the smallest possible set of tokens whose cumulative probability "
235
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
236
+ ),
237
+ )
238
+ with gr.Column():
239
+ with gr.Row():
240
+ top_k = gr.Slider(
241
+ label="Top-k",
242
+ value=40,
243
+ minimum=0.0,
244
+ maximum=200,
245
+ step=1,
246
+ interactive=True,
247
+ info="Sample from a shortlist of top-k tokens β€” 0 to disable and sample from all tokens.",
248
+ )
249
+ with gr.Column():
250
+ with gr.Row():
251
+ repetition_penalty = gr.Slider(
252
+ label="Repetition Penalty",
253
+ value=1.1,
254
+ minimum=1.0,
255
+ maximum=2.0,
256
+ step=0.1,
257
+ interactive=True,
258
+ info="Penalize repetition β€” 1.0 to disable.",
259
+ )
260
+
261
+ submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history], queue=False).then(
262
+ predict,
263
+ [chatbot, task_history, temperature, top_p, top_k, repetition_penalty, conversation_id],
264
+ [chatbot, task_history],
265
+ show_progress=True,
266
+ queue=True
267
+ ).then(process_and_render_abc, [chatbot, task_history, conversation_id], [chatbot])
268
+ submit_btn.click(reset_user_input, [], [query])
269
+ empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
270
+
271
+ gr.Markdown(
272
+ "Disclaimer: The model can produce factually incorrect output, and should not be relied on to produce "
273
+ "factually accurate information. The model was trained on various public datasets; while great efforts "
274
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
275
+ "biased, or otherwise offensive outputs.",
276
+ elem_classes=["disclaimer"],
277
+ )
278
+
279
+ return demo
280
+
281
+
282
+ tokenizer = AutoTokenizer.from_pretrained(
283
+ MODEL_PATH
284
+ )
285
+
286
+ model = AutoModelForCausalLM.from_pretrained(
287
+ MODEL_PATH,
288
+ device_map='cuda',
289
+ torch_dtype=torch.float16
290
+ ).eval()
291
+
292
+ model.generation_config = GenerationConfig.from_pretrained(
293
+ MODEL_PATH
294
+ )
295
+
296
+ app = _launch_demo(model, tokenizer)
297
+
298
+ app.queue().launch()
index.html DELETED
@@ -1,19 +0,0 @@
1
- <!doctype html>
2
- <html>
3
- <head>
4
- <meta charset="utf-8" />
5
- <meta name="viewport" content="width=device-width" />
6
- <title>My static Space</title>
7
- <link rel="stylesheet" href="style.css" />
8
- </head>
9
- <body>
10
- <div class="card">
11
- <h1>Welcome to your static Space!</h1>
12
- <p>You can modify this app directly by editing <i>index.html</i> in the Files and versions tab.</p>
13
- <p>
14
- Also don't forget to check the
15
- <a href="https://huggingface.co/docs/hub/spaces" target="_blank">Spaces documentation</a>.
16
- </p>
17
- </div>
18
- </body>
19
- </html>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ abcmidi
2
+ musescore
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.2.1
2
+ transformers==4.32.0
3
+ accelerate==0.25.0
style.css DELETED
@@ -1,28 +0,0 @@
1
- body {
2
- padding: 2rem;
3
- font-family: -apple-system, BlinkMacSystemFont, "Arial", sans-serif;
4
- }
5
-
6
- h1 {
7
- font-size: 16px;
8
- margin-top: 0;
9
- }
10
-
11
- p {
12
- color: rgb(107, 114, 128);
13
- font-size: 15px;
14
- margin-bottom: 10px;
15
- margin-top: 5px;
16
- }
17
-
18
- .card {
19
- max-width: 620px;
20
- margin: 0 auto;
21
- padding: 16px;
22
- border: 1px solid lightgray;
23
- border-radius: 16px;
24
- }
25
-
26
- .card p:last-child {
27
- margin-bottom: 0;
28
- }