File size: 4,830 Bytes
b47611f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import logging
from datetime import datetime
from pathlib import Path
import pickle
from tqdm import tqdm
from datasets import load_dataset
import chromadb
import matplotlib.pyplot as plt


def set_directories():
    curr_dir = Path(os.getcwd())
    
    data_dir = curr_dir / 'data'
    data_pickle_path = data_dir / 'data_set.pkl'

    vectordb_dir = curr_dir / 'vectore_storage'
    chroma_dir = vectordb_dir / 'chroma'
    
    for dir in [data_dir, vectordb_dir, chroma_dir]:
        if not os.path.exists(dir):
            os.mkdir(dir)

    return data_pickle_path, chroma_dir


def load_data(data_pickle_path, dataset="vipulmaheshwari/GTA-Image-Captioning-Dataset"):
    if not os.path.exists(data_pickle_path):
        print(f"Data set hasn't been loaded. Loading from the datasets library and save it as a pickle.")
        data_set = load_dataset(dataset)
        with open(data_pickle_path, 'wb') as outfile:
            pickle.dump(data_set, outfile)
    else:
        print(f"Data set already exists in the local drive. Loading it.")
        with open(data_pickle_path, 'rb') as infile:
            data_set = pickle.load(infile)

    return data_set


def get_embeddings(data, model):
    # Get the id and embedding of each data/image
    ids = []
    embeddings = []
    for id, image in tqdm(zip(list(range(len(data))), data)):
        ids.append("image "+str(id))

        embedding = model.encode(image)
        embeddings.append(embedding.tolist())

    return ids, embeddings


def get_collection(chroma_dir, model, collection_name, data):
    client = chromadb.PersistentClient(path=chroma_dir.__str__())
    collection = client.get_or_create_collection(name=collection_name)

    if collection.count() != len(data):
        print("Adding embeddings to the collection.")
        ids, embeddings = get_embeddings(data, model)
        collection.add(
            ids=ids,
            embeddings=embeddings
        )
    else:
        print("Embeddings are already added to the collection.")

    return collection


def get_result(collection, data_set, query, model, n_results=2):
    # Query the vector store and get results
    results = collection.query(
        query_embeddings=model.encode([query]),
        n_results=2
    )

    # Get the id of the most relevant image
    img_id = int(results['ids'][0][0].split('image ')[-1])
    
    # Get the image and its caption
    image = data_set['train']['image'][img_id]
    text = data_set['train']['text'][img_id]

    return image, text


def show_image(image, text, query):
    plt.ion()
    plt.axis("off")
    plt.imshow(image)
    plt.show()
    print(f"User query: {query}")
    print(f"Original description: {text}\n")
    

def get_logger():
    log_path = "./log/"
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    cur_date = datetime.utcnow().strftime("%Y%m%d")
    log_filename = f"{log_path}{cur_date}.log"

    logging.basicConfig(
        filename=log_filename,
        level=logging.INFO,
        format="%(asctime)s %(levelname)-8s %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S")
    
    logger = logging.getLogger(__name__)
    
    return logger


def initialization(logger):
    print("Initializing...")
    logger.info("Initializing...")
    print("-------------------------------------------------------")
    logger.info("-------------------------------------------------------")

    print("Importing functions...")
    logger.info("Importing functions...")
    # Import module, classes, and functions
    from sentence_transformers import SentenceTransformer
    from utils.utils import set_directories, load_data, get_collection, get_result, show_image

    print("Set directories...")
    logger.info("Set directories...")
    # Set directories
    data_pickle_path, chroma_dir = set_directories()

    print("Loading data...")
    logger.info("Loading data...")
    # Load dataset
    data_set = load_data(data_pickle_path)

    print("Loading CLIP model...")
    logger.info("Loading CLIP model...")
    # Load CLIP model
    model = SentenceTransformer("sentence-transformers/clip-ViT-L-14")

    print("Getting vector embeddings...")
    logger.info("Getting vector embeddings...")
    # Get vector embeddings
    collection = get_collection(chroma_dir, model, collection_name='image_vectors', data=data_set['train']['image'])

    print("-------------------------------------------------------")
    logger.info("-------------------------------------------------------")
    print("Initialization completed! Ready for search.")
    logger.info("Initialization completed! Ready for search.")

    return collection, data_set, model, logger