Spaces:
Running
Running
# utils/risk_detector.py | |
from transformers import pipeline | |
# βοΈ Load zero-shot classification model | |
classifier = pipeline("zero-shot-classification", model="typeform/distilbert-base-uncased-mnli") | |
# π― Define risk-related labels (can expand as needed) | |
labels = ["Indemnity", "Exclusivity", "Termination", "Jurisdiction", "Confidentiality", "Fees"] | |
# Optional fallback suggestions | |
fallbacks = { | |
"Indemnity": "Consider adding a mutual indemnification clause or capping liability.", | |
"Exclusivity": "Suggest clarifying duration and scope of exclusivity.", | |
"Termination": "Check for balanced termination rights and notice period.", | |
"Jurisdiction": "Ensure forum is neutral or matches your operational base.", | |
"Confidentiality": "Include a clear definition of confidential information and duration.", | |
"Fees": "Ensure clarity on payment structure, late fees, and reimbursement terms." | |
} | |
# ========== Core Function ========== | |
def detect_risks(text, verbose=False): | |
""" | |
Detect and classify legal risks across multiple clauses. | |
Returns: | |
- List of tuples (clause_text, label, score, fallback) if verbose=True | |
- Otherwise: List of (label, score) tuples aggregated | |
""" | |
if not text.strip(): | |
return [] | |
# Break into clauses (simple split by period, can be improved) | |
clauses = [c.strip() for c in text.split(".") if len(c.strip()) > 20] | |
all_results = [] | |
for clause in clauses: | |
result = classifier(clause[:1000], candidate_labels=labels, multi_label=True) | |
top_labels = list(zip(result["labels"], result["scores"])) | |
if verbose: | |
top_risks = [(lbl, score) for lbl, score in top_labels if score >= 0.5] | |
for lbl, score in top_risks: | |
all_results.append({ | |
"clause": clause, | |
"label": lbl, | |
"score": round(score, 3), | |
"suggestion": fallbacks.get(lbl, "") | |
}) | |
else: | |
all_results.extend(top_labels) | |
if verbose: | |
return all_results | |
else: | |
# Return aggregated top risks (non-verbose mode) | |
from collections import Counter | |
agg = Counter() | |
for lbl, score in all_results: | |
agg[lbl] += score | |
return sorted(agg.items(), key=lambda x: x[1], reverse=True) | |