Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
from __future__ import annotations | |
import os | |
import string | |
import gradio as gr | |
import PIL.Image | |
import torch | |
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration | |
DESCRIPTION = "# [BLIP-2 test](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)" | |
if not torch.cuda.is_available(): | |
DESCRIPTION += "\n<p>Running on CPU.</p>" | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
MODEL_ID = "Salesforce/instructblip-flan-t5-xl" | |
processor = InstructBlipProcessor.from_pretrained(MODEL_ID) | |
model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device) | |
def answer_ad_listing_question( | |
image: PIL.Image.Image, | |
title: str, | |
) -> str: | |
# The prompt template with the provided title | |
prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text: | |
Identify the species mentioned in the text, including specific names, e.g., 'Nile crocodile' instead of just 'crocodile'. | |
Select the product type from the following options: Animal fibers, Animal parts (bone or bone-like), Animal parts (fleshy), Coral product, Egg, Extract, Food, Ivory products, Live, Medicine, Nests, Organs and tissues, Powder, Scales or spines, Shells, Skin or leather products, Taxidermy, Insects. | |
The response should be in the format: | |
"Product Type: [type] | |
Species: [species]" | |
""" | |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16) | |
generated_ids = model.generate(**inputs) | |
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() | |
return result | |
def postprocess_output(output: str) -> str: | |
# if output and output[-1] not in string.punctuation: | |
# output += "." | |
return output | |
with gr.Blocks() as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Group(): | |
# Image and ad title input | |
image = gr.Image(type="pil") | |
ad_title = gr.Textbox(label="Advertisement Title", placeholder="Enter the title here", lines=1) | |
# Output section | |
answer_output = gr.Textbox(label="Analysis", show_label=True, placeholder="Response.") | |
# Submit and clear buttons | |
with gr.Row(): | |
submit_button = gr.Button("Analyze Listing", variant="primary") | |
clear_button = gr.Button("Clear") | |
# Logic to handle clicking on "Analyze Ad Listing" | |
submit_button.click( | |
fn=answer_ad_listing_question, | |
inputs=[image, ad_title], # Only the image and ad title are inputs | |
outputs=answer_output, | |
) | |
# Logic to handle clearing the inputs and outputs | |
clear_button.click( | |
fn=lambda: ("", "", ""), # Clear all the fields | |
inputs=None, | |
outputs=[image, ad_title, answer_output], | |
queue=False, | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=10).launch() | |