File size: 4,599 Bytes
7472a45 e64db30 a1e51f4 7472a45 a1e51f4 e64db30 5552636 5933aa3 271e600 6700bfc a1e51f4 6700bfc 43a9654 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc a1e51f4 6700bfc 43a9654 6700bfc 43a9654 6700bfc 5933aa3 6700bfc 5933aa3 43a9654 5933aa3 a1e51f4 43a9654 a1e51f4 6700bfc 0f56dc9 43a9654 5933aa3 a1e51f4 6700bfc a1e51f4 43a9654 a1e51f4 271e600 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
# Function to compute relevance scores (in logits) and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
# Handle empty input for paragraphs
paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]
if not query.strip() or not paragraphs:
return "Please provide both a query and at least one document paragraph.", ""
ranked_paragraphs = []
# Process each paragraph and calculate its logits and highlighted text
for paragraph in paragraphs:
# Tokenize the input
inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
output = model(**inputs, output_attentions=True)
# Extract logits (no sigmoid applied)
logit = output.logits.squeeze().item()
base_relevance_score = logit # Relevance score in logits
# Dynamically adjust the attention threshold based on user weight
dynamic_threshold = max(0.02, threshold_weight)
# Extract attention scores (last layer)
attention = output.attentions[-1]
attention_scores = attention.mean(dim=1).mean(dim=0)
query_tokens = tokenizer.tokenize(query)
paragraph_tokens = tokenizer.tokenize(paragraph)
query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP]
para_start_idx = query_len
para_end_idx = len(inputs["input_ids"][0]) - 1
if para_end_idx <= para_start_idx:
continue
para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
if para_attention_scores.numel() == 0:
continue
# Get indices of relevant tokens above dynamic threshold
relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
# Reconstruct paragraph with bolded relevant tokens using HTML tags
highlighted_text = ""
for idx, token in enumerate(paragraph_tokens):
if idx in relevant_indices:
highlighted_text += f"<b>{token}</b> "
else:
highlighted_text += f"{token} "
highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
ranked_paragraphs.append({
"logit": logit,
"highlighted_text": highlighted_text
})
# Sort paragraphs by logit (descending)
ranked_paragraphs.sort(key=lambda x: x["logit"], reverse=True)
# Prepare output: Combine scores and highlighted text in a readable format
output_html = "<table border='1' style='width:100%; border-collapse: collapse;'>"
output_html += "<tr><th style='padding: 8px;'>Relevance Score (Logits)</th><th style='padding: 8px;'>Highlighted Paragraph</th></tr>"
for item in ranked_paragraphs:
output_html += f"<tr><td style='padding: 8px; text-align: center;'>{round(item['logit'], 4)}</td>"
output_html += f"<td style='padding: 8px;'>{item['highlighted_text']}</td></tr>"
output_html += "</table>"
return output_html
# Define Gradio interface with a slider for threshold adjustment and multiple paragraphs input
interface = gr.Interface(
fn=get_relevance_score_and_excerpt,
inputs=[
gr.Textbox(label="Query", placeholder="Enter your search query..."),
gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4),
gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4),
gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4),
gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
],
outputs=[
gr.HTML(label="Ranked Paragraphs")
],
title="Cross-Encoder Attention Highlighting with Reranking",
description="Adjust the attention threshold to control token highlighting sensitivity. Multiple paragraphs can be added and reranked based on their logits.",
allow_flagging="never",
live=True
)
if __name__ == "__main__":
interface.launch() |