File size: 5,302 Bytes
dc81f01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b47611f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import logging
from datetime import datetime
from pathlib import Path
import pickle
from tqdm import tqdm
from datasets import load_dataset
import chromadb
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
import google.generativeai as genai
from dotenv import load_dotenv


def set_directories():
    curr_dir = Path(os.getcwd())
    
    data_dir = curr_dir / 'data'
    data_pickle_path = data_dir / 'data_set.pkl'

    vectordb_dir = curr_dir / 'vector_storage'
    chroma_dir = vectordb_dir / 'chroma'
    
    for dir in [data_dir, vectordb_dir, chroma_dir]:
        if not os.path.exists(dir):
            os.mkdir(dir)

    return data_pickle_path, chroma_dir


def load_data(data_pickle_path, dataset="vipulmaheshwari/GTA-Image-Captioning-Dataset"):
    if not os.path.exists(data_pickle_path):
        print(f"Data set hasn't been loaded. Loading from the datasets library and save it as a pickle.")
        data_set = load_dataset(dataset)
        with open(data_pickle_path, 'wb') as outfile:
            pickle.dump(data_set, outfile)
    else:
        print(f"Data set already exists in the local drive. Loading it.")
        with open(data_pickle_path, 'rb') as infile:
            data_set = pickle.load(infile)

    return data_set


def get_embeddings(data, model):
    # Get the id and embedding of each data/image
    ids = []
    embeddings = []
    for id, image in tqdm(zip(list(range(len(data))), data)):
        ids.append("image "+str(id))

        embedding = model.encode(image)
        embeddings.append(embedding.tolist())

    return ids, embeddings


def get_collection(chroma_dir, model, collection_name, data):
    client = chromadb.PersistentClient(path=chroma_dir.__str__())
    collection = client.get_or_create_collection(name=collection_name)

    if collection.count() != len(data):
        print("Adding embeddings to the collection.")
        ids, embeddings = get_embeddings(data, model)
        collection.add(
            ids=ids,
            embeddings=embeddings
        )
    else:
        print("Embeddings are already added to the collection.")

    return collection


def get_search_result(collection, data_set, query, model, n_results=2):
    # Query the vector store and get results
    results = collection.query(
        query_embeddings=model.encode([query]),
        n_results=2
    )

    # Get the id of the most relevant image
    img_id = int(results['ids'][0][0].split('image ')[-1])
    
    # Get the image and its caption
    image = data_set['train']['image'][img_id]
    text = data_set['train']['text'][img_id]

    return image, text


def show_image(image, text, query):
    plt.ion()
    plt.axis("off")
    plt.imshow(image)
    plt.show()
    print(f"User query: {query}")
    print(f"Original description: {text}\n")
    

def get_logger():
    log_path = "./log/"
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    cur_date = datetime.utcnow().strftime("%Y%m%d")
    log_filename = f"{log_path}{cur_date}.log"

    logging.basicConfig(
        filename=log_filename,
        level=logging.INFO,
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S")
    
    logger = logging.getLogger(__name__)
    
    return logger


def get_image_description(image):
    _ = load_dotenv()
    GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
    genai.configure(api_key=GOOGLE_API_KEY)

    vision_model = genai.GenerativeModel(
        "gemini-pro-vision",
        generation_config={
            "temperature": 0.0
            }
    )
    
    # image = Image.open(image_path)
    
    prompt = f"""
    Describe what you explicitly see in the given image in detail.
    Begin your description with "In this image," or "This image is about," to provide context.
    Your response should be a hard description of the given image without any thoughts or suggestions.
    """

    response = vision_model.generate_content([prompt, image])
    description_by_llm = response.text

    return description_by_llm


def initialization(logger):
    print("Initializing...")
    logger.info("Initializing...")
    print("-------------------------------------------------------")
    logger.info("-------------------------------------------------------")

    print("Set directories...")
    logger.info("Set directories...")
    # Set directories
    data_pickle_path, chroma_dir = set_directories()

    print("Loading data...")
    logger.info("Loading data...")
    # Load dataset
    data_set = load_data(data_pickle_path)

    print("Loading CLIP model...")
    logger.info("Loading CLIP model...")
    # Load CLIP model
    model = SentenceTransformer("sentence-transformers/clip-ViT-L-14")

    print("Getting vector embeddings...")
    logger.info("Getting vector embeddings...")
    # Get vector embeddings
    collection = get_collection(chroma_dir, model, collection_name='image_vectors', data=data_set['train']['image'])

    print("-------------------------------------------------------")
    logger.info("-------------------------------------------------------")
    print("Initialization completed! Ready for search.")
    logger.info("Initialization completed! Ready for search.")

    return collection, data_set, model, logger