Spaces:
Build error
Build error
File size: 2,237 Bytes
cf349fd 503acf7 2cf3514 f1d50b1 587ab22 f1d50b1 98e7562 2cf3514 a1fc7fb cf349fd a1fc7fb cf349fd 1991cb1 cf349fd 2cf3514 cf349fd 2cf3514 cf349fd f1d50b1 2cf3514 0e0bacc 98e7562 a1fc7fb 98e7562 587ab22 f1d50b1 2cf3514 503acf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor
from config import MODEL_LIST
from koclip import FlaxHybridCLIP
from global_session import GlobalState
from threading import Lock
def load_index(img_file):
state = GlobalState(img_file)
if not hasattr(state, '_lock'):
state._lock = Lock()
print(f"Locking loading of features : {img_file} to avoid concurrent caching.")
with state._lock:
cached_index = load_index_cached(img_file)
print(f"Unlocking loading of model : {img_file} to avoid concurrent caching.")
return cached_index
@st.cache(allow_output_mutation=True)
def load_index_cached(img_file):
filenames, embeddings = [], []
with open(img_file, "r") as f:
for line in f:
cols = line.strip().split("\t")
filename = cols[0]
embedding = [float(x) for x in cols[1].split(",")]
filenames.append(filename)
embeddings.append(embedding)
embeddings = np.array(embeddings)
index = nmslib.init(method="hnsw", space="cosinesimil")
index.addDataPointBatch(embeddings)
index.createIndex({"post": 2}, print_progress=True)
return filenames, index
def load_model(model_name="koclip/koclip-base"):
state = GlobalState(model_name)
if not hasattr(state, '_lock'):
state._lock = Lock()
print(f"Locking loading of model : {model_name} to avoid concurrent caching.")
with state._lock:
cached_model = load_model_cached(model_name)
print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
return cached_model
@st.cache(allow_output_mutation=True)
def load_model_cached(model_name):
assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
model = FlaxHybridCLIP.from_pretrained(model_name)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
if model_name == "koclip/koclip-large":
processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
"google/vit-large-patch16-224"
)
return model, processor
|