Spaces:
No application file
No application file
import requests | |
from PIL import Image | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
import torch | |
import math | |
import streamlit as st | |
import matplotlib.pyplot as plt | |
from torch import nn | |
from torchvision.models import resnet50 | |
import torchvision.transforms as T | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from groq import Groq | |
import re | |
import json | |
# Initialize Groq API key and client | |
GROQ_API_KEY = "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" | |
client = Groq(api_key = GROQ_API_KEY) | |
# Initialize models and processor | |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") | |
torch.set_grad_enabled(False) | |
# COCO classes and colors | |
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] | |
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]] | |
# Image transformation | |
transform = T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | |
# Helper functions | |
def box_cxcywh_to_xyxy(x): | |
x_c, y_c, w, h = x.unbind(1) | |
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
return torch.stack(b, dim=1) | |
def rescale_bboxes(out_bbox, size): | |
img_w, img_h = size | |
b = box_cxcywh_to_xyxy(out_bbox) | |
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
return b | |
def plot_results(pil_img, prob, boxes): | |
plt.figure(figsize=(16,10)) | |
plt.imshow(pil_img) | |
ax = plt.gca() | |
classes_predicted = [] | |
colors = COLORS * 100 | |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors): | |
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3)) | |
cl = p.argmax() | |
text = f'{CLASSES[cl]}: {p[cl]:0.2f}' | |
classes_predicted.append(CLASSES[cl]) | |
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5)) | |
plt.axis('off') | |
st.pyplot(plt) | |
def get_caption(img_url): | |
raw_image = Image.open(img_url).convert('RGB') | |
inputs = processor(raw_image, return_tensors="pt") | |
out = caption_model.generate(**inputs) | |
return str(processor.decode(out[0], skip_special_tokens=True)) | |
def get_objects(url): | |
im = Image.open(url) | |
img = transform(im).unsqueeze(0) | |
outputs = model(img) | |
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1] | |
keep = probas.max(-1).values > 0.9 | |
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size) | |
plot_results(im, probas[keep], bboxes_scaled) | |
return [CLASSES[p.argmax()] for p in probas[keep]] | |
def get_tags(text, objects): | |
system_prompt = """ | |
<SystemPrompt> | |
Extract Tags from the provided text. | |
The Tags that will be used to search. | |
<OutputFormat> | |
Format the output in the following JSON structure | |
{ | |
"tags" : [* list of tags here*] | |
} | |
</OutputFormat> | |
</SystemPrompt> | |
""" | |
try: | |
user_prompt = f"Extract the Tags from this text:\n{text}\n{objects}" | |
chat_completion = client.chat.completions.create( | |
messages=[{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt}], | |
model="llama3-8b-8192", response_format={"type": "json_object"}, | |
stream=False | |
) | |
json_data = json.loads(chat_completion.choices[0].message.content) | |
return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005 | |
except Exception as e: | |
st.error(f"Exception | get_tags | {str(e)}") | |
# Main image to tags function | |
def image_to_tags(image): | |
image = Image.fromarray(image) | |
image.save("saved_image.png") | |
generated_caption = get_caption('saved_image.png') | |
objects = get_objects('saved_image.png') | |
tags, cost = get_tags(generated_caption, ", ".join(objects)) | |
return ", ".join(tags), generated_caption, ", ".join(objects), cost | |
# Streamlit app | |
st.title("Image Tagging App") | |
st.write("Upload an image and get captions, object detection results, and associated tags.") | |
# Image upload | |
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "png", "jpeg"]) | |
if uploaded_image is not None: | |
image = Image.open(uploaded_image) | |
st.image(image, caption='Uploaded Image.', use_column_width=True) | |
# Generate tags, caption, objects, and cost | |
tags, caption, objects, cost = image_to_tags(image) | |
# Display results | |
st.subheader("Predicted Tags:") | |
st.write(tags) | |
st.subheader("Caption:") | |
st.write(caption) | |
st.subheader("Objects Detected:") | |
st.write(objects) | |
st.subheader("Cost:") | |
st.write(f"${cost:.6f}") | |