testwildlife / app.py
ki1207's picture
Update app.py
bfc1711 verified
#!/usr/bin/env python
from __future__ import annotations
import os
import string
import gradio as gr
import PIL.Image
import torch
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
DESCRIPTION = "# [BLIP-2 test](https://github.com/salesforce/LAVIS/tree/main/projects/blip2)"
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU.</p>"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
MODEL_ID = "Salesforce/instructblip-flan-t5-xl"
processor = InstructBlipProcessor.from_pretrained(MODEL_ID)
model = InstructBlipForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
def answer_ad_listing_question(
image: PIL.Image.Image,
title: str,
) -> str:
# The prompt template with the provided title
prompt = f"""Given an ad listing with the title '{title}' and image, answer the following questions without any explanation or extra text:
Identify the species mentioned in the text, including specific names, e.g., 'Nile crocodile' instead of just 'crocodile'.
Select the product type from the following options: Animal fibers, Animal parts (bone or bone-like), Animal parts (fleshy), Coral product, Egg, Extract, Food, Ivory products, Live, Medicine, Nests, Organs and tissues, Powder, Scales or spines, Shells, Skin or leather products, Taxidermy, Insects.
The response should be in the format:
"Product Type: [type]
Species: [species]"
"""
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
generated_ids = model.generate(**inputs)
result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return result
def postprocess_output(output: str) -> str:
# if output and output[-1] not in string.punctuation:
# output += "."
return output
with gr.Blocks() as demo:
gr.Markdown(DESCRIPTION)
with gr.Group():
# Image and ad title input
image = gr.Image(type="pil")
ad_title = gr.Textbox(label="Advertisement Title", placeholder="Enter the title here", lines=1)
# Output section
answer_output = gr.Textbox(label="Analysis", show_label=True, placeholder="Response.")
# Submit and clear buttons
with gr.Row():
submit_button = gr.Button("Analyze Listing", variant="primary")
clear_button = gr.Button("Clear")
# Logic to handle clicking on "Analyze Ad Listing"
submit_button.click(
fn=answer_ad_listing_question,
inputs=[image, ad_title], # Only the image and ad title are inputs
outputs=answer_output,
)
# Logic to handle clearing the inputs and outputs
clear_button.click(
fn=lambda: ("", "", ""), # Clear all the fields
inputs=None,
outputs=[image, ad_title, answer_output],
queue=False,
)
if __name__ == "__main__":
demo.queue(max_size=10).launch()