File size: 5,828 Bytes
e1e5a9b
 
 
 
 
d2e2636
e1e5a9b
 
 
 
 
 
 
 
 
d2e2636
e1e5a9b
 
 
d2e2636
e1e5a9b
 
d2e2636
e1e5a9b
d2e2636
 
 
e1e5a9b
d2e2636
 
e1e5a9b
d2e2636
e1e5a9b
 
d2e2636
e1e5a9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2e2636
e1e5a9b
 
 
d2e2636
e1e5a9b
d2e2636
e1e5a9b
 
d2e2636
 
 
 
e1e5a9b
 
d2e2636
 
 
 
 
 
 
 
e1e5a9b
 
d2e2636
 
 
 
 
 
 
 
 
 
 
e1e5a9b
d2e2636
 
 
 
 
 
 
 
 
 
 
 
 
 
e1e5a9b
 
 
 
 
 
d2e2636
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import requests
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
import math
import streamlit as st
import matplotlib.pyplot as plt
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
from transformers import BlipProcessor, BlipForConditionalGeneration
from groq import Groq
import re
import json

# Initialize Groq API key and client
GROQ_API_KEY = "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"
client = Groq(api_key = GROQ_API_KEY)

# Initialize models and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
torch.set_grad_enabled(False)

# COCO classes and colors
CLASSES = ['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# Image transformation
transform = T.Compose([T.Resize(800), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# Helper functions
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def plot_results(pil_img, prob, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    classes_predicted = []
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        classes_predicted.append(CLASSES[cl])
        ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    st.pyplot(plt)

def get_caption(img_url):
    raw_image = Image.open(img_url).convert('RGB')
    inputs = processor(raw_image, return_tensors="pt")
    out = caption_model.generate(**inputs)
    return str(processor.decode(out[0], skip_special_tokens=True))

def get_objects(url):
    im = Image.open(url)
    img = transform(im).unsqueeze(0)
    outputs = model(img)
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.9
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    plot_results(im, probas[keep], bboxes_scaled)
    return [CLASSES[p.argmax()] for p in probas[keep]]

def get_tags(text, objects):
    system_prompt = """
        <SystemPrompt>
                Extract Tags from the provided text.
                The Tags that will be used to search.
            <OutputFormat>
                Format the output in the following JSON structure
                {
                    "tags" : [* list of tags here*]
                }
            </OutputFormat>
    </SystemPrompt>
    """
    try:
        user_prompt = f"Extract the Tags from this text:\n{text}\n{objects}"
        chat_completion = client.chat.completions.create(
            messages=[{"role": "system", "content": system_prompt},
                      {"role": "user", "content": user_prompt}],
            model="llama3-8b-8192", response_format={"type": "json_object"},
            stream=False
        )
        json_data = json.loads(chat_completion.choices[0].message.content)
        return json_data['tags'], chat_completion.usage.total_tokens * 0.00000005
    except Exception as e:
        st.error(f"Exception | get_tags | {str(e)}")

# Main image to tags function
def image_to_tags(image):
    image = Image.fromarray(image)
    image.save("saved_image.png")
    generated_caption = get_caption('saved_image.png')
    objects = get_objects('saved_image.png')
    tags, cost = get_tags(generated_caption, ", ".join(objects))
    return ", ".join(tags), generated_caption, ", ".join(objects), cost

# Streamlit app
st.title("Image Tagging App")
st.write("Upload an image and get captions, object detection results, and associated tags.")

# Image upload
uploaded_image = st.file_uploader("Choose an Image", type=["jpg", "png", "jpeg"])

if uploaded_image is not None:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image.', use_column_width=True)
    
    # Generate tags, caption, objects, and cost
    tags, caption, objects, cost = image_to_tags(image)
    
    # Display results
    st.subheader("Predicted Tags:")
    st.write(tags)
    
    st.subheader("Caption:")
    st.write(caption)
    
    st.subheader("Objects Detected:")
    st.write(objects)
    
    st.subheader("Cost:")
    st.write(f"${cost:.6f}")