Spaces:
Runtime error
Runtime error
File size: 1,731 Bytes
6a3aaf8 6831777 6a3aaf8 88d581a 6a3aaf8 6831777 fb53c32 6a3aaf8 5d35937 6a3aaf8 8182466 6a3aaf8 88d581a e30b005 6a3aaf8 e30b005 c4f4d05 66d0fee 88d581a 6a3aaf8 c4f4d05 efae79d 6a3aaf8 c4f4d05 6a3aaf8 66d0fee c4f4d05 66d0fee c4f4d05 66d0fee 6a3aaf8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
import datasets
import faiss
import os
from transformers import pipeline
auth_token = os.environ.get("CLARIN_KNEXT")
sample_text = (
"Europejscy astronomowie odkryli planetę "
"pozasłoneczną pochodzącą spoza naszej galaktyki, czyli "
"[unused0] Drogi Mlecznej [unused1]. Obserwacji dokonali "
"2,2-metrowym teleskopem MPG/ESO."
)
textbox = gr.Textbox(
label="Type your query here.",
placeholder=sample_text, lines=10
)
def load_index(index_data: str = "clarin-knext/entity-linking-index"):
ds = datasets.load_dataset(index_data, use_auth_token=auth_token)['train']
index_data = {
idx: (e_id, e_text) for idx, (e_id, e_text) in
enumerate(zip(ds['entities'], ds['texts']))
}
faiss_index = faiss.read_index("./encoder.faissindex", faiss.IO_FLAG_MMAP)
return index_data, faiss_index
def load_model(model_name: str = "clarin-knext/entity-linking-encoder"):
model = pipeline("feature-extraction", model=model_name, use_auth_token=auth_token)
return model
model = load_model()
index = load_index()
def predict(text: str = sample_text, top_k: int=3):
text = text + "".join(['[PAD]' * 252])
index_data, faiss_index = index
# takes only the [CLS] embedding (for now)
query = model(text, return_tensors='pt')[0][0].numpy().reshape(1, -1)
scores, indices = faiss_index.search(query, top_k)
scores, indices = scores.tolist(), indices.tolist()
results = "\n".join([
f"{index_data[result[0]]}: {result[1]}"
for output in zip(indices, scores)
for result in zip(*output)
])
return results
demo = gr.Interface(fn=predict, inputs=textbox, outputs="text").launch() |