SCBconsulting commited on
Commit
404a876
Β·
verified Β·
1 Parent(s): 8521187

Update utils/fallback_suggester.py

Browse files
Files changed (1) hide show
  1. utils/fallback_suggester.py +22 -6
utils/fallback_suggester.py CHANGED
@@ -3,24 +3,40 @@
3
  import json
4
  from sentence_transformers import SentenceTransformer, util
5
 
 
6
  model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
7
 
 
8
  with open("fallback_clauses.json", "r", encoding="utf-8") as f:
9
  clause_bank = json.load(f)
10
 
 
11
  clause_labels = list(clause_bank.keys())
12
  clause_texts = list(clause_bank.values())
13
  clause_embeddings = model.encode(clause_texts, convert_to_tensor=True)
14
 
15
- def suggest_fallback(input_clause):
 
 
 
 
 
 
 
 
 
 
16
  if not input_clause or len(input_clause.strip()) == 0:
17
- return "No input clause provided."
18
 
19
  input_embedding = model.encode(input_clause, convert_to_tensor=True)
20
  scores = util.cos_sim(input_embedding, clause_embeddings)[0]
21
- best_idx = scores.argmax().item()
22
 
23
- label = clause_labels[best_idx]
24
- suggestion = clause_texts[best_idx]
 
 
 
25
 
26
- return f"πŸ”Ή {label} β†’ {suggestion}"
 
3
  import json
4
  from sentence_transformers import SentenceTransformer, util
5
 
6
+ # πŸ” Load pre-trained semantic similarity model
7
  model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
8
 
9
+ # πŸ“š Load fallback clause database
10
  with open("fallback_clauses.json", "r", encoding="utf-8") as f:
11
  clause_bank = json.load(f)
12
 
13
+ # πŸ”‘ Extract clause labels and text
14
  clause_labels = list(clause_bank.keys())
15
  clause_texts = list(clause_bank.values())
16
  clause_embeddings = model.encode(clause_texts, convert_to_tensor=True)
17
 
18
+ def suggest_fallback(input_clause: str, top_k: int = 3):
19
+ """
20
+ Suggest top-k fallback clauses based on semantic similarity.
21
+
22
+ Args:
23
+ input_clause (str): The clause to analyze.
24
+ top_k (int): Number of fallback suggestions to return.
25
+
26
+ Returns:
27
+ str: Formatted fallback suggestions.
28
+ """
29
  if not input_clause or len(input_clause.strip()) == 0:
30
+ return "⚠️ No input clause provided."
31
 
32
  input_embedding = model.encode(input_clause, convert_to_tensor=True)
33
  scores = util.cos_sim(input_embedding, clause_embeddings)[0]
34
+ top_indices = scores.topk(k=min(top_k, len(clause_labels))).indices.tolist()
35
 
36
+ results = []
37
+ for idx in top_indices:
38
+ label = clause_labels[idx]
39
+ suggestion = clause_texts[idx]
40
+ results.append(f"πŸ”Ή {label}:\n{suggestion}")
41
 
42
+ return "\n\n".join(results)