gauravchand11 commited on
Commit
a60593a
·
verified ·
1 Parent(s): 7ee88b4

Upload assistant.py

Browse files
Files changed (1) hide show
  1. assistant.py +36 -0
assistant.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ MODEL_NAME = "google/gemma-2b-it"
5
+
6
+ class LegalEaseAssistant:
7
+ def __init__(self, model_name=MODEL_NAME):
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ self.model = AutoModelForCausalLM.from_pretrained(
10
+ model_name,
11
+ device_map="cpu",
12
+ load_in_8bit=True,
13
+ torch_dtype=torch.float16
14
+ )
15
+
16
+ def generate_response(self, text, task_type):
17
+ task_prompts = {
18
+ "simplify": f"Simplify the following legal text in clear, plain language:\n\n{text}\n\nSimplified explanation:",
19
+ "summary": f"Provide a concise summary of the following legal document:\n\n{text}\n\nSummary:",
20
+ "key_terms": f"Identify and explain key legal terms:\n\n{text}\n\nKey Terms:",
21
+ "risk": f"Perform a risk analysis:\n\n{text}\n\nRisk Assessment:"
22
+ }
23
+
24
+ prompt = task_prompts.get(task_type, f"Analyze the following text:\n\n{text}\n\nAnalysis:")
25
+ inputs = self.tokenizer(prompt, return_tensors="pt")
26
+ outputs = self.model.generate(
27
+ **inputs,
28
+ max_new_tokens=300,
29
+ num_return_sequences=1,
30
+ do_sample=True,
31
+ temperature=0.7,
32
+ top_p=0.9
33
+ )
34
+
35
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ return response.split(prompt.split("\n\n")[-1])[-1].strip()