Spaces:
Running
Running
from transformers import BertTokenizer, BertModel | |
import torch | |
from sklearn.metrics.pairwise import cosine_similarity | |
import numpy as np | |
# Load BERT tokenizer and model | |
bert_model_name = "bert-base-uncased" | |
tokenizer = BertTokenizer.from_pretrained(bert_model_name) | |
model = BertModel.from_pretrained(bert_model_name) | |
model.eval() # Set to evaluation mode | |
# Function to obtain BERT embeddings | |
def get_bert_embeddings(texts): | |
"""Obtain BERT embeddings for a list of texts.""" | |
embeddings = [] | |
with torch.no_grad(): | |
for text in texts: | |
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
outputs = model(**inputs) | |
# Take the mean of token embeddings as the sentence embedding | |
embedding = outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
embeddings.append(embedding) | |
return np.array(embeddings) | |
# Compute similarity matrices over embeddings | |
def compute_similarity(embeddings1, embeddings2): | |
"""Compute pairwise cosine similarity between two sets of embeddings.""" | |
return cosine_similarity(embeddings1, embeddings2) | |
# Compare a paragraph with a list of other paragraphs | |
def compare_summaries(paragraph, paragraphs): | |
""" | |
Compare a single paragraph with a list of summaries, | |
and return the most similar summary along with the similarity score. | |
""" | |
# Get embeddings for the paragraph and the list of summaries | |
paragraph_embedding = get_bert_embeddings([paragraph])[0] # Single paragraph embedding | |
summaries_embeddings = get_bert_embeddings(paragraphs) # Embeddings for list of paragraphs | |
# Compute similarity between the paragraph and each summary | |
similarities = compute_similarity([paragraph_embedding], summaries_embeddings)[0] | |
# Find the most similar summary | |
most_similar_index = np.argmax(similarities) # Get index of most similar summary | |
most_similar_summary = paragraphs[most_similar_index] # Corresponding summary | |
similarity_score = similarities[most_similar_index] # Similarity score | |
return most_similar_summary | |