File size: 2,237 Bytes
cf349fd
503acf7
2cf3514
 
f1d50b1
587ab22
f1d50b1
98e7562
 
2cf3514
a1fc7fb
cf349fd
a1fc7fb
 
 
 
 
 
 
 
 
 
 
 
 
 
cf349fd
1991cb1
 
 
 
 
 
 
cf349fd
2cf3514
cf349fd
2cf3514
cf349fd
f1d50b1
2cf3514
0e0bacc
98e7562
 
 
 
 
 
 
 
 
 
 
a1fc7fb
98e7562
 
587ab22
f1d50b1
 
 
 
2cf3514
 
 
503acf7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import nmslib
import numpy as np
import streamlit as st
from transformers import AutoTokenizer, CLIPProcessor, ViTFeatureExtractor

from config import MODEL_LIST
from koclip import FlaxHybridCLIP
from global_session import GlobalState
from threading import Lock


def load_index(img_file):
    state = GlobalState(img_file)
    if not hasattr(state, '_lock'):
        state._lock = Lock()
    print(f"Locking loading of features : {img_file} to avoid concurrent caching.")

    with state._lock:
        cached_index = load_index_cached(img_file)

    print(f"Unlocking loading of model : {img_file} to avoid concurrent caching.")
    return cached_index


@st.cache(allow_output_mutation=True)
def load_index_cached(img_file):
    filenames, embeddings = [], []
    with open(img_file, "r") as f:
        for line in f:
            cols = line.strip().split("\t")
            filename = cols[0]
            embedding = [float(x) for x in cols[1].split(",")]
            filenames.append(filename)
            embeddings.append(embedding)
    embeddings = np.array(embeddings)
    index = nmslib.init(method="hnsw", space="cosinesimil")
    index.addDataPointBatch(embeddings)
    index.createIndex({"post": 2}, print_progress=True)
    return filenames, index


def load_model(model_name="koclip/koclip-base"):
    state = GlobalState(model_name)
    if not hasattr(state, '_lock'):
        state._lock = Lock()
    print(f"Locking loading of model : {model_name} to avoid concurrent caching.")

    with state._lock:
        cached_model = load_model_cached(model_name)

    print(f"Unlocking loading of model : {model_name} to avoid concurrent caching.")
    return cached_model


@st.cache(allow_output_mutation=True)
def load_model_cached(model_name):
    assert model_name in {f"koclip/{model}" for model in MODEL_LIST}
    model = FlaxHybridCLIP.from_pretrained(model_name)
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    processor.tokenizer = AutoTokenizer.from_pretrained("klue/roberta-large")
    if model_name == "koclip/koclip-large":
        processor.feature_extractor = ViTFeatureExtractor.from_pretrained(
            "google/vit-large-patch16-224"
        )
    return model, processor