Fitness_QA_Bot / utils /embedding_utils.py
lstetson's picture
Upload folder using huggingface_hub
0685af6 verified
raw
history blame
757 Bytes
import torch
from transformers import AutoTokenizer, AutoModel
from chromadb import Documents, EmbeddingFunction, Embeddings
model_name = "YituTech/conv-bert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
class MyEmbeddingFunction(EmbeddingFunction[Documents]):
def __call__(self, input: Documents) -> Embeddings:
embeddings_list = []
for text in input:
tokens = tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = model(**tokens)
embeddings = outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
embeddings_list.append(embeddings)
return embeddings_list