Spaces:
No application file
No application file
File size: 5,828 Bytes
e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 e1e5a9b d2e2636 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import requests
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
import math
import streamlit as st
import matplotlib.pyplot as plt
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
from transformers import BlipProcessor, BlipForConditionalGeneration
from groq import Groq
import re
import json
# Initialize Groq API key and client
GROQ_API_KEY = "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"
client = Groq(api_key = GROQ_API_KEY)
# Initialize models and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
torch.set_grad_enabled(False)
# COCO classes and colors
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
# Image transformation
transform = T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# Helper functions
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
def plot_results(pil_img, prob, boxes):
plt.figure(figsize=(16,10))
plt.imshow(pil_img)
ax = plt.gca()
classes_predicted = []
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
cl = p.argmax()
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
classes_predicted.append(CLASSES[cl])
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
plt.axis('off')
st.pyplot(plt)
def get_caption(img_url):
raw_image = Image.open(img_url).convert('RGB')
inputs = processor(raw_image, return_tensors="pt")
out = caption_model.generate(**inputs)
return str(processor.decode(out[0], skip_special_tokens=True))
def get_objects(url):
im = Image.open(url)
img = transform(im).unsqueeze(0)
outputs = model(img)
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
plot_results(im, probas[keep], bboxes_scaled)
return [CLASSES[p.argmax()] for p in probas[keep]]
def get_tags(text, objects):
system_prompt = """
<SystemPrompt>
Extract Tags from the provided text.
The Tags that will be used to search.
<OutputFormat>
Format the output in the following JSON structure
{
"tags" : [* list of tags here*]
}
</OutputFormat>
</SystemPrompt>
"""
try:
user_prompt = f"Extract the Tags from this text:\n{text}\n{objects}"
chat_completion = client.chat.completions.create(
messages=[{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}],
model="llama3-8b-8192", response_format={"type": "json_object"},
stream=False
)
json_data = json.loads(chat_completion.choices[0].message.content)
return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005
except Exception as e:
st.error(f"Exception | get_tags | {str(e)}")
# Main image to tags function
def image_to_tags(image):
image = Image.fromarray(image)
image.save("saved_image.png")
generated_caption = get_caption('saved_image.png')
objects = get_objects('saved_image.png')
tags, cost = get_tags(generated_caption, ", ".join(objects))
return ", ".join(tags), generated_caption, ", ".join(objects), cost
# Streamlit app
st.title("Image Tagging App")
st.write("Upload an image and get captions, object detection results, and associated tags.")
# Image upload
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "png", "jpeg"])
if uploaded_image is not None:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image.', use_column_width=True)
# Generate tags, caption, objects, and cost
tags, caption, objects, cost = image_to_tags(image)
# Display results
st.subheader("Predicted Tags:")
st.write(tags)
st.subheader("Caption:")
st.write(caption)
st.subheader("Objects Detected:")
st.write(objects)
st.subheader("Cost:")
st.write(f"${cost:.6f}")
|