ShahzainHaider commited on
Commit
d2e2636
·
verified ·
1 Parent(s): e1e5a9b

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. demo.py +79 -159
  3. saved_image.png +3 -0
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  sample_data/mnist_test.csv filter=lfs diff=lfs merge=lfs -text
37
  sample_data/mnist_train_small.csv filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  sample_data/mnist_test.csv filter=lfs diff=lfs merge=lfs -text
37
  sample_data/mnist_train_small.csv filter=lfs diff=lfs merge=lfs -text
38
+ saved_image.png filter=lfs diff=lfs merge=lfs -text
demo.py CHANGED
@@ -3,14 +3,8 @@ from PIL import Image
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  import torch
5
  import math
6
- import gradio as gr
7
-
8
  import matplotlib.pyplot as plt
9
- # %config InlineBackend.figure_format = 'retina'
10
-
11
- import ipywidgets as widgets
12
- from IPython.display import display, clear_output
13
-
14
  from torch import nn
15
  from torchvision.models import resnet50
16
  import torchvision.transforms as T
@@ -19,51 +13,26 @@ from groq import Groq
19
  import re
20
  import json
21
 
 
22
  GROQ_API_KEY = "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"
23
-
24
  client = Groq(api_key = GROQ_API_KEY)
25
 
26
-
27
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
28
  caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
 
29
 
 
 
 
30
 
31
- torch.set_grad_enabled(False);
32
-
33
- # COCO classes
34
- CLASSES = [
35
- 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
36
- 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
37
- 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
38
- 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
39
- 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
40
- 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
41
- 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
42
- 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
43
- 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
44
- 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
45
- 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
46
- 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
47
- 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
48
- 'toothbrush'
49
- ]
50
 
51
- # # colors for visualization
52
- COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
53
- [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
54
-
55
- # standard PyTorch mean-std input image normalization
56
- transform = T.Compose([
57
- T.Resize(800),
58
- T.ToTensor(),
59
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
60
- ])
61
-
62
- # for output bounding box post-processing
63
  def box_cxcywh_to_xyxy(x):
64
  x_c, y_c, w, h = x.unbind(1)
65
- b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
66
- (x_c + 0.5 * w), (y_c + 0.5 * h)]
67
  return torch.stack(b, dim=1)
68
 
69
  def rescale_bboxes(out_bbox, size):
@@ -79,137 +48,88 @@ def plot_results(pil_img, prob, boxes):
79
  classes_predicted = []
80
  colors = COLORS * 100
81
  for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
82
- ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
83
- fill=False, color=c, linewidth=3))
84
  cl = p.argmax()
85
  text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
86
- (CLASSES[cl])
87
  classes_predicted.append(CLASSES[cl])
88
- ax.text(xmin, ymin, text, fontsize=15,
89
- bbox=dict(facecolor='yellow', alpha=0.5))
90
  plt.axis('off')
91
- plt.show()
92
- return list(set(classes_predicted))
93
-
94
- model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
95
- model.eval();
96
-
97
 
98
  def get_caption(img_url):
99
- raw_image = Image.open(img_url).convert('RGB')
100
- inputs = processor(raw_image, return_tensors="pt")
101
-
102
- out = caption_model.generate(**inputs)
103
- print(processor.decode(out[0], skip_special_tokens=True))
104
- return str(processor.decode(out[0], skip_special_tokens=True))
105
-
106
 
107
  def get_objects(url):
108
- # url = '/content/saved_image.png'
109
- im = Image.open(url)
110
-
111
- # mean-std normalize the input image (batch-size: 1)
112
- img = transform(im).unsqueeze(0)
113
-
114
- # propagate through the model
115
- outputs = model(img)
116
-
117
- # keep only predictions with 0.7+ confidence
118
- probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
119
- keep = probas.max(-1).values > 0.9
120
-
121
- # convert boxes from [0; 1] to image scales
122
- bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
123
-
124
- res = plot_results(im, probas[keep], bboxes_scaled)
125
- # print(res)
126
- return res
127
-
128
-
129
- system_prompt = """
130
- <SystemPrompt>
131
- Extract Tags from the provided text.
132
- The Tags that will be used to search.
133
-
134
- <OutputFormat>
135
- Format the output in the following JSON structure
136
-
137
- {
138
- "tags" : [* list of tags here*]
139
- }
140
-
141
- </OutputFormat>
142
 
143
- </SystemPrompt>
144
-
145
- """
146
  def get_tags(text, objects):
147
- try:
148
-
149
- user_prompt = f"""
150
- Extract the Tags from thei text:
151
-
152
- {text}
153
-
154
- {objects}
155
-
 
 
156
  """
157
-
158
- chat_completion = client.chat.completions.create(
159
-
160
- messages=[
161
- {
162
- "role": "system",
163
- "content": system_prompt,
164
- },
165
- {
166
- "role": "user",
167
- "content": user_prompt,
168
- }
169
- ],
170
- model="llama3-8b-8192",
171
- response_format={"type": "json_object"},
172
-
173
- stream=False,
174
- )
175
-
176
- print(chat_completion)
177
-
178
- json_data = json.loads(chat_completion.choices[0].message.content)
179
- return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005
180
-
181
- except Exception as e:
182
- print(f"Exception | get_tags | {str(e)}")
183
-
184
-
185
- # Image processing function
186
  def image_to_tags(image):
187
- # tags = "shahzain, haider"
188
  image = Image.fromarray(image)
189
  image.save("saved_image.png")
190
-
191
  generated_caption = get_caption('saved_image.png')
192
- print(generated_caption)
193
-
194
  objects = get_objects('saved_image.png')
195
-
196
  tags, cost = get_tags(generated_caption, ", ".join(objects))
197
-
198
- return ", ".join(tags) , generated_caption , ", ".join(objects), cost
199
- # return "", "", ""
200
-
201
- # Define Gradio interface
202
- app = gr.Interface(
203
- fn=image_to_tags,
204
- inputs=gr.Image(type="numpy", label="Upload an Image"),
205
- outputs=[
206
- gr.Label(num_top_classes=5, label="Predicted Tags"),
207
- gr.Textbox(label="Caption"),
208
- gr.Textbox(label="Object Detection"),
209
- gr.Textbox(label="Cost")
210
-
211
- ], title="Image Tagging App"
212
- )
213
-
214
- # Launch the app
215
- app.launch(debug = True, share=True)
 
 
 
 
 
 
 
 
 
 
3
  from transformers import BlipProcessor, BlipForConditionalGeneration
4
  import torch
5
  import math
6
+ import streamlit as st
 
7
  import matplotlib.pyplot as plt
 
 
 
 
 
8
  from torch import nn
9
  from torchvision.models import resnet50
10
  import torchvision.transforms as T
 
13
  import re
14
  import json
15
 
16
+ # Initialize Groq API key and client
17
  GROQ_API_KEY = "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"
 
18
  client = Groq(api_key = GROQ_API_KEY)
19
 
20
+ # Initialize models and processor
21
  processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
22
  caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
23
+ torch.set_grad_enabled(False)
24
 
25
+ # COCO classes and colors
26
+ CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
27
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
28
 
29
+ # Image transformation
30
+ transform = T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ # Helper functions
 
 
 
 
 
 
 
 
 
 
 
33
  def box_cxcywh_to_xyxy(x):
34
  x_c, y_c, w, h = x.unbind(1)
35
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
 
36
  return torch.stack(b, dim=1)
37
 
38
  def rescale_bboxes(out_bbox, size):
 
48
  classes_predicted = []
49
  colors = COLORS * 100
50
  for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
51
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
 
52
  cl = p.argmax()
53
  text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
 
54
  classes_predicted.append(CLASSES[cl])
55
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
 
56
  plt.axis('off')
57
+ st.pyplot(plt)
 
 
 
 
 
58
 
59
  def get_caption(img_url):
60
+ raw_image = Image.open(img_url).convert('RGB')
61
+ inputs = processor(raw_image, return_tensors="pt")
62
+ out = caption_model.generate(**inputs)
63
+ return str(processor.decode(out[0], skip_special_tokens=True))
 
 
 
64
 
65
  def get_objects(url):
66
+ im = Image.open(url)
67
+ img = transform(im).unsqueeze(0)
68
+ outputs = model(img)
69
+ probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
70
+ keep = probas.max(-1).values > 0.9
71
+ bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
72
+ plot_results(im, probas[keep], bboxes_scaled)
73
+ return [CLASSES[p.argmax()] for p in probas[keep]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
 
 
75
  def get_tags(text, objects):
76
+ system_prompt = """
77
+ <SystemPrompt>
78
+ Extract Tags from the provided text.
79
+ The Tags that will be used to search.
80
+ <OutputFormat>
81
+ Format the output in the following JSON structure
82
+ {
83
+ "tags" : [* list of tags here*]
84
+ }
85
+ </OutputFormat>
86
+ </SystemPrompt>
87
  """
88
+ try:
89
+ user_prompt = f"Extract the Tags from this text:\n{text}\n{objects}"
90
+ chat_completion = client.chat.completions.create(
91
+ messages=[{"role": "system", "content": system_prompt},
92
+ {"role": "user", "content": user_prompt}],
93
+ model="llama3-8b-8192", response_format={"type": "json_object"},
94
+ stream=False
95
+ )
96
+ json_data = json.loads(chat_completion.choices[0].message.content)
97
+ return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005
98
+ except Exception as e:
99
+ st.error(f"Exception | get_tags | {str(e)}")
100
+
101
+ # Main image to tags function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def image_to_tags(image):
 
103
  image = Image.fromarray(image)
104
  image.save("saved_image.png")
 
105
  generated_caption = get_caption('saved_image.png')
 
 
106
  objects = get_objects('saved_image.png')
 
107
  tags, cost = get_tags(generated_caption, ", ".join(objects))
108
+ return ", ".join(tags), generated_caption, ", ".join(objects), cost
109
+
110
+ # Streamlit app
111
+ st.title("Image Tagging App")
112
+ st.write("Upload an image and get captions, object detection results, and associated tags.")
113
+
114
+ # Image upload
115
+ uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "png", "jpeg"])
116
+
117
+ if uploaded_image is not None:
118
+ image = Image.open(uploaded_image)
119
+ st.image(image, caption='Uploaded Image.', use_column_width=True)
120
+
121
+ # Generate tags, caption, objects, and cost
122
+ tags, caption, objects, cost = image_to_tags(image)
123
+
124
+ # Display results
125
+ st.subheader("Predicted Tags:")
126
+ st.write(tags)
127
+
128
+ st.subheader("Caption:")
129
+ st.write(caption)
130
+
131
+ st.subheader("Objects Detected:")
132
+ st.write(objects)
133
+
134
+ st.subheader("Cost:")
135
+ st.write(f"${cost:.6f}")
saved_image.png ADDED

Git LFS Details

  • SHA256: 3babe4b1623ed5bf450949a781d5295b6b166c373e58a045d9f10130615d4d53
  • Pointer size: 132 Bytes
  • Size of remote file: 1.59 MB