Salma Mayorquin commited on
Commit
0945ad6
·
1 Parent(s): a6e329b

initial commit

Browse files
Files changed (5) hide show
  1. README.md +3 -3
  2. app.py +83 -0
  3. examples/warehouse_1.jpg +0 -0
  4. examples/warehouse_2.jpg +0 -0
  5. requirements.txt +28 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: SpaceLLaVA
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.20.1
8
  app_file: app.py
 
1
  ---
2
  title: SpaceLLaVA
3
+ emoji: 🛸
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.20.1
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import base64
3
+ import numpy as np
4
+ import torch
5
+ import matplotlib
6
+ import matplotlib.cm
7
+ import gradio as gr
8
+ from PIL import Image
9
+
10
+ from llama_cpp import Llama
11
+ from llama_cpp.llama_chat_format import Llava15ChatHandler
12
+
13
+ # Converts an image input (PIL Image or file path) into a base64 data URI
14
+ def image_to_base64_data_uri(image_input):
15
+ if isinstance(image_input, str):
16
+ with open(image_input, "rb") as img_file:
17
+ base64_data = base64.b64encode(img_file.read()).decode('utf-8')
18
+ elif isinstance(image_input, Image.Image):
19
+ buffer = io.BytesIO()
20
+ image_input.save(buffer, format="PNG")
21
+ base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
22
+ else:
23
+ raise ValueError("Unsupported input type. Input must be a file path or a PIL.Image.Image instance.")
24
+ return f"data:image/png;base64,{base64_data}"
25
+
26
+ class Llava:
27
+ def __init__(self, mmproj="model/mmproj-model-f16.gguf", model_path="model/ggml-model-q4_0.gguf", gpu=False):
28
+ chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True)
29
+ n_gpu_layers = 0
30
+ if gpu:
31
+ n_gpu_layers = -1
32
+ self.llm = Llama(model_path=model_path, chat_handler=chat_handler, n_ctx=2048, logits_all=True, n_gpu_layers=n_gpu_layers)
33
+
34
+ def run_inference(self, image, prompt):
35
+ data_uri = image_to_base64_data_uri(image)
36
+ res = self.llm.create_chat_completion(
37
+ messages=[
38
+ {"role": "system", "content": "You are an assistant who perfectly describes images."},
39
+ {
40
+ "role": "user",
41
+ "content": [
42
+ {"type": "image_url", "image_url": {"url": data_uri}},
43
+ {"type": "text", "text": prompt}
44
+ ]
45
+ }
46
+ ]
47
+ )
48
+ return res["choices"][0]["message"]["content"]
49
+
50
+ # Initialize the model
51
+ llm_model = Llava()
52
+
53
+ title_and_links_markdown = """
54
+ # 🛸SpaceLLaVA🌋: A spatial reasoning multi-modal model
55
+ This space hosts our initial release of LLaVA 1.5 LoRA tuned for spatial reasoning using data generated with [VQASynth](https://github.com/remyxai/VQASynth).
56
+ Upload an image and ask a question.
57
+
58
+ [Model](https://huggingface.co/remyxai/SpaceLLaVA) | [Code](https://github.com/remyxai/VQASynth) | [Paper](https://spatial-vlm.github.io)
59
+ """
60
+
61
+ def predict(image, prompt):
62
+ result = llm_model.run_inference(image, prompt)
63
+ return result
64
+
65
+ image_input = gr.inputs.Image(type="pil", label="Input Image")
66
+ text_input = gr.inputs.Textbox(label="Prompt")
67
+
68
+ # Initialize interface with examples
69
+ iface = gr.Interface(
70
+ fn=predict,
71
+ inputs=[image_input, text_input],
72
+ outputs="text",
73
+ title="Llava Model Inference",
74
+ description="Input an image and a prompt to receive a description."
75
+ )
76
+
77
+ examples = [
78
+ ["examples/warehouse_1.jpg", "Is the man wearing gray pants to the left of the pile of boxes on a pallet?"],
79
+ ["examples/warehouse_2.jpg", "Is the forklift taller than the shelves of boxes?"],
80
+ ]
81
+
82
+ iface.examples = examples
83
+ iface.launch()
examples/warehouse_1.jpg ADDED
examples/warehouse_2.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip
2
+ einops
3
+ fastapi
4
+ gradio==3.35.2
5
+ markdown2[all]
6
+ numpy
7
+ requests
8
+ sentencepiece
9
+ tokenizers>=0.12.1
10
+ torch==2.0.1
11
+ torchvision==0.15.2
12
+ uvicorn
13
+ wandb
14
+ shortuuid
15
+ pillow
16
+ httpx==0.24.0
17
+ deepspeed==0.9.5
18
+ peft==0.4.0
19
+ transformers==4.31.0
20
+ accelerate==0.21.0
21
+ bitsandbytes==0.41.0
22
+ scikit-learn==1.2.2
23
+ sentencepiece==0.1.99
24
+ einops==0.6.1
25
+ einops-exts==0.0.4
26
+ llama-cpp-python==0.2.55
27
+ timm==0.6.13
28
+ gradio_client==0.2.9