import gradio as gr from transformers import AutoModelForSequenceClassification, AutoTokenizer import torch # Load model and tokenizer model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) model.eval() # Function to compute relevance scores (in logits) and dynamically adjust threshold def get_relevance_score_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight): # Handle empty input for paragraphs paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()] if not query.strip() or not paragraphs: return "Please provide both a query and at least one document paragraph.", "" ranked_paragraphs = [] # Process each paragraph and calculate its logits and highlighted text for paragraph in paragraphs: # Tokenize the input inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): output = model(**inputs, output_attentions=True) # Extract logits (no sigmoid applied) logit = output.logits.squeeze().item() base_relevance_score = logit # Relevance score in logits # Dynamically adjust the attention threshold based on user weight dynamic_threshold = max(0.02, threshold_weight) # Extract attention scores (last layer) attention = output.attentions[-1] attention_scores = attention.mean(dim=1).mean(dim=0) query_tokens = tokenizer.tokenize(query) paragraph_tokens = tokenizer.tokenize(paragraph) query_len = len(query_tokens) + 2 # +2 for special tokens [CLS] and first [SEP] para_start_idx = query_len para_end_idx = len(inputs["input_ids"][0]) - 1 if para_end_idx <= para_start_idx: continue para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0) if para_attention_scores.numel() == 0: continue # Get indices of relevant tokens above dynamic threshold relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist() # Reconstruct paragraph with bolded relevant tokens using HTML tags highlighted_text = "" for idx, token in enumerate(paragraph_tokens): if idx in relevant_indices: highlighted_text += f"{token} " else: highlighted_text += f"{token} " highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split()) ranked_paragraphs.append({ "logit": logit, "highlighted_text": highlighted_text }) # Sort paragraphs by logit (descending) ranked_paragraphs.sort(key=lambda x: x["logit"], reverse=True) # Prepare output: Combine scores and highlighted text in a readable format output_html = "" output_html += "" for item in ranked_paragraphs: output_html += f"" output_html += f"" output_html += "
Relevance Score (Logits)Highlighted Paragraph
{round(item['logit'], 4)}{item['highlighted_text']}
" return output_html # Define Gradio interface with a slider for threshold adjustment and multiple paragraphs input interface = gr.Interface( fn=get_relevance_score_and_excerpt, inputs=[ gr.Textbox(label="Query", placeholder="Enter your search query..."), gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4), gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4), gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4), gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold") ], outputs=[ gr.HTML(label="Ranked Paragraphs") ], title="Cross-Encoder Attention Highlighting with Reranking", description="Adjust the attention threshold to control token highlighting sensitivity. Multiple paragraphs can be added and reranked based on their logits.", allow_flagging="never", live=True ) if __name__ == "__main__": interface.launch()