Spaces:
Runtime error
Runtime error
File size: 4,830 Bytes
b47611f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import logging
from datetime import datetime
from pathlib import Path
import pickle
from tqdm import tqdm
from datasets import load_dataset
import chromadb
import matplotlib.pyplot as plt
def set_directories():
curr_dir = Path(os.getcwd())
data_dir = curr_dir / 'data'
data_pickle_path = data_dir / 'data_set.pkl'
vectordb_dir = curr_dir / 'vectore_storage'
chroma_dir = vectordb_dir / 'chroma'
for dir in [data_dir, vectordb_dir, chroma_dir]:
if not os.path.exists(dir):
os.mkdir(dir)
return data_pickle_path, chroma_dir
def load_data(data_pickle_path, dataset="vipulmaheshwari/GTA-Image-Captioning-Dataset"):
if not os.path.exists(data_pickle_path):
print(f"Data set hasn't been loaded. Loading from the datasets library and save it as a pickle.")
data_set = load_dataset(dataset)
with open(data_pickle_path, 'wb') as outfile:
pickle.dump(data_set, outfile)
else:
print(f"Data set already exists in the local drive. Loading it.")
with open(data_pickle_path, 'rb') as infile:
data_set = pickle.load(infile)
return data_set
def get_embeddings(data, model):
# Get the id and embedding of each data/image
ids = []
embeddings = []
for id, image in tqdm(zip(list(range(len(data))), data)):
ids.append("image "+str(id))
embedding = model.encode(image)
embeddings.append(embedding.tolist())
return ids, embeddings
def get_collection(chroma_dir, model, collection_name, data):
client = chromadb.PersistentClient(path=chroma_dir.__str__())
collection = client.get_or_create_collection(name=collection_name)
if collection.count() != len(data):
print("Adding embeddings to the collection.")
ids, embeddings = get_embeddings(data, model)
collection.add(
ids=ids,
embeddings=embeddings
)
else:
print("Embeddings are already added to the collection.")
return collection
def get_result(collection, data_set, query, model, n_results=2):
# Query the vector store and get results
results = collection.query(
query_embeddings=model.encode([query]),
n_results=2
)
# Get the id of the most relevant image
img_id = int(results['ids'][0][0].split('image ')[-1])
# Get the image and its caption
image = data_set['train']['image'][img_id]
text = data_set['train']['text'][img_id]
return image, text
def show_image(image, text, query):
plt.ion()
plt.axis("off")
plt.imshow(image)
plt.show()
print(f"User query: {query}")
print(f"Original description: {text}\n")
def get_logger():
log_path = "./log/"
if not os.path.exists(log_path):
os.mkdir(log_path)
cur_date = datetime.utcnow().strftime("%Y%m%d")
log_filename = f"{log_path}{cur_date}.log"
logging.basicConfig(
filename=log_filename,
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
logger = logging.getLogger(__name__)
return logger
def initialization(logger):
print("Initializing...")
logger.info("Initializing...")
print("-------------------------------------------------------")
logger.info("-------------------------------------------------------")
print("Importing functions...")
logger.info("Importing functions...")
# Import module, classes, and functions
from sentence_transformers import SentenceTransformer
from utils.utils import set_directories, load_data, get_collection, get_result, show_image
print("Set directories...")
logger.info("Set directories...")
# Set directories
data_pickle_path, chroma_dir = set_directories()
print("Loading data...")
logger.info("Loading data...")
# Load dataset
data_set = load_data(data_pickle_path)
print("Loading CLIP model...")
logger.info("Loading CLIP model...")
# Load CLIP model
model = SentenceTransformer("sentence-transformers/clip-ViT-L-14")
print("Getting vector embeddings...")
logger.info("Getting vector embeddings...")
# Get vector embeddings
collection = get_collection(chroma_dir, model, collection_name='image_vectors', data=data_set['train']['image'])
print("-------------------------------------------------------")
logger.info("-------------------------------------------------------")
print("Initialization completed! Ready for search.")
logger.info("Initialization completed! Ready for search.")
return collection, data_set, model, logger |