rayistern's picture
Added embedding handler
5901795 verified
from transformers import AutoModel, AutoTokenizer
import torch
class EndpointHandler():
def __init__(self, path=""):
# Initialize the tokenizer and model with pre-trained weights
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModel.from_pretrained(path)
def __call__(self, data):
# Extract text input from the request data
inputs = data['inputs']
# Define a prompt to provide context
prompt = "Contextual understanding of the following text, from the perspective of Chassidic philosophy: "
# Combine prompt with the actual input
combined_input = prompt + inputs
# Prepare the text for the model
encoded_input = self.tokenizer(combined_input, return_tensors='pt', padding=True, truncation=True, max_length=512)
# Generate embeddings without updating gradients
with torch.no_grad():
outputs = self.model(**encoded_input)
# Extract embeddings from the last hidden layer
embeddings = outputs.last_hidden_state.squeeze().tolist()
# Return the embeddings as a list (serialized format)
return {'embeddings': embeddings}