|
import numpy as np |
|
import torch |
|
import pandas as pd |
|
import translate |
|
import gradio as gr |
|
|
|
data = pd.read_csv("./embedding_data.csv") |
|
embeddings = np.load("./embeddings.npy") |
|
|
|
def normalize_vector(v): |
|
norm = np.linalg.norm(v) |
|
if norm == 0: |
|
return v |
|
return v / norm |
|
|
|
|
|
def embed_one(model, tokenizer, text, normalize=True): |
|
tokens = tokenizer(text, return_tensors="pt", truncation=True) |
|
with torch.no_grad(): |
|
embedding = model.model.encoder(**tokens).last_hidden_state.mean(axis=1) |
|
embedding = embedding.detach().numpy()[0] |
|
|
|
if normalize: |
|
return normalize_vector(embedding) |
|
else: |
|
return embedding |
|
|
|
|
|
def knn(query_embedding, embeddings, df, k=5, hebrew=True): |
|
sims = np.dot(embeddings, query_embedding.T) |
|
outs = np.argsort(sims, axis=0)[-k:][::-1] |
|
select = outs.ravel() |
|
if hebrew: |
|
return df.iloc[select][["arabic", "hebrew", "validated"]] |
|
else: |
|
return df.iloc[select][["arabic", "english", "validated"]] |
|
|
|
def run_knn(text, k=5): |
|
print(text) |
|
query_embedding = embed_one(translate.model_from_ar, |
|
translate.tokenizer_from_ar, text) |
|
return knn(query_embedding, embeddings, data, k=k, hebrew=True) |
|
|
|
|
|
def style_dataframe(df): |
|
styled_df = df.style.set_properties(**{ |
|
'font-family': 'Arial, sans-serif', |
|
'font-size': '20px', |
|
'text-align': 'right', |
|
'direction': 'rtl', |
|
'align': 'right' |
|
}).set_table_styles([ |
|
{'selector': 'th', 'props': [('text-align', 'right')]} |
|
]) |
|
return styled_df |
|
|
|
|
|
def style_dataframe(df): |
|
return df.style.set_table_styles([ |
|
{'selector': 'thead', 'props': [('text-align', 'right')]}, |
|
{'selector': '.index_name', 'props': [('text-align', 'right')]}, |
|
]).set_properties(**{ |
|
'text-align': 'right', |
|
}) |
|
|
|
|
|
def update_df(hidden_arabic): |
|
df = run_knn(hidden_arabic, 100) |
|
|
|
df["validated"] = df["validated"].apply(lambda x: "✅" if x else "❌") |
|
|
|
df = df.rename(columns={"validated": "מאומת"}) |
|
|
|
df = df.rename(columns={"arabic": "ערבית"}) |
|
|
|
df = df.rename(columns={"hebrew": "עברית"}) |
|
styled_df = style_dataframe(df) |
|
return gr.DataFrame(value=styled_df, visible=True) |