Spaces:
Running
Running
add application files
Browse files- app.py +78 -0
- config.py +34 -0
- model/__init__.py +0 -0
- model/chat.py +31 -0
- model/controller.py +18 -0
- model/llm/llm.py +117 -0
- model/processor/case_crawler.py +113 -0
- model/processor/database_Chunker.ipynb +0 -0
- model/processor/law_provider.py +61 -0
- model/processor/pre_process.ipynb +0 -0
- model/processor/retrieval_rag_nlp_project.ipynb:Zone.Identifier +0 -0
- model/propmt/__init__.py +0 -0
- model/propmt/prompt_handler.py +16 -0
- model/rag/__init__.py +0 -0
- model/rag/rag_handler.py +89 -0
- requirements.txt +26 -0
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from model.controller import Controller
|
4 |
+
import zipfile
|
5 |
+
|
6 |
+
os.chdir("/home/user/app")
|
7 |
+
|
8 |
+
os.system('wget -O processed_cases.csv "https://drive.usercontent.google.com/download?id=1jMuQtywo0mbj7ZHCCsyE8xurbSyVVCst&export=download&confirm=t&uuid=2f681c98-86f8-4159-9e03-673cdcbc7cb51"')
|
9 |
+
os.system('wget -O chromadb_collection.zip "https://drive.usercontent.google.com/download?id=1gz5-gxSlySEtPTzL_VPQ9e8jxHFuL0ZJ&export=download&confirm=t&uuid=de946efb-47b3-435d-b432-3bd5c01c73fb"')
|
10 |
+
|
11 |
+
with zipfile.ZipFile("chromadb_collection.zip", 'r') as zip_ref:
|
12 |
+
zip_ref.extractall()
|
13 |
+
|
14 |
+
os.system('mv content/chromadb_collections chromadb_collections')
|
15 |
+
os.system('rm -r content')
|
16 |
+
|
17 |
+
bot = Controller()
|
18 |
+
|
19 |
+
def chatbot_interface(user_input, chat_id=2311):
|
20 |
+
return bot.handle_message(chat_id, user_input)
|
21 |
+
|
22 |
+
def validate_input(user_input):
|
23 |
+
if not user_input or user_input.strip() == "":
|
24 |
+
return False, "🚫 Please enter a valid legal question. It cannot be empty."
|
25 |
+
if len(user_input) < 5:
|
26 |
+
return False, "⚠️ Your question is too short. Please provide more details."
|
27 |
+
return True, None
|
28 |
+
|
29 |
+
custom_css = """
|
30 |
+
@font-face {
|
31 |
+
font-family: 'Vazir';
|
32 |
+
src: url('https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/vf/Vazir.woff2') format('woff2'),
|
33 |
+
url('https://cdn.jsdelivr.net/gh/rastikerdar/vazir-font/vf/Vazir.woff') format('woff');
|
34 |
+
}
|
35 |
+
.gradio-container {
|
36 |
+
background-color: #f9f9f9;
|
37 |
+
}
|
38 |
+
.chatbox, .inputbox {
|
39 |
+
font-family: 'Vazir', sans-serif;
|
40 |
+
font-size: 16px;
|
41 |
+
}
|
42 |
+
"""
|
43 |
+
|
44 |
+
with gr.Blocks(css=custom_css) as interface:
|
45 |
+
|
46 |
+
gr.Markdown("""
|
47 |
+
<div style="text-align: center; font-family: 'Vazir';">
|
48 |
+
<h1 style="color: #4a90e2;">⚖️ RAG Law Chatbot ⚖️</h1>
|
49 |
+
<p style="font-size: 18px; color: #333;">Welcome to the legal chatbot! 👨⚖️👩⚖️<br>Ask any legal question, and our assistant will help you! 📜🏛️</p>
|
50 |
+
</div>
|
51 |
+
""")
|
52 |
+
|
53 |
+
# Organize the chatbot area in a column for vertical stacking
|
54 |
+
with gr.Column():
|
55 |
+
chatbot = gr.Chatbot(label="🧑⚖️ Legal Chatbot Assistant 🧑⚖️", height=400, elem_classes=["chatbox"])
|
56 |
+
|
57 |
+
# Use Row to align input and button horizontally
|
58 |
+
with gr.Row():
|
59 |
+
user_input = gr.Textbox(show_label=False, placeholder="Enter your law question here... ⚖️", container=True)
|
60 |
+
send_button = gr.Button("📤 Send")
|
61 |
+
|
62 |
+
# Chat update function to append new messages to the chatbot
|
63 |
+
def chat_update(user_message, history):
|
64 |
+
history = history or []
|
65 |
+
|
66 |
+
is_valid, validation_message = validate_input(user_message)
|
67 |
+
if not is_valid:
|
68 |
+
history.append((user_message, validation_message))
|
69 |
+
return history, ""
|
70 |
+
|
71 |
+
bot_reply = chatbot_interface(user_message)
|
72 |
+
history.append((user_message, bot_reply))
|
73 |
+
return history, ""
|
74 |
+
|
75 |
+
# Connect the button click to the chat update function
|
76 |
+
send_button.click(chat_update, [user_input, chatbot], [chatbot, user_input])
|
77 |
+
|
78 |
+
interface.launch()
|
config.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gpt_3_5 = "gpt-3.5-turbo-instruct"
|
2 |
+
gpt_mini = "gpt-4o-mini"
|
3 |
+
|
4 |
+
aval_ai = {
|
5 |
+
"model": gpt_3_5,
|
6 |
+
"base_url": "https://api.avalai.ir/v1",
|
7 |
+
|
8 |
+
}
|
9 |
+
|
10 |
+
GILAS_CONFIG = {
|
11 |
+
"api_key": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwMzg5OTQ0NjgsImp0aSI6IjExNDg4MzAyMTE3NDA0MzY2ODc0NiIsImlhdCI6MTcyMzYzNDQ2OCwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyMzYzNDQ2OCwic3ViIjoiMTE0ODgzMDIxMTc0MDQzNjY4NzQ2In0.8hbh59BmwBcAfoH9nEB98_5BIuxzwUUb8fpHSKF1S_Q",
|
12 |
+
"model": "gpt-4o-mini" ,
|
13 |
+
"base_url": 'https://api.gilas.io/v1',
|
14 |
+
}
|
15 |
+
|
16 |
+
GILAS_API_KEYS = [
|
17 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwMzg5OTQ0NjgsImp0aSI6IjExNDg4MzAyMTE3NDA0MzY2ODc0NiIsImlhdCI6MTcyMzYzNDQ2OCwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyMzYzNDQ2OCwic3ViIjoiMTE0ODgzMDIxMTc0MDQzNjY4NzQ2In0.8hbh59BmwBcAfoH9nEB98_5BIuxzwUUb8fpHSKF1S_Q",
|
18 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzI3NTYsImp0aSI6IjEwNjg5OTE1MjQwNTM4MzY3Nzc2NyIsImlhdCI6MTcyNzE3Mjc1NiwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3Mjc1Niwic3ViIjoiMTA2ODk5MTUyNDA1MzgzNjc3NzY3In0.Jgfi7BWhpXFTYdHe73md5p932EP75wTD-CZQ6SfGkK8",
|
19 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzMzNzIsImp0aSI6IjEwNjg4MTE2MzAzOTkzMTg2MjY3NiIsImlhdCI6MTcyNzE3MzM3MiwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3MzM3Miwic3ViIjoiMTA2ODgxMTYzMDM5OTMxODYyNjc2In0.PhVdoRUdaCfHa4va-EtWP5o7KISCSdMjT5mWtc9cefo",
|
20 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzM0MDIsImp0aSI6IjExNTY3MDAwOTQyMjcyNTE3NDE1NCIsImlhdCI6MTcyNzE3MzQwMiwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3MzQwMiwic3ViIjoiMTE1NjcwMDA5NDIyNzI1MTc0MTU0In0.IRcnkiZJdKNPTE1nYXoeiVMfxj9xXHSvAxBLaBGC6yk",
|
21 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzM1MzEsImp0aSI6IjExMzk2NzY4OTcxNjg2NjYzNDk3MCIsImlhdCI6MTcyNzE3MzUzMSwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3MzUzMSwic3ViIjoiMTEzOTY3Njg5NzE2ODY2NjM0OTcwIn0.kHZZDlVnZsbnoSac0wtM3ezrPCkIBYVQSdkfbFsT_xs",
|
22 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzM1ODksImp0aSI6IjEwNzM3MDcyODA4NDQxMTk0MTQwOSIsImlhdCI6MTcyNzE3MzU4OSwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3MzU4OSwic3ViIjoiMTA3MzcwNzI4MDg0NDExOTQxNDA5In0.4qhnj6YhunOHoAMmosibf4CaopJqSlvwxvhB6671Suw",
|
23 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzQ5ODEsImp0aSI6IjEwNjE2NTI5NzI5MjAxODExMzgwMCIsImlhdCI6MTcyNzE3NDk4MSwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3NDk4MSwic3ViIjoiMTA2MTY1Mjk3MjkyMDE4MTEzODAwIn0.9QvgxTlDugcDwSa880B0hefhWjVfEzjTDX2ywgNORrc",
|
24 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzUwNTIsImp0aSI6IjExMzA3MTQ4ODA5OTA0OTQzMDI0MSIsImlhdCI6MTcyNzE3NTA1MiwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3NTA1Miwic3ViIjoiMTEzMDcxNDg4MDk5MDQ5NDMwMjQxIn0.Z8TNrz_LXCtFjE0BwBLCBqh03uTKZ6WWLptQA6zdy1Y",
|
25 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1MzUxMjUsImp0aSI6IjExMTU3MzA2NjkwODIzNjk4MjM1OSIsImlhdCI6MTcyNzE3NTEyNSwiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE3NTEyNSwic3ViIjoiMTExNTczMDY2OTA4MjM2OTgyMzU5In0.eQIqXoSbsD19AJrQxCVh7T6tcLvCJ7TH3c8Ajso9CJU",
|
26 |
+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjIwNDI1NTM2NjcsImp0aSI6IjEwMTkyMDcyMjAwOTgxNDEwMDE5MiIsImlhdCI6MTcyNzE5MzY2NywiaXNzIjoiaHR0cHM6Ly9naWxhcy5pbyIsIm5iZiI6MTcyNzE5MzY2Nywic3ViIjoiMTAxOTIwNzIyMDA5ODE0MTAwMTkyIn0.WmYY-BbcsYcvgZmes_eH5AS-06imEDslcNPH41UOH-c",
|
27 |
+
]
|
28 |
+
|
29 |
+
OPENAI_CONFIG = {
|
30 |
+
"model": gpt_mini,
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
LLM_CONFIG = aval_ai
|
model/__init__.py
ADDED
File without changes
|
model/chat.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.propmt.prompt_handler import *
|
2 |
+
from model.llm.llm import *
|
3 |
+
from model.rag.rag_handler import *
|
4 |
+
from config import *
|
5 |
+
|
6 |
+
class Chat:
|
7 |
+
def __init__(self, chat_id, rag_handler) -> None:
|
8 |
+
self.chat_id = chat_id
|
9 |
+
self.message_history = []
|
10 |
+
self.response_history = []
|
11 |
+
self.prompt_handler = Prompt()
|
12 |
+
self.llm = LLM_API_Call("gilas")
|
13 |
+
self.rag_handler = rag_handler
|
14 |
+
|
15 |
+
def response(self, message: str) -> str:
|
16 |
+
self.message_history.append(message)
|
17 |
+
|
18 |
+
info_list = self.rag_handler.get_information(message)
|
19 |
+
prompt = self.prompt_handler.get_prompt(message, info_list)
|
20 |
+
llm_response = self.llm.get_LLM_response(prompt=prompt)
|
21 |
+
|
22 |
+
final_response = f"**Response**:\n{llm_response}\n\n"
|
23 |
+
if info_list:
|
24 |
+
final_response += "The following legal cases and information were retrieved and considered:\n"
|
25 |
+
for i, info in enumerate(info_list):
|
26 |
+
case_text = info['text'].replace("[end]", "")
|
27 |
+
final_response += f"\n**Case {i+1}:** {info['title']}\n{case_text}\n"
|
28 |
+
|
29 |
+
self.response_history.append(final_response)
|
30 |
+
|
31 |
+
return final_response
|
model/controller.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from model.chat import *
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
|
5 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
6 |
+
|
7 |
+
class Controller:
|
8 |
+
def __init__(self) -> None:
|
9 |
+
self.chat_dic = {}
|
10 |
+
self.rag_handler = RAG()
|
11 |
+
|
12 |
+
def handle_message(self,
|
13 |
+
chat_id: int,
|
14 |
+
message: str) -> str:
|
15 |
+
if chat_id not in self.chat_dic:
|
16 |
+
self.chat_dic[chat_id] = Chat(chat_id=chat_id, rag_handler=self.rag_handler)
|
17 |
+
chat = self.chat_dic[chat_id]
|
18 |
+
return chat.response(message)
|
model/llm/llm.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_openai import OpenAI
|
2 |
+
import openai
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import requests
|
6 |
+
from json import JSONDecodeError
|
7 |
+
import time
|
8 |
+
|
9 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
|
10 |
+
|
11 |
+
from config import *
|
12 |
+
|
13 |
+
|
14 |
+
class LLM_API_Call:
|
15 |
+
def __init__(self, type) -> None:
|
16 |
+
if type == "openai":
|
17 |
+
self.llm = OpenAI_API_Call(api_key = LLM_CONFIG[""],
|
18 |
+
model = LLM_CONFIG["model"])
|
19 |
+
elif type == "gilas":
|
20 |
+
self.llm = Gilas_API_Call(api_keys = GILAS_API_KEYS,
|
21 |
+
model = GILAS_CONFIG["model"],
|
22 |
+
base_url = GILAS_CONFIG["base_url"])
|
23 |
+
else:
|
24 |
+
self.llm = OpenAI(
|
25 |
+
**LLM_CONFIG
|
26 |
+
)
|
27 |
+
|
28 |
+
def get_LLM_response(self, prompt: str) -> str:
|
29 |
+
return self.llm.invoke(prompt)
|
30 |
+
|
31 |
+
|
32 |
+
class OpenAI_API_Call:
|
33 |
+
|
34 |
+
def __init__(self, api_key, model="gpt-4"):
|
35 |
+
self.api_key = api_key
|
36 |
+
openai.api_key = api_key
|
37 |
+
self.model = model
|
38 |
+
self.conversation = []
|
39 |
+
|
40 |
+
def add_message(self, role, content):
|
41 |
+
self.conversation.append({"role": role, "content": content})
|
42 |
+
|
43 |
+
def get_response(self):
|
44 |
+
response = openai.ChatCompletion.create(
|
45 |
+
model=self.model,
|
46 |
+
messages=self.conversation
|
47 |
+
)
|
48 |
+
return response['choices'][0]['message']['content']
|
49 |
+
|
50 |
+
def invoke(self, user_input):
|
51 |
+
self.add_message("user", user_input)
|
52 |
+
|
53 |
+
response = self.get_response()
|
54 |
+
|
55 |
+
self.add_message("assistant", response)
|
56 |
+
|
57 |
+
return response
|
58 |
+
|
59 |
+
|
60 |
+
class Gilas_API_Call:
|
61 |
+
def __init__(self, api_keys, base_url, model="gpt-4o-mini"):
|
62 |
+
self.api_keys = api_keys
|
63 |
+
self.base_url = base_url
|
64 |
+
self.model = model
|
65 |
+
self.headers = {
|
66 |
+
"Content-Type": "application/json"
|
67 |
+
}
|
68 |
+
self.conversation = []
|
69 |
+
self.retry_wait_time = 30
|
70 |
+
|
71 |
+
|
72 |
+
def add_message(self, role, content):
|
73 |
+
self.conversation.append({"role": role, "content": content})
|
74 |
+
|
75 |
+
def get_response(self, api_key):
|
76 |
+
self.headers["Authorization"] = f"Bearer {api_key}"
|
77 |
+
|
78 |
+
data = {
|
79 |
+
"model": self.model,
|
80 |
+
"messages": self.conversation
|
81 |
+
}
|
82 |
+
|
83 |
+
response = requests.post(
|
84 |
+
url=f"{self.base_url}/chat/completions",
|
85 |
+
headers=self.headers,
|
86 |
+
json=data
|
87 |
+
)
|
88 |
+
|
89 |
+
if response.status_code == 200:
|
90 |
+
try:
|
91 |
+
return response.json()['choices'][0]['message']['content']
|
92 |
+
except (KeyError, IndexError, ValueError) as e:
|
93 |
+
raise Exception(f"Unexpected API response format: {e}")
|
94 |
+
else:
|
95 |
+
raise Exception(f"Gilas API call failed: {response.status_code} - {response.text}")
|
96 |
+
|
97 |
+
def invoke(self, user_input, max_retries=3):
|
98 |
+
self.add_message("user", user_input)
|
99 |
+
|
100 |
+
retries = 0
|
101 |
+
while retries < max_retries:
|
102 |
+
for i, api_key in enumerate(self.api_keys):
|
103 |
+
try:
|
104 |
+
response = self.get_response(api_key)
|
105 |
+
self.add_message("assistant", response)
|
106 |
+
return response
|
107 |
+
except (JSONDecodeError, Exception) as e:
|
108 |
+
print(f"Error encountered with API key {api_key}: {e}. Trying next key...")
|
109 |
+
# Sleep before trying next key
|
110 |
+
if i == len(self.api_keys) - 1:
|
111 |
+
print(f"All keys failed. Retrying oldest key after {self.retry_wait_time} seconds...")
|
112 |
+
time.sleep(self.retry_wait_time)
|
113 |
+
self.retry_wait_time += 30 # Increase wait time for next retry
|
114 |
+
|
115 |
+
retries += 1
|
116 |
+
|
117 |
+
raise Exception(f"Failed to get a valid response after {max_retries} retries.")
|
model/processor/case_crawler.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import os
|
4 |
+
import warnings
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
class Crawler:
|
8 |
+
# This is used for vote separating when list of vote concatenation in string
|
9 |
+
vote_splitter = " |split| "
|
10 |
+
|
11 |
+
def __init__(self, base_url: str, list_url:str ,
|
12 |
+
base_vote_url:str , models_path: str , result_path:str):
|
13 |
+
if base_url == "":
|
14 |
+
self.base_url ="https://ara.jri.ac.ir/"
|
15 |
+
else:
|
16 |
+
self.base_url = base_url
|
17 |
+
|
18 |
+
if list_url == "":
|
19 |
+
self.list_url ="https://ara.jri.ac.ir/Judge/Index"
|
20 |
+
else:
|
21 |
+
self.list_url = list_url
|
22 |
+
|
23 |
+
if base_vote_url == "":
|
24 |
+
self.base_vote_url ="https://ara.jri.ac.ir/Judge/Text/"
|
25 |
+
else:
|
26 |
+
self.base_vote_url = base_vote_url
|
27 |
+
|
28 |
+
if models_path == "":
|
29 |
+
self.models_path ="Models/"
|
30 |
+
else:
|
31 |
+
self.models_path = models_path
|
32 |
+
self.pos_model_path = os.path.join(models_path, "postagger.model")
|
33 |
+
self.chunker_path = os.path.join(models_path, "chunker.model")
|
34 |
+
|
35 |
+
if result_path == "":
|
36 |
+
self.result_path = "Resource/"
|
37 |
+
else:
|
38 |
+
self.result_path = result_path
|
39 |
+
|
40 |
+
self.merges_vote_path = os.path.join(result_path, 'merged_vote.txt')
|
41 |
+
self.clean_vote_path = os.path.join(result_path, 'clean_vote.txt')
|
42 |
+
self.clean_vote_path_csv = os.path.join(result_path, 'clean_vote.csv')
|
43 |
+
self.selected_vote_path = os.path.join(result_path, 'selected_vote.txt')
|
44 |
+
self.law_list_path = os.path.join(result_path, 'law_list.txt')
|
45 |
+
self.law_clean_list_path = os.path.join(result_path, 'law_clean_list.txt')
|
46 |
+
self.vote_stop_path = os.path.join(result_path, "vote_stopwords.txt")
|
47 |
+
self.law_stop_path = os.path.join(result_path, "law_stopwords.txt")
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def check_valid_vote(html_soup: BeautifulSoup) -> bool:
|
51 |
+
# Extract title for detection of non-valid vote
|
52 |
+
h1_element = html_soup.find('h1', class_='Title3D')
|
53 |
+
if h1_element is None:
|
54 |
+
return False
|
55 |
+
span_text = h1_element.find('span').text # Text within the <span> tag
|
56 |
+
full_text = h1_element.text # Full text within the <h1> element
|
57 |
+
text_after_span = full_text.split(span_text)[-1].strip() # Extract text after the </span> tag
|
58 |
+
return len(text_after_span) > 0
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def html_data_extractor(html_soup: BeautifulSoup, vote_splitter: str) -> str:
|
62 |
+
vote_text = html_soup.find('div', id='treeText', class_='BackText')
|
63 |
+
title = html_soup.find('h1', class_='Title3D')
|
64 |
+
info = html_soup.find('td', valign="top", class_="font-size-small")
|
65 |
+
# for separating each vote in file use vote_splitter
|
66 |
+
vote_df = str(title) + str(info) + str(vote_text) + vote_splitter
|
67 |
+
return vote_df
|
68 |
+
|
69 |
+
def vote_crawler(self, start: int, end: int, separator: int):
|
70 |
+
counter = 0 # For counting right votes crawled
|
71 |
+
result_list = []
|
72 |
+
warnings.filterwarnings("ignore")
|
73 |
+
# Loop for sending request to get each vote page
|
74 |
+
for i in tqdm(range(start, end)):
|
75 |
+
# Save every separator records gotten in .txt format
|
76 |
+
if (counter % separator == 0 and counter > 0) or i == end - 1:
|
77 |
+
text_file = open(os.path.join(self.result_path, f'vote{i}.txt'), "w", encoding='utf-8')
|
78 |
+
text_file.write(''.join(result_list))
|
79 |
+
text_file.close()
|
80 |
+
result_list = []
|
81 |
+
url = self.base_vote_url + f"{i}"
|
82 |
+
response = requests.get(url, verify=False)
|
83 |
+
# Change format for Persian text
|
84 |
+
response.encoding = 'utf-8'
|
85 |
+
resp_text = response.text
|
86 |
+
html_soup = BeautifulSoup(resp_text, 'html.parser')
|
87 |
+
if response.ok and self.check_valid_vote(html_soup):
|
88 |
+
counter += 1
|
89 |
+
vote_df = self.html_data_extractor(html_soup, self.vote_splitter)
|
90 |
+
result_list.append(vote_df)
|
91 |
+
|
92 |
+
def merge_out_txt(self) -> None:
|
93 |
+
|
94 |
+
with open(self.result_path, 'w', encoding='utf-8') as outfile:
|
95 |
+
for filename in os.listdir(self.merges_vote_path):
|
96 |
+
if filename.startswith("vote") and filename.endswith('.txt'): # Only merge vote .txt
|
97 |
+
with open(os.path.join(self.merges_vote_path, filename), 'r', encoding='utf-8') as infile:
|
98 |
+
outfile.write(infile.read())
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
models_path = input("Enter the models path (initial value = https://ara.jri.ac.ir/): ")
|
102 |
+
result_path = input("Enter the result path (initial value = https://ara.jri.ac.ir/Judge/Index): ")
|
103 |
+
base_url = input("Enter the base URL (initial value = https://ara.jri.ac.ir/Judge/Text/): ")
|
104 |
+
list_url = input("Enter the list URL (initial value = Models/ ): ")
|
105 |
+
base_vote_url = input("Enter the base vote URL (initial value = Resource/ ): ")
|
106 |
+
|
107 |
+
crawler_instance = Crawler(models_path=models_path, result_path=result_path, base_url=base_url, list_url=list_url, base_vote_url=base_vote_url)
|
108 |
+
start = int(input("Enter the start value for vote crawling: "))
|
109 |
+
end = int(input("Enter the end value for vote crawling: "))
|
110 |
+
separator = int(input("Enter the separator value for vote crawling: "))
|
111 |
+
|
112 |
+
crawler_instance.vote_crawler(start=start, end=end, separator=separator)
|
113 |
+
crawler_instance.merge_out_txt()
|
model/processor/database_Chunker.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/processor/law_provider.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import re
|
3 |
+
|
4 |
+
class LawTxetPreProcessor():
|
5 |
+
|
6 |
+
def __init__(self, law_texts: list) -> None:
|
7 |
+
self._law_texets = law_texts
|
8 |
+
self._law_name_df = pd.DataFrame(columns=["law_index", "law_name"])
|
9 |
+
self._madeh_df = pd.DataFrame(columns=["law_index", "madeh_index", "madeh_text"])
|
10 |
+
self._is_df = False
|
11 |
+
|
12 |
+
def build_df(self):
|
13 |
+
title_list = []
|
14 |
+
madeh_list = []
|
15 |
+
madeh_index = []
|
16 |
+
law_index = []
|
17 |
+
counter = 0
|
18 |
+
for text in self._law_texets:
|
19 |
+
title = self.title_extractor(text)
|
20 |
+
title_list.append(title)
|
21 |
+
temp_madeh_list = self.madeh_extractor(text, title == "قانون اساسی جمهوری اسلامی ایران")
|
22 |
+
law_index.extend([counter for i in temp_madeh_list])
|
23 |
+
madeh_index.extend([i+1 for i in range(len(temp_madeh_list))])
|
24 |
+
madeh_list.extend(temp_madeh_list)
|
25 |
+
counter += 1
|
26 |
+
law_index_list = [i for i in range(counter)]
|
27 |
+
self._madeh_df = pd.DataFrame({"law_index": law_index,
|
28 |
+
"madeh_index": madeh_index,
|
29 |
+
"madeh_text": madeh_list})
|
30 |
+
self._law_name_df = pd.DataFrame({"law_index": law_index_list,
|
31 |
+
"law_name": title_list})
|
32 |
+
|
33 |
+
def title_extractor(self, law_text: str) -> str:
|
34 |
+
first_newline_index = law_text.find('\n')
|
35 |
+
return law_text[:first_newline_index]
|
36 |
+
|
37 |
+
def madeh_extractor(self, law_text: str, is_asl:False)-> list:
|
38 |
+
result = []
|
39 |
+
pattern = r"(^.{0,1}اصل )" if is_asl else r"(^.{0,1}ماده)"
|
40 |
+
removed_regex = r"❯.*\n"
|
41 |
+
notvalid_pattern = r"(^.{0,1}ماده.*مکرر\n)"
|
42 |
+
cleaned_text = re.sub(removed_regex, "", law_text)
|
43 |
+
matches = re.finditer(pattern, cleaned_text, flags=re.MULTILINE)
|
44 |
+
not_valid_matches = re.finditer(notvalid_pattern, cleaned_text, flags=re.MULTILINE)
|
45 |
+
indices = [match.start() for match in matches]
|
46 |
+
not_valid_indices = [match.start() for match in not_valid_matches]
|
47 |
+
valid_indices = [item for item in indices if item not in not_valid_indices]
|
48 |
+
for i in range(len(valid_indices)):
|
49 |
+
start = valid_indices[i]
|
50 |
+
if i != len(valid_indices)-1:
|
51 |
+
end = valid_indices[i+1]
|
52 |
+
result.append(cleaned_text[start:end])
|
53 |
+
else:
|
54 |
+
result.append(cleaned_text[start:])
|
55 |
+
return result
|
56 |
+
|
57 |
+
|
58 |
+
def get_df(self) -> pd.DataFrame:
|
59 |
+
if not self._is_df:
|
60 |
+
self.build_df()
|
61 |
+
return self._law_name_df, self._madeh_df
|
model/processor/pre_process.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/processor/retrieval_rag_nlp_project.ipynb:Zone.Identifier
ADDED
Binary file (27 Bytes). View file
|
|
model/propmt/__init__.py
ADDED
File without changes
|
model/propmt/prompt_handler.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
class Prompt:
|
4 |
+
|
5 |
+
def get_prompt(self, message:str, info_list: List) -> str:
|
6 |
+
prompt = f"As a user, I want to ask you the following legal question:\n{message}\n\n"
|
7 |
+
|
8 |
+
if info_list:
|
9 |
+
prompt += "Here are some relevant legal cases and information you should consider:\n"
|
10 |
+
for i, info in enumerate(info_list):
|
11 |
+
prompt += f"case {i+1}:\n{info['title']}\n{info['text']}\n"
|
12 |
+
|
13 |
+
prompt += "\nBased on the provided information, please respond in Persian(Farsi) with a concise legal analysis.\
|
14 |
+
Ensure that your response is as summarized and clear as possible. (one paragraph)"
|
15 |
+
|
16 |
+
return prompt
|
model/rag/__init__.py
ADDED
File without changes
|
model/rag/rag_handler.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import chromadb
|
3 |
+
from transformers import AutoTokenizer, AutoModel
|
4 |
+
from chromadb.config import Settings
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
from tqdm import tqdm
|
9 |
+
import os
|
10 |
+
from hazm import *
|
11 |
+
|
12 |
+
|
13 |
+
class RAG:
|
14 |
+
def __init__(self,
|
15 |
+
model_name: str = "HooshvareLab/bert-base-parsbert-uncased",
|
16 |
+
collection_name: str = "legal_cases",
|
17 |
+
persist_directory: str = "chromadb_collections/",
|
18 |
+
top_k: int = 2
|
19 |
+
) -> None:
|
20 |
+
|
21 |
+
self.cases_df = pd.read_csv('processed_cases.csv')
|
22 |
+
|
23 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
+
self.model = AutoModel.from_pretrained(model_name)
|
25 |
+
self.normalizer = Normalizer()
|
26 |
+
self.top_k = top_k
|
27 |
+
|
28 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
+
self.model.to(self.device)
|
30 |
+
|
31 |
+
self.client = chromadb.PersistentClient(path=persist_directory)
|
32 |
+
|
33 |
+
self.collection = self.client.get_collection(name=collection_name)
|
34 |
+
|
35 |
+
def query_pre_process(self, query: str) -> str:
|
36 |
+
return self.normalizer.normalize(query)
|
37 |
+
|
38 |
+
def embed_single_text(self, text: str) -> np.ndarray:
|
39 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
40 |
+
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
41 |
+
with torch.no_grad():
|
42 |
+
outputs = self.model(**inputs)
|
43 |
+
return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
|
44 |
+
|
45 |
+
|
46 |
+
def extract_case_title_from_df(self, case_id: str) -> str:
|
47 |
+
|
48 |
+
case_id_int = int(case_id.split("_")[1])
|
49 |
+
|
50 |
+
try:
|
51 |
+
case_title = self.cases_df.loc[case_id_int, 'title']
|
52 |
+
return case_title
|
53 |
+
except KeyError:
|
54 |
+
return "Case ID not found in DataFrame."
|
55 |
+
|
56 |
+
def extract_case_text_from_df(self, case_id: str) -> str:
|
57 |
+
case_id_int = int(case_id.split("_")[1])
|
58 |
+
|
59 |
+
try:
|
60 |
+
case_text = self.cases_df.loc[case_id_int, 'text']
|
61 |
+
return case_text
|
62 |
+
except KeyError:
|
63 |
+
return "Case ID not found in DataFrame."
|
64 |
+
|
65 |
+
def retrieve_relevant_cases(self, query_text: str) -> List[str]:
|
66 |
+
normalized_query_text = self.query_pre_process(query_text)
|
67 |
+
|
68 |
+
query_embedding = self.embed_single_text(normalized_query_text)
|
69 |
+
query_embedding_list = query_embedding.tolist()
|
70 |
+
|
71 |
+
results = self.collection.query(
|
72 |
+
query_embeddings=[query_embedding_list],
|
73 |
+
n_results=self.top_k
|
74 |
+
)
|
75 |
+
|
76 |
+
retrieved_cases = []
|
77 |
+
for i in range(len(results['metadatas'][0])):
|
78 |
+
case_id = results['ids'][0][i]
|
79 |
+
case_text = self.extract_case_text_from_df(case_id)
|
80 |
+
case_title = self.extract_case_title_from_df(case_id)
|
81 |
+
retrieved_cases.append({
|
82 |
+
"text": case_text,
|
83 |
+
"title": case_title
|
84 |
+
})
|
85 |
+
|
86 |
+
return retrieved_cases
|
87 |
+
|
88 |
+
def get_information(self, query: str) -> List[str]:
|
89 |
+
return self.retrieve_relevant_cases(query)
|
requirements.txt
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#dataset
|
2 |
+
datasets
|
3 |
+
pandas
|
4 |
+
numpy
|
5 |
+
indexed_gzip
|
6 |
+
# json
|
7 |
+
matrix-nio[e2e]
|
8 |
+
opsdroid
|
9 |
+
python-dotenv
|
10 |
+
|
11 |
+
BeautifulSoup4
|
12 |
+
requests
|
13 |
+
tqdm
|
14 |
+
|
15 |
+
hazm
|
16 |
+
spacy
|
17 |
+
|
18 |
+
rank_bm25
|
19 |
+
openai
|
20 |
+
gradio
|
21 |
+
|
22 |
+
langchain_openai
|
23 |
+
sentence-transformers
|
24 |
+
chromadb
|
25 |
+
rarfile
|
26 |
+
patool
|