Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -112,9 +112,18 @@ if torch.__version__ >= "2":
|
|
112 |
model = torch.compile(model)
|
113 |
|
114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
def evaluate(
|
116 |
table,
|
117 |
question,
|
|
|
118 |
input=None,
|
119 |
temperature=0.1,
|
120 |
top_p=0.75,
|
@@ -124,26 +133,34 @@ def evaluate(
|
|
124 |
**kwargs,
|
125 |
):
|
126 |
prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
with torch.no_grad():
|
137 |
-
generation_output = model.generate(
|
138 |
-
input_ids=input_ids,
|
139 |
-
generation_config=generation_config,
|
140 |
-
return_dict_in_generate=True,
|
141 |
-
output_scores=True,
|
142 |
-
max_new_tokens=max_new_tokens,
|
143 |
)
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
return output
|
148 |
|
149 |
|
@@ -151,23 +168,31 @@ def evaluate(
|
|
151 |
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
|
152 |
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
|
153 |
|
154 |
-
def process_document(image, question):
|
155 |
# image = Image.open(image)
|
156 |
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
|
157 |
predictions = model_deplot.generate(**inputs, max_new_tokens=512)
|
158 |
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
|
159 |
|
160 |
# send prompt+table to LLM
|
161 |
-
res = evaluate(table, question)
|
162 |
#return res + "\n\n" + res.split("A:")[-1]
|
163 |
-
|
|
|
|
|
|
|
164 |
|
165 |
description = "Demo for DePlot+LLM for QA and summarisation. [DePlot](https://arxiv.org/abs/2212.10505) is an image-to-text model that converts plots and charts into a textual sequence. The sequence then is used to prompt LLM for chain-of-thought reasoning. The current underlying LLM is [alpaca-lora](https://huggingface.co/spaces/tloen/alpaca-lora). To use it, simply upload your image and type a question or instruction and click 'submit', or click one of the examples to load them. Read more at the links below."
|
166 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"
|
167 |
|
168 |
demo = gr.Interface(
|
169 |
fn=process_document,
|
170 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
171 |
outputs=[
|
172 |
gr.inputs.Textbox(
|
173 |
lines=8,
|
|
|
112 |
model = torch.compile(model)
|
113 |
|
114 |
|
115 |
+
## FLAN-UL2
|
116 |
+
TOKEN = os.environ.get("API_TOKEN", None)
|
117 |
+
API_URL = "https://api-inference.huggingface.co/models/google/flan-ul2"
|
118 |
+
headers = {"Authorization": f"Bearer {TOKEN}"}
|
119 |
+
def query(payload):
|
120 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
121 |
+
return response.json()
|
122 |
+
|
123 |
def evaluate(
|
124 |
table,
|
125 |
question,
|
126 |
+
llm="alpaca-lora",
|
127 |
input=None,
|
128 |
temperature=0.1,
|
129 |
top_p=0.75,
|
|
|
133 |
**kwargs,
|
134 |
):
|
135 |
prompt = _TEMPLATE + "\n" + _add_markup(table) + "\n" + "Q: " + question + "\n" + "A:"
|
136 |
+
if llm == "alpaca-lora":
|
137 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
138 |
+
input_ids = inputs["input_ids"].to(device)
|
139 |
+
generation_config = GenerationConfig(
|
140 |
+
temperature=temperature,
|
141 |
+
top_p=top_p,
|
142 |
+
top_k=top_k,
|
143 |
+
num_beams=num_beams,
|
144 |
+
**kwargs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
)
|
146 |
+
with torch.no_grad():
|
147 |
+
generation_output = model.generate(
|
148 |
+
input_ids=input_ids,
|
149 |
+
generation_config=generation_config,
|
150 |
+
return_dict_in_generate=True,
|
151 |
+
output_scores=True,
|
152 |
+
max_new_tokens=max_new_tokens,
|
153 |
+
)
|
154 |
+
s = generation_output.sequences[0]
|
155 |
+
output = tokenizer.decode(s)
|
156 |
+
elif llm == "flan-ul2":
|
157 |
+
output = query({
|
158 |
+
"inputs": prompt
|
159 |
+
})[0]["generated_text"]
|
160 |
+
|
161 |
+
else:
|
162 |
+
RuntimeError(f"No such LLM: {llm}")
|
163 |
+
|
164 |
return output
|
165 |
|
166 |
|
|
|
168 |
model_deplot = Pix2StructForConditionalGeneration.from_pretrained("google/deplot", torch_dtype=torch.bfloat16).to(0)
|
169 |
processor_deplot = Pix2StructProcessor.from_pretrained("google/deplot")
|
170 |
|
171 |
+
def process_document(llm, image, question):
|
172 |
# image = Image.open(image)
|
173 |
inputs = processor_deplot(images=image, text="Generate the underlying data table for the figure below:", return_tensors="pt").to(0, torch.bfloat16)
|
174 |
predictions = model_deplot.generate(**inputs, max_new_tokens=512)
|
175 |
table = processor_deplot.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>", "\n")
|
176 |
|
177 |
# send prompt+table to LLM
|
178 |
+
res = evaluate(table, question, llm=llm)
|
179 |
#return res + "\n\n" + res.split("A:")[-1]
|
180 |
+
if llm == "alpaca-lora":
|
181 |
+
return [table, res.split("A:")[-1]]
|
182 |
+
else:
|
183 |
+
return [table, res]
|
184 |
|
185 |
description = "Demo for DePlot+LLM for QA and summarisation. [DePlot](https://arxiv.org/abs/2212.10505) is an image-to-text model that converts plots and charts into a textual sequence. The sequence then is used to prompt LLM for chain-of-thought reasoning. The current underlying LLM is [alpaca-lora](https://huggingface.co/spaces/tloen/alpaca-lora). To use it, simply upload your image and type a question or instruction and click 'submit', or click one of the examples to load them. Read more at the links below."
|
186 |
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2212.10505' target='_blank'>DePlot: One-shot visual language reasoning by plot-to-table translation</a></p>"
|
187 |
|
188 |
demo = gr.Interface(
|
189 |
fn=process_document,
|
190 |
+
inputs=[
|
191 |
+
gr.Dropdown(
|
192 |
+
["alpaca-lora", "flan-ul2"], label="LLM", info="Will add more LLMs later!"
|
193 |
+
),
|
194 |
+
"image",
|
195 |
+
"text"],
|
196 |
outputs=[
|
197 |
gr.inputs.Textbox(
|
198 |
lines=8,
|