hanzla commited on
Commit
2e92739
1 Parent(s): b8fcf34

chat interface

Browse files
Files changed (1) hide show
  1. app.py +27 -14
app.py CHANGED
@@ -4,21 +4,22 @@ import re
4
  import gradio as gr
5
  from threading import Thread
6
  from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
7
-
8
  import subprocess
 
 
9
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
 
11
  model_id = "vikhyatk/moondream2"
12
  revision = "2024-04-02"
13
  tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
14
  moondream = AutoModelForCausalLM.from_pretrained(
15
  model_id, trust_remote_code=True, revision=revision,
16
  torch_dtype=torch.bfloat16, device_map={"": "cuda"},
17
- attn_implementation="flash_attention_2"
18
- )
19
  moondream.eval()
20
 
21
-
22
  @spaces.GPU(duration=10)
23
  def answer_question(img, prompt):
24
  image_embeds = moondream.encode_image(img)
@@ -33,14 +34,13 @@ def answer_question(img, prompt):
33
  },
34
  )
35
  thread.start()
36
-
37
  buffer = ""
38
  for new_text in streamer:
39
  buffer += new_text
40
  yield buffer.strip()
41
 
42
-
43
- with gr.Blocks(theme="Glass") as demo:
44
  gr.Markdown(
45
  """
46
  # AskMoondream: Moondream 2 Demonstration Space
@@ -48,13 +48,26 @@ with gr.Blocks(theme="Glass") as demo:
48
  Modularity AI presents this open source huggingface space for running fast experimental inferences on Moondream2.
49
  """
50
  )
51
- with gr.Row():
52
- prompt = gr.Textbox(label="Input", value="Describe this image.", scale=4)
53
- submit = gr.Button("Submit")
 
 
54
  with gr.Row():
55
  img = gr.Image(type="pil", label="Upload an Image")
56
- output = gr.TextArea(label="Response")
57
- submit.click(answer_question, [img, prompt], output)
58
- prompt.submit(answer_question, [img, prompt], output)
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- demo.queue().launch()
 
4
  import gradio as gr
5
  from threading import Thread
6
  from transformers import TextIteratorStreamer, AutoTokenizer, AutoModelForCausalLM
 
7
  import subprocess
8
+
9
+ # Install flash-attn for faster inference
10
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
 
12
+ # Model and tokenizer setup
13
  model_id = "vikhyatk/moondream2"
14
  revision = "2024-04-02"
15
  tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
16
  moondream = AutoModelForCausalLM.from_pretrained(
17
  model_id, trust_remote_code=True, revision=revision,
18
  torch_dtype=torch.bfloat16, device_map={"": "cuda"},
19
+ attn_implementation="flash_attention_2")
 
20
  moondream.eval()
21
 
22
+ # Function to generate responses
23
  @spaces.GPU(duration=10)
24
  def answer_question(img, prompt):
25
  image_embeds = moondream.encode_image(img)
 
34
  },
35
  )
36
  thread.start()
 
37
  buffer = ""
38
  for new_text in streamer:
39
  buffer += new_text
40
  yield buffer.strip()
41
 
42
+ # Create the Gradio interface
43
+ with gr.Blocks(theme="Monochrome") as demo:
44
  gr.Markdown(
45
  """
46
  # AskMoondream: Moondream 2 Demonstration Space
 
48
  Modularity AI presents this open source huggingface space for running fast experimental inferences on Moondream2.
49
  """
50
  )
51
+
52
+ # Chatbot layout
53
+ chatbot = gr.Chatbot()
54
+
55
+ # Image upload and prompt input
56
  with gr.Row():
57
  img = gr.Image(type="pil", label="Upload an Image")
58
+ prompt = gr.Textbox(label="Your Question", placeholder="Ask something about the image...", show_label=False)
59
+
60
+ # Send message button
61
+ send_btn = gr.Button("Send")
62
+
63
+ # Function to send message and get response
64
+ def send_message(history, prompt):
65
+ history.append((prompt, None))
66
+ response = answer_question(img.value, prompt)
67
+ history.append((None, response))
68
+ return history, "" # Clear the input box
69
+
70
+ send_btn.click(send_message, [chatbot, prompt], [chatbot, prompt])
71
+ prompt.submit(send_message, [chatbot, prompt], [chatbot, prompt])
72
 
73
+ demo.queue().launch()