Spaces:
No application file
No application file
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- demo.py +79 -159
- saved_image.png +3 -0
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
sample_data/mnist_test.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
sample_data/mnist_train_small.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
sample_data/mnist_test.csv filter=lfs diff=lfs merge=lfs -text
|
37 |
sample_data/mnist_train_small.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
saved_image.png filter=lfs diff=lfs merge=lfs -text
|
demo.py
CHANGED
@@ -3,14 +3,8 @@ from PIL import Image
|
|
3 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
4 |
import torch
|
5 |
import math
|
6 |
-
import
|
7 |
-
|
8 |
import matplotlib.pyplot as plt
|
9 |
-
# %config InlineBackend.figure_format = 'retina'
|
10 |
-
|
11 |
-
import ipywidgets as widgets
|
12 |
-
from IPython.display import display, clear_output
|
13 |
-
|
14 |
from torch import nn
|
15 |
from torchvision.models import resnet50
|
16 |
import torchvision.transforms as T
|
@@ -19,51 +13,26 @@ from groq import Groq
|
|
19 |
import re
|
20 |
import json
|
21 |
|
|
|
22 |
GROQ_API_KEY = "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"
|
23 |
-
|
24 |
client = Groq(api_key = GROQ_API_KEY)
|
25 |
|
26 |
-
|
27 |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
28 |
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
|
|
29 |
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
# COCO classes
|
34 |
-
CLASSES = [
|
35 |
-
'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
36 |
-
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
|
37 |
-
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
|
38 |
-
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
|
39 |
-
'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
|
40 |
-
'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
|
41 |
-
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
|
42 |
-
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
|
43 |
-
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
|
44 |
-
'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
|
45 |
-
'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
|
46 |
-
'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
|
47 |
-
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
|
48 |
-
'toothbrush'
|
49 |
-
]
|
50 |
|
51 |
-
#
|
52 |
-
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
53 |
-
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
54 |
-
|
55 |
-
# standard PyTorch mean-std input image normalization
|
56 |
-
transform = T.Compose([
|
57 |
-
T.Resize(800),
|
58 |
-
T.ToTensor(),
|
59 |
-
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
60 |
-
])
|
61 |
-
|
62 |
-
# for output bounding box post-processing
|
63 |
def box_cxcywh_to_xyxy(x):
|
64 |
x_c, y_c, w, h = x.unbind(1)
|
65 |
-
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
66 |
-
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
67 |
return torch.stack(b, dim=1)
|
68 |
|
69 |
def rescale_bboxes(out_bbox, size):
|
@@ -79,137 +48,88 @@ def plot_results(pil_img, prob, boxes):
|
|
79 |
classes_predicted = []
|
80 |
colors = COLORS * 100
|
81 |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
|
82 |
-
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
|
83 |
-
fill=False, color=c, linewidth=3))
|
84 |
cl = p.argmax()
|
85 |
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
|
86 |
-
(CLASSES[cl])
|
87 |
classes_predicted.append(CLASSES[cl])
|
88 |
-
ax.text(xmin, ymin, text, fontsize=15,
|
89 |
-
bbox=dict(facecolor='yellow', alpha=0.5))
|
90 |
plt.axis('off')
|
91 |
-
|
92 |
-
return list(set(classes_predicted))
|
93 |
-
|
94 |
-
model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
|
95 |
-
model.eval();
|
96 |
-
|
97 |
|
98 |
def get_caption(img_url):
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
print(processor.decode(out[0], skip_special_tokens=True))
|
104 |
-
return str(processor.decode(out[0], skip_special_tokens=True))
|
105 |
-
|
106 |
|
107 |
def get_objects(url):
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
# keep only predictions with 0.7+ confidence
|
118 |
-
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
|
119 |
-
keep = probas.max(-1).values > 0.9
|
120 |
-
|
121 |
-
# convert boxes from [0; 1] to image scales
|
122 |
-
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
|
123 |
-
|
124 |
-
res = plot_results(im, probas[keep], bboxes_scaled)
|
125 |
-
# print(res)
|
126 |
-
return res
|
127 |
-
|
128 |
-
|
129 |
-
system_prompt = """
|
130 |
-
<SystemPrompt>
|
131 |
-
Extract Tags from the provided text.
|
132 |
-
The Tags that will be used to search.
|
133 |
-
|
134 |
-
<OutputFormat>
|
135 |
-
Format the output in the following JSON structure
|
136 |
-
|
137 |
-
{
|
138 |
-
"tags" : [* list of tags here*]
|
139 |
-
}
|
140 |
-
|
141 |
-
</OutputFormat>
|
142 |
|
143 |
-
</SystemPrompt>
|
144 |
-
|
145 |
-
"""
|
146 |
def get_tags(text, objects):
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
156 |
"""
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
response_format={"type": "json_object"},
|
172 |
-
|
173 |
-
stream=False,
|
174 |
-
)
|
175 |
-
|
176 |
-
print(chat_completion)
|
177 |
-
|
178 |
-
json_data = json.loads(chat_completion.choices[0].message.content)
|
179 |
-
return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005
|
180 |
-
|
181 |
-
except Exception as e:
|
182 |
-
print(f"Exception | get_tags | {str(e)}")
|
183 |
-
|
184 |
-
|
185 |
-
# Image processing function
|
186 |
def image_to_tags(image):
|
187 |
-
# tags = "shahzain, haider"
|
188 |
image = Image.fromarray(image)
|
189 |
image.save("saved_image.png")
|
190 |
-
|
191 |
generated_caption = get_caption('saved_image.png')
|
192 |
-
print(generated_caption)
|
193 |
-
|
194 |
objects = get_objects('saved_image.png')
|
195 |
-
|
196 |
tags, cost = get_tags(generated_caption, ", ".join(objects))
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from transformers import BlipProcessor, BlipForConditionalGeneration
|
4 |
import torch
|
5 |
import math
|
6 |
+
import streamlit as st
|
|
|
7 |
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
8 |
from torch import nn
|
9 |
from torchvision.models import resnet50
|
10 |
import torchvision.transforms as T
|
|
|
13 |
import re
|
14 |
import json
|
15 |
|
16 |
+
# Initialize Groq API key and client
|
17 |
GROQ_API_KEY = "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"
|
|
|
18 |
client = Groq(api_key = GROQ_API_KEY)
|
19 |
|
20 |
+
# Initialize models and processor
|
21 |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
|
22 |
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
|
23 |
+
torch.set_grad_enabled(False)
|
24 |
|
25 |
+
# COCO classes and colors
|
26 |
+
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
|
27 |
+
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
28 |
|
29 |
+
# Image transformation
|
30 |
+
transform = T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
# Helper functions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def box_cxcywh_to_xyxy(x):
|
34 |
x_c, y_c, w, h = x.unbind(1)
|
35 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
|
|
36 |
return torch.stack(b, dim=1)
|
37 |
|
38 |
def rescale_bboxes(out_bbox, size):
|
|
|
48 |
classes_predicted = []
|
49 |
colors = COLORS * 100
|
50 |
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
|
51 |
+
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
|
|
|
52 |
cl = p.argmax()
|
53 |
text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
|
|
|
54 |
classes_predicted.append(CLASSES[cl])
|
55 |
+
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
|
|
|
56 |
plt.axis('off')
|
57 |
+
st.pyplot(plt)
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def get_caption(img_url):
|
60 |
+
raw_image = Image.open(img_url).convert('RGB')
|
61 |
+
inputs = processor(raw_image, return_tensors="pt")
|
62 |
+
out = caption_model.generate(**inputs)
|
63 |
+
return str(processor.decode(out[0], skip_special_tokens=True))
|
|
|
|
|
|
|
64 |
|
65 |
def get_objects(url):
|
66 |
+
im = Image.open(url)
|
67 |
+
img = transform(im).unsqueeze(0)
|
68 |
+
outputs = model(img)
|
69 |
+
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
|
70 |
+
keep = probas.max(-1).values > 0.9
|
71 |
+
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
|
72 |
+
plot_results(im, probas[keep], bboxes_scaled)
|
73 |
+
return [CLASSES[p.argmax()] for p in probas[keep]]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
|
|
|
|
|
|
75 |
def get_tags(text, objects):
|
76 |
+
system_prompt = """
|
77 |
+
<SystemPrompt>
|
78 |
+
Extract Tags from the provided text.
|
79 |
+
The Tags that will be used to search.
|
80 |
+
<OutputFormat>
|
81 |
+
Format the output in the following JSON structure
|
82 |
+
{
|
83 |
+
"tags" : [* list of tags here*]
|
84 |
+
}
|
85 |
+
</OutputFormat>
|
86 |
+
</SystemPrompt>
|
87 |
"""
|
88 |
+
try:
|
89 |
+
user_prompt = f"Extract the Tags from this text:\n{text}\n{objects}"
|
90 |
+
chat_completion = client.chat.completions.create(
|
91 |
+
messages=[{"role": "system", "content": system_prompt},
|
92 |
+
{"role": "user", "content": user_prompt}],
|
93 |
+
model="llama3-8b-8192", response_format={"type": "json_object"},
|
94 |
+
stream=False
|
95 |
+
)
|
96 |
+
json_data = json.loads(chat_completion.choices[0].message.content)
|
97 |
+
return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005
|
98 |
+
except Exception as e:
|
99 |
+
st.error(f"Exception | get_tags | {str(e)}")
|
100 |
+
|
101 |
+
# Main image to tags function
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
def image_to_tags(image):
|
|
|
103 |
image = Image.fromarray(image)
|
104 |
image.save("saved_image.png")
|
|
|
105 |
generated_caption = get_caption('saved_image.png')
|
|
|
|
|
106 |
objects = get_objects('saved_image.png')
|
|
|
107 |
tags, cost = get_tags(generated_caption, ", ".join(objects))
|
108 |
+
return ", ".join(tags), generated_caption, ", ".join(objects), cost
|
109 |
+
|
110 |
+
# Streamlit app
|
111 |
+
st.title("Image Tagging App")
|
112 |
+
st.write("Upload an image and get captions, object detection results, and associated tags.")
|
113 |
+
|
114 |
+
# Image upload
|
115 |
+
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "png", "jpeg"])
|
116 |
+
|
117 |
+
if uploaded_image is not None:
|
118 |
+
image = Image.open(uploaded_image)
|
119 |
+
st.image(image, caption='Uploaded Image.', use_column_width=True)
|
120 |
+
|
121 |
+
# Generate tags, caption, objects, and cost
|
122 |
+
tags, caption, objects, cost = image_to_tags(image)
|
123 |
+
|
124 |
+
# Display results
|
125 |
+
st.subheader("Predicted Tags:")
|
126 |
+
st.write(tags)
|
127 |
+
|
128 |
+
st.subheader("Caption:")
|
129 |
+
st.write(caption)
|
130 |
+
|
131 |
+
st.subheader("Objects Detected:")
|
132 |
+
st.write(objects)
|
133 |
+
|
134 |
+
st.subheader("Cost:")
|
135 |
+
st.write(f"${cost:.6f}")
|
saved_image.png
ADDED
![]() |
Git LFS Details
|