SCBconsulting commited on
Commit
9d35238
·
verified ·
1 Parent(s): 6f26572

Update utils/risk_detector.py

Browse files
Files changed (1) hide show
  1. utils/risk_detector.py +44 -13
utils/risk_detector.py CHANGED
@@ -5,27 +5,58 @@ from transformers import pipeline
5
  # ⚖️ Load zero-shot classification model
6
  classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
7
 
8
- # 🎯 Define risk-related labels
9
- labels = ["Indemnity", "Exclusivity", "Termination", "Jurisdiction", "Confidentiality"]
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def detect_risks(text, verbose=False):
12
  """
13
- Classify clauses into predefined legal risk categories.
14
- If verbose=True, include detailed scores for each label.
15
 
16
  Returns:
17
- - List of (label, score) tuples (default)
18
- - Or dict of full model output if verbose
19
  """
20
  if not text.strip():
21
  return []
22
 
23
- result = classifier(text[:1000], candidate_labels=labels, multi_label=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  if verbose:
26
- return {
27
- "sequence": result["sequence"],
28
- "predictions": list(zip(result["labels"], result["scores"]))
29
- }
30
-
31
- return list(zip(result["labels"], result["scores"]))
 
 
 
5
  # ⚖️ Load zero-shot classification model
6
  classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
7
 
8
+ # 🎯 Define risk-related labels (can expand as needed)
9
+ labels = ["Indemnity", "Exclusivity", "Termination", "Jurisdiction", "Confidentiality", "Fees"]
10
+
11
+ # Optional fallback suggestions
12
+ fallbacks = {
13
+ "Indemnity": "Consider adding a mutual indemnification clause or capping liability.",
14
+ "Exclusivity": "Suggest clarifying duration and scope of exclusivity.",
15
+ "Termination": "Check for balanced termination rights and notice period.",
16
+ "Jurisdiction": "Ensure forum is neutral or matches your operational base.",
17
+ "Confidentiality": "Include a clear definition of confidential information and duration.",
18
+ "Fees": "Ensure clarity on payment structure, late fees, and reimbursement terms."
19
+ }
20
+
21
+ # ========== Core Function ==========
22
 
23
  def detect_risks(text, verbose=False):
24
  """
25
+ Detect and classify legal risks across multiple clauses.
 
26
 
27
  Returns:
28
+ - List of tuples (clause_text, label, score, fallback) if verbose=True
29
+ - Otherwise: List of (label, score) tuples aggregated
30
  """
31
  if not text.strip():
32
  return []
33
 
34
+ # Break into clauses (simple split by period, can be improved)
35
+ clauses = [c.strip() for c in text.split(".") if len(c.strip()) > 20]
36
+ all_results = []
37
+
38
+ for clause in clauses:
39
+ result = classifier(clause[:1000], candidate_labels=labels, multi_label=True)
40
+ top_labels = list(zip(result["labels"], result["scores"]))
41
+
42
+ if verbose:
43
+ top_risks = [(lbl, score) for lbl, score in top_labels if score >= 0.5]
44
+ for lbl, score in top_risks:
45
+ all_results.append({
46
+ "clause": clause,
47
+ "label": lbl,
48
+ "score": round(score, 3),
49
+ "suggestion": fallbacks.get(lbl, "")
50
+ })
51
+ else:
52
+ all_results.extend(top_labels)
53
 
54
  if verbose:
55
+ return all_results
56
+ else:
57
+ # Return aggregated top risks (non-verbose mode)
58
+ from collections import Counter
59
+ agg = Counter()
60
+ for lbl, score in all_results:
61
+ agg[lbl] += score
62
+ return sorted(agg.items(), key=lambda x: x[1], reverse=True)