Spaces:
Runtime error
Runtime error
File size: 901 Bytes
6ae27e8 a41bdbc 6ae27e8 a41bdbc 6ae27e8 a41bdbc 6ae27e8 a41bdbc 6ae27e8 a41bdbc 6ae27e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import pandas as pd
import jax.numpy as jnp
from typing import List
# Defining cosine similarity using flax.
from backend.utils import load_model
def cos_sim(a, b):
return jnp.matmul(a, jnp.transpose(b)) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
# We get similarity between embeddings.
def text_similarity(anchor: str, inputs: List[str], model_name: str):
model = load_model(model_name)
# Creating embeddings
anchor_emb = model.encode(anchor)[None, :]
inputs_emb = model.encode([input for input in inputs])
# Obtaining similarity
similarity = list(jnp.squeeze(cos_sim(anchor_emb, inputs_emb)))
# Returning a Pandas' dataframe
d = {'inputs': [input for input in inputs],
'score': [round(similarity[i], 3) for i in range(len(similarity))]}
df = pd.DataFrame(d, columns=['inputs', 'score'])
return df.sort_values('score', ascending=False)
|