wilwork commited on
Commit
0f56dc9
·
verified ·
1 Parent(s): c9d7e8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -8,10 +8,10 @@ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval()
10
 
11
- # Function to compute relevance score and dynamically adjust threshold
12
  def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
13
  if not query.strip() or not paragraph.strip():
14
- return "Please provide both a query and a document paragraph.", "", ""
15
 
16
  # Tokenize the input
17
  inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
@@ -19,13 +19,12 @@ def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
19
  with torch.no_grad():
20
  output = model(**inputs, output_attentions=True)
21
 
22
- # Extract logits and calculate base relevance score
23
  logit = output.logits.squeeze().item()
24
- base_relevance_score = torch.sigmoid(torch.tensor(logit)).item()
25
 
26
- # Calculate dynamic threshold using sigmoid-based formula
27
- sigmoid_factor = 1 / (1 + torch.exp(torch.tensor(-5 * (base_relevance_score - 0.5)))).item()
28
- dynamic_threshold = max(0.02, threshold_weight * sigmoid_factor)
29
 
30
  # Extract attention scores (last layer)
31
  attention = output.attentions[-1]
@@ -39,12 +38,12 @@ def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
39
  para_end_idx = len(inputs["input_ids"][0]) - 1
40
 
41
  if para_end_idx <= para_start_idx:
42
- return round(base_relevance_score, 4), round(dynamic_threshold, 4), "No relevant tokens extracted."
43
 
44
  para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
45
 
46
  if para_attention_scores.numel() == 0:
47
- return round(base_relevance_score, 4), round(dynamic_threshold, 4), "No relevant tokens extracted."
48
 
49
  # Get indices of relevant tokens above dynamic threshold
50
  relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
@@ -59,7 +58,7 @@ def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
59
 
60
  highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
61
 
62
- return round(base_relevance_score, 4), round(dynamic_threshold, 4), highlighted_text
63
 
64
  # Define Gradio interface with a slider for threshold adjustment
65
  interface = gr.Interface(
@@ -67,15 +66,14 @@ interface = gr.Interface(
67
  inputs=[
68
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
69
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
70
- gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Threshold Weight")
71
  ],
72
  outputs=[
73
- gr.Textbox(label="Relevance Score"),
74
- gr.Textbox(label="Dynamic Threshold"),
75
  gr.HTML(label="Highlighted Document Paragraph")
76
  ],
77
  title="Cross-Encoder Attention Highlighting",
78
- description="Adjust the threshold weight to influence dynamic token highlighting based on relevance.",
79
  allow_flagging="never",
80
  live=True
81
  )
 
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  model.eval()
10
 
11
+ # Function to compute relevance score (in logits) and dynamically adjust threshold
12
  def get_relevance_score_and_excerpt(query, paragraph, threshold_weight):
13
  if not query.strip() or not paragraph.strip():
14
+ return "Please provide both a query and a document paragraph.", ""
15
 
16
  # Tokenize the input
17
  inputs = tokenizer(query, paragraph, return_tensors="pt", truncation=True, padding=True)
 
19
  with torch.no_grad():
20
  output = model(**inputs, output_attentions=True)
21
 
22
+ # Extract logits (no sigmoid applied)
23
  logit = output.logits.squeeze().item()
24
+ base_relevance_score = logit # Relevance score in logits
25
 
26
+ # Dynamically adjust the attention threshold based on user weight (no relevance score influence)
27
+ dynamic_threshold = max(0.02, threshold_weight)
 
28
 
29
  # Extract attention scores (last layer)
30
  attention = output.attentions[-1]
 
38
  para_end_idx = len(inputs["input_ids"][0]) - 1
39
 
40
  if para_end_idx <= para_start_idx:
41
+ return round(base_relevance_score, 4), "No relevant tokens extracted."
42
 
43
  para_attention_scores = attention_scores[para_start_idx:para_end_idx, para_start_idx:para_end_idx].mean(dim=0)
44
 
45
  if para_attention_scores.numel() == 0:
46
+ return round(base_relevance_score, 4), "No relevant tokens extracted."
47
 
48
  # Get indices of relevant tokens above dynamic threshold
49
  relevant_indices = (para_attention_scores > dynamic_threshold).nonzero(as_tuple=True)[0].tolist()
 
58
 
59
  highlighted_text = tokenizer.convert_tokens_to_string(highlighted_text.split())
60
 
61
+ return round(base_relevance_score, 4), highlighted_text
62
 
63
  # Define Gradio interface with a slider for threshold adjustment
64
  interface = gr.Interface(
 
66
  inputs=[
67
  gr.Textbox(label="Query", placeholder="Enter your search query..."),
68
  gr.Textbox(label="Document Paragraph", placeholder="Enter a paragraph to match..."),
69
+ gr.Slider(minimum=0.02, maximum=0.5, value=0.1, step=0.01, label="Attention Threshold")
70
  ],
71
  outputs=[
72
+ gr.Textbox(label="Relevance Score (Logits)"),
 
73
  gr.HTML(label="Highlighted Document Paragraph")
74
  ],
75
  title="Cross-Encoder Attention Highlighting",
76
+ description="Adjust the attention threshold to control token highlighting sensitivity.",
77
  allow_flagging="never",
78
  live=True
79
  )