synclm-demo / utils /risk_detector.py
SCBconsulting's picture
Update utils/risk_detector.py
9d35238 verified
# utils/risk_detector.py
from transformers import pipeline
# βš–οΈ Load zero-shot classification model
classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli")
# 🎯 Define risk-related labels (can expand as needed)
labels = ["Indemnity", "Exclusivity", "Termination", "Jurisdiction", "Confidentiality", "Fees"]
# Optional fallback suggestions
fallbacks = {
"Indemnity": "Consider adding a mutual indemnification clause or capping liability.",
"Exclusivity": "Suggest clarifying duration and scope of exclusivity.",
"Termination": "Check for balanced termination rights and notice period.",
"Jurisdiction": "Ensure forum is neutral or matches your operational base.",
"Confidentiality": "Include a clear definition of confidential information and duration.",
"Fees": "Ensure clarity on payment structure, late fees, and reimbursement terms."
}
# ========== Core Function ==========
def detect_risks(text, verbose=False):
"""
Detect and classify legal risks across multiple clauses.
Returns:
- List of tuples (clause_text, label, score, fallback) if verbose=True
- Otherwise: List of (label, score) tuples aggregated
"""
if not text.strip():
return []
# Break into clauses (simple split by period, can be improved)
clauses = [c.strip() for c in text.split(".") if len(c.strip()) > 20]
all_results = []
for clause in clauses:
result = classifier(clause[:1000], candidate_labels=labels, multi_label=True)
top_labels = list(zip(result["labels"], result["scores"]))
if verbose:
top_risks = [(lbl, score) for lbl, score in top_labels if score >= 0.5]
for lbl, score in top_risks:
all_results.append({
"clause": clause,
"label": lbl,
"score": round(score, 3),
"suggestion": fallbacks.get(lbl, "")
})
else:
all_results.extend(top_labels)
if verbose:
return all_results
else:
# Return aggregated top risks (non-verbose mode)
from collections import Counter
agg = Counter()
for lbl, score in all_results:
agg[lbl] += score
return sorted(agg.items(), key=lambda x: x[1], reverse=True)