File size: 4,599 Bytes
7472a45
e64db30
a1e51f4
7472a45
 
 
 
a1e51f4
e64db30
5552636
5933aa3
271e600
 
 
 
 
6700bfc
a1e51f4
6700bfc
43a9654
6700bfc
 
 
 
 
 
 
a1e51f4
6700bfc
 
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
a1e51f4
6700bfc
 
 
 
 
 
 
43a9654
6700bfc
43a9654
6700bfc
 
 
 
 
 
 
 
5933aa3
 
 
6700bfc
5933aa3
 
 
43a9654
5933aa3
 
 
 
 
a1e51f4
 
43a9654
a1e51f4
6700bfc
 
 
0f56dc9
43a9654
 
5933aa3
a1e51f4
6700bfc
 
a1e51f4
 
43a9654
 
a1e51f4
271e600
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "cross-encoder/ms-marco-MiniLM-L-12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()

# Function to compute relevance scores (in logits) and dynamically adjust threshold
def get_relevance_score_and_excerpt(query, paragraph1, paragraph2, paragraph3, threshold_weight):
    # Handle empty input for paragraphs
    paragraphs = [p for p in [paragraph1, paragraph2, paragraph3] if p.strip()]
    
    if not query.strip() or not paragraphs:
        return "Please provide both a query and at least one document paragraph.", ""

    ranked_paragraphs = []
    
    # Process each paragraph and calculate its logits and highlighted text
    for paragraph in paragraphs:
        # Tokenize the input
        inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
        
        with torch.no_grad():
            output = model(**inputs, output_attentions=True)

        # Extract logits (no sigmoid applied)
        logit = output.logits.squeeze().item()
        base_relevance_score = logit  # Relevance score in logits

        # Dynamically adjust the attention threshold based on user weight
        dynamic_threshold = max(0.02, threshold_weight)

        # Extract attention scores (last layer)
        attention = output.attentions[-1]
        attention_scores = attention.mean(dim=1).mean(dim=0)

        query_tokens = tokenizer.tokenize(query)
        paragraph_tokens = tokenizer.tokenize(paragraph)

        query_len = len(query_tokens) + 2  # +2 for special tokens [CLS] and first [SEP]
        para_start_idx = query_len
        para_end_idx = len(inputs["input_ids"][0]) - 1

        if para_end_idx <= para_start_idx:
            continue

        para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)

        if para_attention_scores.numel() == 0:
            continue

        # Get indices of relevant tokens above dynamic threshold
        relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()

        # Reconstruct paragraph with bolded relevant tokens using HTML tags
        highlighted_text = ""
        for idx, token in enumerate(paragraph_tokens):
            if idx in relevant_indices:
                highlighted_text += f"<b>{token}</b> "
            else:
                highlighted_text += f"{token} "

        highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())

        ranked_paragraphs.append({
            "logit": logit,
            "highlighted_text": highlighted_text
        })
    
    # Sort paragraphs by logit (descending)
    ranked_paragraphs.sort(key=lambda x: x["logit"], reverse=True)

    # Prepare output: Combine scores and highlighted text in a readable format
    output_html = "<table border='1' style='width:100%; border-collapse: collapse;'>"
    output_html += "<tr><th style='padding: 8px;'>Relevance Score (Logits)</th><th style='padding: 8px;'>Highlighted Paragraph</th></tr>"

    for item in ranked_paragraphs:
        output_html += f"<tr><td style='padding: 8px; text-align: center;'>{round(item['logit'], 4)}</td>"
        output_html += f"<td style='padding: 8px;'>{item['highlighted_text']}</td></tr>"

    output_html += "</table>"

    return output_html

# Define Gradio interface with a slider for threshold adjustment and multiple paragraphs input
interface = gr.Interface(
    fn=get_relevance_score_and_excerpt,
    inputs=[
        gr.Textbox(label="Query", placeholder="Enter your search query..."),
        gr.Textbox(label="Document Paragraph 1", placeholder="Enter a paragraph to match...", lines=4),
        gr.Textbox(label="Document Paragraph 2 (optional)", placeholder="Enter another paragraph...", lines=4),
        gr.Textbox(label="Document Paragraph 3 (optional)", placeholder="Enter another paragraph...", lines=4),
        gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
    ],
    outputs=[
        gr.HTML(label="Ranked Paragraphs")
    ],
    title="Cross-Encoder Attention Highlighting with Reranking",
    description="Adjust the attention threshold to control token highlighting sensitivity. Multiple paragraphs can be added and reranked based on their logits.",
    allow_flagging="never",
    live=True
)

if __name__ == "__main__":
    interface.launch()