Trent
Multi model select and local model loading
a41bdbc
raw
history blame
901 Bytes
import pandas as pd
import jax.numpy as jnp
from typing import List
# Defining cosine similarity using flax.
from backend.utils import load_model
def cos_sim(a, b):
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
# We get similarity between embeddings.
def text_similarity(anchor: str, inputs: List[str], model_name: str):
model = load_model(model_name)
# Creating embeddings
anchor_emb = model.encode(anchor)[None, :]
inputs_emb = model.encode([input for input in inputs])
# Obtaining similarity
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
# Returning a Pandas' dataframe
d = {'inputs': [input for input in inputs],
'score': [round(similarity[i], 3) for i in range(len(similarity))]}
df = pd.DataFrame(d, columns=['inputs', 'score'])
return df.sort_values('score', ascending=False)