import random | |
import torch | |
from dataframe import * | |
from model import * | |
images = ["Girl.jpg", | |
"Cat In Hat.jpg", | |
"Cat In The Hat.jpg", | |
"Fox In Socks.jpg", | |
"Green Eggs And Ham.jpg", | |
"Green Eggs with Ham.jpg", | |
"Grinch.jpg", | |
"Horton.jpg", | |
"Lorax.jpg", | |
"Thing1 and Thing2.jpg", | |
"Turtle.jpg", | |
"One Fish.jpg", | |
"Two Fish.jpg"] | |
def search1(search_prompt : str): | |
""" | |
Given a search_prompt, return an array of pictures to display | |
""" | |
return [ (images[i], images[i].split('.')[0]) for i in random.sample(range(len(images)), 4) ] | |
def search2(search_prompt : str) : | |
# Set the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Define the model ID | |
model_ID = "openai/clip-vit-base-patch32" | |
# Get model, processor & tokenizer | |
model, processor, tokenizer = get_model_info(model_ID, device) | |
image_data_df = get_image_data() | |
return get_top_N_images(search_prompt, | |
data = image_data_df, | |
model=model, tokenizer=tokenizer, | |
device = device, | |
top_K=4) |