import nmslib import numpy as np import streamlit as st from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor from config import MODEL_LIST from koclip import FlaxHybridCLIP from global_session import GlobalState from threading import Lock def load_index(img_file): state = GlobalState(img_file) if not hasattr(state, '_lock'): state._lock = Lock() print(f"Locking loading of features : {img_file} to avoid concurrent caching.") with state._lock: cached_index = load_index_cached(img_file) print(f"Unlocking loading of model : {img_file} to avoid concurrent caching.") return cached_index @st.cache(allow_output_mutation=True) def load_index_cached(img_file): filenames, embeddings = [], [] with open(img_file, "r") as f: for line in f: cols = line.strip().split("\t") filename = cols[0] embedding = [float(x) for x in cols[1].split(",")] filenames.append(filename) embeddings.append(embedding) embeddings = np.array(embeddings) index = nmslib.init(method="hnsw", space="cosinesimil") index.addDataPointBatch(embeddings) index.createIndex({"post": 2}, print_progress=True) return filenames, index def load_model(model_name="koclip/koclip-base"): state = GlobalState(model_name) if not hasattr(state, '_lock'): state._lock = Lock() print(f"Locking loading of model : {model_name} to avoid concurrent caching.") with state._lock: cached_model = load_model_cached(model_name) print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.") return cached_model @st.cache(allow_output_mutation=True) def load_model_cached(model_name): assert model_name in {f"koclip/{model}" for model in MODEL_LIST} model = FlaxHybridCLIP.from_pretrained(model_name) processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large") if model_name == "koclip/koclip-large": processor.feature_extractor = ViTFeatureExtractor.from_pretrained( "google/vit-large-patch16-224" ) return model, processor