curacel-demo-2 / app.py
mattraj's picture
Update app.py
fe44ad8 verified
raw
history blame
4.03 kB
import gradio as gr
import PIL.Image
import transformers
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import torch
import os
import string
import functools
import re
import numpy as np
import spaces
from PIL import Image, ImageDraw
import re
model_id = "mattraj/curacel-autodamage-1"
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).eval().to(device)
processor = PaliGemmaProcessor.from_pretrained(model_id)
###### Transformers Inference
@spaces.GPU
def infer(
image: PIL.Image.Image,
text: str,
max_new_tokens: int = 2048
) -> tuple:
inputs = processor(text=text, images=image, return_tensors="pt", padding="longest", do_convert_rgb=True).to(device).to(dtype=model.dtype)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_length=max_new_tokens
)
result = processor.decode(generated_ids[0], skip_special_tokens=True)
# Placeholder to extract bounding box info from the result (you should replace this with actual bounding box extraction)
bounding_boxes = extract_bounding_boxes(result)
# Draw bounding boxes on the image
annotated_image = image.copy()
draw = ImageDraw.Draw(annotated_image)
# Example of drawing bounding boxes (replace with actual coordinates)
for idx, (box, label) in enumerate(bounding_boxes):
color = COLORS[idx % len(COLORS)]
draw.rectangle(box, outline=color, width=3)
draw.text((box[0], box[1]), label, fill=color)
return result, annotated_image
def extract_bounding_boxes(result):
"""
Extract bounding boxes and labels from the model result.
Each bounding box is represented by two locXXXX tags and a label.
Example return: [((x1, y1, x2, y2), "Label")]
"""
bounding_boxes = []
# Regular expression to find <locXXXX> tags and labels
pattern = re.compile(r'<loc(\d{4})><loc(\d{4})>\s*(\S.+?)\s*(?=<loc|\Z)')
matches = pattern.findall(result)
for match in matches:
# Extract x1, y1 from the first loc tag
x1, y1 = int(match[0][:2]), int(match[0][2:])
# Extract x2, y2 from the second loc tag
x2, y2 = int(match[1][:2]), int(match[1][2:])
# Get the label
label = match[2].strip()
# Append the bounding box with the label
bounding_boxes.append(((x1, y1, x2, y2), label))
return bounding_boxes
######## Demo
INTRO_TEXT = """## Curacel Auto Damage demo\n\n
Finetuned from: google/paligemma-3b-pt-448
"""
with gr.Blocks(css="style.css") as demo:
gr.Markdown(INTRO_TEXT)
with gr.Tab("Text Generation"):
with gr.Column():
image = gr.Image(type="pil")
text_input = gr.Text(label="Input Text")
text_output = gr.Text(label="Text Output")
output_image = gr.Image(label="Annotated Image")
chat_btn = gr.Button()
chat_inputs = [image, text_input]
chat_outputs = [text_output, output_image]
chat_btn.click(
fn=infer,
inputs=chat_inputs,
outputs=chat_outputs,
)
examples = [["./car-1.png", "detect Front-Windscreen-Damage ; Headlight-Damage ; Major-Rear-Bumper-Dent ; Rear-windscreen-Damage ; RunningBoard-Dent ; Sidemirror-Damage ; Signlight-Damage ; Taillight-Damage ; bonnet-dent ; doorouter-dent ; doorouter-scratch ; fender-dent ; front-bumper-dent ; front-bumper-scratch ; medium-Bodypanel-Dent ; paint-chip ; paint-trace ; pillar-dent ; quaterpanel-dent ; rear-bumper-dent ; rear-bumper-scratch ; roof-dent"]]
gr.Markdown("")
gr.Examples(
examples=examples,
inputs=chat_inputs,
)
#########
if __name__ == "__main__":
demo.queue(max_size=10).launch(debug=True)