suric commited on
Commit
705e089
·
1 Parent(s): 9e2a478

optimize the prompt

Browse files
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
 
3
  import gradio as gr
4
 
5
- from gradio_components.image import generate_caption
6
  from gradio_components.prediction import predict, transcribe
7
 
8
  theme = gr.themes.Glass(
@@ -83,7 +83,7 @@ def generate_prompt(difficulty, style):
83
  "Medum": "player who has 2-3 years experience",
84
  "Hard": "player who has more than 4 years experiences",
85
  }
86
- prompt = "piano only music for a {} to pratice with the touch of {}".format(
87
  _DIFFICULTY_MAPPIN[difficulty], style
88
  )
89
  return prompt
@@ -106,6 +106,16 @@ def toggle_melody_condition(melody_condition):
106
  )
107
 
108
 
 
 
 
 
 
 
 
 
 
 
109
  def show_caption(show_caption_condition, description, prompt):
110
  if show_caption_condition:
111
  return (
@@ -145,6 +155,17 @@ def show_caption(show_caption_condition, description, prompt):
145
  )
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
148
  def post_submit(show_caption, model_path, image_input):
149
  _, description, prompt = generate_caption(image_input, model_path)
150
  return (
@@ -210,12 +231,36 @@ def UI():
210
  )
211
  if style == "Others":
212
  style = gr.Textbox(label="Type your music genre")
213
- prompt = generate_prompt(difficulty.value, style.value)
214
  customize = gr.Checkbox(
215
- label="Customize the prompt", interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
216
  )
217
- if customize:
218
- prompt = gr.Textbox(label="Type your prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  with gr.Column():
220
  with gr.Row():
221
  melody = gr.Audio(
@@ -385,12 +430,12 @@ def UI():
385
  fn=post_submit,
386
  inputs=[show_prompt, model_path, image_input],
387
  outputs=[description, prompt, generate],
388
- )
389
  show_prompt.change(
390
  fn=show_caption,
391
  inputs=[show_prompt, description, prompt],
392
  outputs=[description, prompt, generate],
393
- )
394
  transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
395
  generate.click(
396
  fn=predict,
 
2
 
3
  import gradio as gr
4
 
5
+ from gradio_components.image import generate_caption, improve_prompt
6
  from gradio_components.prediction import predict, transcribe
7
 
8
  theme = gr.themes.Glass(
 
83
  "Medum": "player who has 2-3 years experience",
84
  "Hard": "player who has more than 4 years experiences",
85
  }
86
+ prompt = "piano only music for a {} to practice with the touch of {}".format(
87
  _DIFFICULTY_MAPPIN[difficulty], style
88
  )
89
  return prompt
 
106
  )
107
 
108
 
109
+ def toggle_custom_prompt(customize, difficulty, style):
110
+ if customize:
111
+ return gr.Textbox(label="Type your prompt", interactive=True, visible=True)
112
+ else:
113
+ prompt = generate_prompt(difficulty, style)
114
+ return gr.Textbox(
115
+ label="Generated Prompt", value=prompt, interactive=False, visible=True
116
+ )
117
+
118
+
119
  def show_caption(show_caption_condition, description, prompt):
120
  if show_caption_condition:
121
  return (
 
155
  )
156
 
157
 
158
+ def optimize_fn(prompt):
159
+ message_object, prompt = improve_prompt(prompt)
160
+ return prompt
161
+
162
+
163
+ def display_prompt(prompt):
164
+ return gr.Textbox(
165
+ label="Generated Prompt", value=prompt, interactive=False, visible=True
166
+ )
167
+
168
+
169
  def post_submit(show_caption, model_path, image_input):
170
  _, description, prompt = generate_caption(image_input, model_path)
171
  return (
 
231
  )
232
  if style == "Others":
233
  style = gr.Textbox(label="Type your music genre")
 
234
  customize = gr.Checkbox(
235
+ label="Customize the prompt", interactive=True, value=False
236
+ )
237
+ _init_prompt = generate_prompt(difficulty.value, style.value)
238
+ prompt = gr.Textbox(
239
+ label="",
240
+ value=_init_prompt,
241
+ interactive=False,
242
+ visible=False,
243
+ )
244
+ customize.change(
245
+ fn=toggle_custom_prompt,
246
+ inputs=[customize, difficulty, style],
247
+ outputs=prompt,
248
  )
249
+ print(prompt)
250
+ with gr.Column():
251
+ optimize = gr.Button(
252
+ "Optimize the prompt", interactive=True
253
+ )
254
+ with gr.Column():
255
+ show_prompt = gr.Button("Show the prompt", interactive=True)
256
+ prompt_text = gr.Textbox(
257
+ "Optimized Prompt", interactive=False, visible=False
258
+ )
259
+ optimize.click(optimize_fn, inputs=[prompt], outputs=prompt)
260
+ show_prompt.click(
261
+ display_prompt, inputs=[prompt], outputs=prompt_text
262
+ )
263
+
264
  with gr.Column():
265
  with gr.Row():
266
  melody = gr.Audio(
 
430
  fn=post_submit,
431
  inputs=[show_prompt, model_path, image_input],
432
  outputs=[description, prompt, generate],
433
+ )
434
  show_prompt.change(
435
  fn=show_caption,
436
  inputs=[show_prompt, description, prompt],
437
  outputs=[description, prompt, generate],
438
+ )
439
  transcribe_button.click(transcribe, inputs=[output_audio], outputs=d)
440
  generate.click(
441
  fn=predict,
gradio_components/image.py CHANGED
@@ -28,6 +28,36 @@ Try to make the prompt simple and concise with only 1-2 sentences
28
  Make sure the ouput is in JSON fomat, with two items `description` and `prompt`
29
  """
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def generate_caption(image_file, model_file, progress=gr.Progress()):
33
  if model_file == "facebook/audiogen-medium":
 
28
  Make sure the ouput is in JSON fomat, with two items `description` and `prompt`
29
  """
30
 
31
+ PROMPT_IMPROVEMENT_GENERATE_PROMPT = """
32
+ You are an export llm prompt enginner, you will be helping the user to improve their prompts. here are some examples of good prompts
33
+ - "90s rock song with electric guitar and heavy drums"
34
+ - "An 80s driving pop song with heavy drums and synth pads in the background"
35
+ - "An energetic hip-hop music piece, with synth sounds and strong bass. There is a rhythmic hi-hat patten in the drums."
36
+ - "A grand orchestral arrangement with thunderous percussion, epic brass fanfares, and soaring strings, creating a cinematic atmosphere fit for a heroic battle."
37
+ - "Classic reggae track with an electronic guitar solo"
38
+
39
+ You will be provided with a prompt and you need to improve it. Make sure the prompt is simple and concise with only 1-2 sentences. The output should be in JSON format, with one item `prompt`
40
+ """
41
+
42
+
43
+ def improve_prompt(prompt):
44
+ message = client.messages.create(
45
+ model="claude-3-opus-20240229",
46
+ max_tokens=1024,
47
+ system=PROMPT_IMPROVEMENT_GENERATE_PROMPT,
48
+ messages=[
49
+ {
50
+ "role": "user",
51
+ "content": [
52
+ {"type": "text", "text": prompt},
53
+ ],
54
+ }
55
+ ],
56
+ )
57
+ message_object = json.loads(message.content[0].text)
58
+ prompt = message_object["prompt"]
59
+ return message_object, prompt
60
+
61
 
62
  def generate_caption(image_file, model_file, progress=gr.Progress()):
63
  if model_file == "facebook/audiogen-medium":
gradio_components/prediction.py CHANGED
@@ -8,7 +8,7 @@ import gradio as gr
8
  import torch
9
  from audiocraft.data.audio import audio_write
10
  from audiocraft.data.audio_utils import convert_audio
11
- from audiocraft.models import MusicGen, AudioGen
12
  from basic_pitch import ICASSP_2022_MODEL_PATH
13
  from transformers import AutoModelForSeq2SeqLM
14
 
@@ -69,10 +69,7 @@ def _do_predictions(
69
  else:
70
  if model_file == "facebook/audiogen-medium":
71
  # audio condition
72
- outputs = model.generate(
73
- texts,
74
- progress=progress
75
- )
76
  else:
77
  # text only
78
  outputs = model.generate(texts, progress=progress)
 
8
  import torch
9
  from audiocraft.data.audio import audio_write
10
  from audiocraft.data.audio_utils import convert_audio
11
+ from audiocraft.models import AudioGen, MusicGen
12
  from basic_pitch import ICASSP_2022_MODEL_PATH
13
  from transformers import AutoModelForSeq2SeqLM
14
 
 
69
  else:
70
  if model_file == "facebook/audiogen-medium":
71
  # audio condition
72
+ outputs = model.generate(texts, progress=progress)
 
 
 
73
  else:
74
  # text only
75
  outputs = model.generate(texts, progress=progress)