DocUA commited on
Commit
45fc0a6
·
1 Parent(s): b3880cc

refactoring + add run.py

Browse files
Files changed (6) hide show
  1. README.md +1 -1
  2. interface.py +8 -15
  3. main.py +43 -7
  4. requirements.txt +7 -8
  5. run.py +15 -0
  6. storage.py +43 -28
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: "Legal Position Search (without AI) BM25S long & short text"
3
  emoji: "⚖️"
4
  colorFrom: "blue"
5
  colorTo: "green"
 
1
  ---
2
+ title: "Legal Position Search (without AI) BM25S long & short text + vector search ChromaDB"
3
  emoji: "⚖️"
4
  colorFrom: "blue"
5
  colorTo: "green"
interface.py CHANGED
@@ -1,23 +1,14 @@
1
  import gradio as gr
2
  import re
3
- from typing import Callable, Awaitable, Any, Tuple
 
4
 
5
-
6
- def create_gradio_interface(search_action: Callable[[str], Awaitable[Tuple[str, Any]]]) -> gr.Blocks:
7
- """
8
- Creates Gradio interface for legal search system.
9
-
10
- Args:
11
- search_action: Async function that performs the search and returns (output_text, nodes)
12
-
13
- Returns:
14
- gr.Blocks: Configured Gradio interface
15
- """
16
  with gr.Blocks() as app:
17
  gr.Markdown("# Знаходьте правові позиції Верховного Суду")
18
 
19
  input_field = gr.Textbox(
20
- label="Введіть текст для пошуку або посилання на судове рішення (у форматі https://reyestr.court.gov.ua/Review/{doc_id})",
21
  lines=1
22
  )
23
  search_button = gr.Button("Пошук", interactive=False)
@@ -25,8 +16,10 @@ def create_gradio_interface(search_action: Callable[[str], Awaitable[Tuple[str,
25
  search_output = gr.Markdown(label="Результат пошуку")
26
  state_nodes = gr.State()
27
 
 
 
 
28
  def update_button_state(text: str) -> Tuple[gr.update, gr.update]:
29
- """Updates button state and warning message based on input text."""
30
  text = text.strip()
31
  if not text:
32
  return gr.update(value="Пошук", interactive=False), gr.update(visible=False)
@@ -41,7 +34,7 @@ def create_gradio_interface(search_action: Callable[[str], Awaitable[Tuple[str,
41
  return gr.update(value="Пошук за текстом", interactive=True), gr.update(visible=False)
42
 
43
  search_button.click(
44
- fn=search_action,
45
  inputs=input_field,
46
  outputs=[search_output, state_nodes]
47
  )
 
1
  import gradio as gr
2
  import re
3
+ from typing import Callable, Any, Tuple
4
+ import asyncio
5
 
6
+ def create_gradio_interface(search_action: Callable) -> gr.Blocks:
 
 
 
 
 
 
 
 
 
 
7
  with gr.Blocks() as app:
8
  gr.Markdown("# Знаходьте правові позиції Верховного Суду")
9
 
10
  input_field = gr.Textbox(
11
+ label="Введіть текст або посилання на судове рішення",
12
  lines=1
13
  )
14
  search_button = gr.Button("Пошук", interactive=False)
 
16
  search_output = gr.Markdown(label="Результат пошуку")
17
  state_nodes = gr.State()
18
 
19
+ async def async_wrapper(text):
20
+ return await search_action(text)
21
+
22
  def update_button_state(text: str) -> Tuple[gr.update, gr.update]:
 
23
  text = text.strip()
24
  if not text:
25
  return gr.update(value="Пошук", interactive=False), gr.update(visible=False)
 
34
  return gr.update(value="Пошук за текстом", interactive=True), gr.update(visible=False)
35
 
36
  search_button.click(
37
+ fn=async_wrapper,
38
  inputs=input_field,
39
  outputs=[search_output, state_nodes]
40
  )
main.py CHANGED
@@ -5,11 +5,15 @@ from pathlib import Path
5
 
6
  import nest_asyncio
7
  import requests
 
8
  from bs4 import BeautifulSoup
9
  from dotenv import load_dotenv
10
- from llama_index.core import Settings
11
  from llama_index.core.retrievers import QueryFusionRetriever
 
 
12
  from llama_index.retrievers.bm25 import BM25Retriever
 
13
 
14
  from interface import create_gradio_interface
15
  from storage import StorageManager
@@ -18,22 +22,28 @@ from storage import StorageManager
18
  load_dotenv()
19
 
20
  # Basic settings
21
- Settings.similarity_top_k = 20 # type: ignore
22
  Settings.llm = None
 
 
 
23
 
24
  # Storage settings
25
  LOCAL_DIR = Path("Save_Index_Local")
26
  BUCKET_NAME = "legal-position"
27
  PREFIX_RETRIEVER = "Save_Index_Ivan/"
 
28
 
29
  # Index parameters
30
  INDEX_NAME_BM25_LONG = "bm25_retriever"
31
  INDEX_NAME_BM25_SHORT = "bm25_retriever_short"
32
  REQUIRED_FILES = [INDEX_NAME_BM25_LONG, INDEX_NAME_BM25_SHORT]
 
33
 
34
  # Global retrievers
35
  retriever_bm25_long = None
36
  retriever_bm25_short = None
 
37
 
38
  # Initialize nest_asyncio for async operations
39
  nest_asyncio.apply()
@@ -100,12 +110,12 @@ def initialize_components():
100
  )
101
 
102
  # Check and sync data
103
- if not storage_manager.sync_data(REQUIRED_FILES):
104
  raise FileNotFoundError("Failed to obtain required files")
105
 
106
- global retriever_bm25_long, retriever_bm25_short
107
 
108
- # Initialize retrievers
109
  bm25_retriever_long = BM25Retriever.from_persist_dir(
110
  str(LOCAL_DIR / INDEX_NAME_BM25_LONG)
111
  )
@@ -120,8 +130,27 @@ def initialize_components():
120
  use_async=True,
121
  )
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  retriever_bm25_short = QueryFusionRetriever(
124
- [bm25_retriever_short],
125
  similarity_top_k=Settings.similarity_top_k,
126
  num_queries=1,
127
  use_async=True,
@@ -220,7 +249,10 @@ def main():
220
  if initialize_components():
221
  print("Components initialized successfully!")
222
  app = create_gradio_interface(main_search_action)
223
- app.launch(share=True)
 
 
 
224
  else:
225
  print(
226
  "Failed to initialize components. Please check the paths and try again.",
@@ -228,6 +260,10 @@ def main():
228
  )
229
  sys.exit(1)
230
 
 
 
 
 
231
 
232
  if __name__ == "__main__":
233
  main()
 
5
 
6
  import nest_asyncio
7
  import requests
8
+ import chromadb
9
  from bs4 import BeautifulSoup
10
  from dotenv import load_dotenv
11
+ from llama_index.core import Settings, StorageContext
12
  from llama_index.core.retrievers import QueryFusionRetriever
13
+ from llama_index.vector_stores.chroma import ChromaVectorStore
14
+ from llama_index.core import VectorStoreIndex
15
  from llama_index.retrievers.bm25 import BM25Retriever
16
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
17
 
18
  from interface import create_gradio_interface
19
  from storage import StorageManager
 
22
  load_dotenv()
23
 
24
  # Basic settings
25
+ Settings.similarity_top_k = 20
26
  Settings.llm = None
27
+ Settings.embed_model = HuggingFaceEmbedding(
28
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
29
+ )
30
 
31
  # Storage settings
32
  LOCAL_DIR = Path("Save_Index_Local")
33
  BUCKET_NAME = "legal-position"
34
  PREFIX_RETRIEVER = "Save_Index_Ivan/"
35
+ CHROMA_DIR = "chroma_db_hf"
36
 
37
  # Index parameters
38
  INDEX_NAME_BM25_LONG = "bm25_retriever"
39
  INDEX_NAME_BM25_SHORT = "bm25_retriever_short"
40
  REQUIRED_FILES = [INDEX_NAME_BM25_LONG, INDEX_NAME_BM25_SHORT]
41
+ REQUIRED_DIRS = [CHROMA_DIR]
42
 
43
  # Global retrievers
44
  retriever_bm25_long = None
45
  retriever_bm25_short = None
46
+ retriever_chroma = None
47
 
48
  # Initialize nest_asyncio for async operations
49
  nest_asyncio.apply()
 
110
  )
111
 
112
  # Check and sync data
113
+ if not storage_manager.sync_data(REQUIRED_FILES, REQUIRED_DIRS):
114
  raise FileNotFoundError("Failed to obtain required files")
115
 
116
+ global retriever_bm25_long, retriever_bm25_short, retriever_chroma
117
 
118
+ # Initialize BM25 retrievers
119
  bm25_retriever_long = BM25Retriever.from_persist_dir(
120
  str(LOCAL_DIR / INDEX_NAME_BM25_LONG)
121
  )
 
130
  use_async=True,
131
  )
132
 
133
+ # Initialize ChromaDB
134
+ db_chroma = chromadb.PersistentClient(path=str(LOCAL_DIR / CHROMA_DIR))
135
+ chroma_collection = db_chroma.get_or_create_collection(name="legal_position")
136
+ chroma_vector_store = ChromaVectorStore(
137
+ chroma_collection=chroma_collection,
138
+ embedding_model=Settings.embed_model
139
+ )
140
+ storage_context = StorageContext.from_defaults(vector_store=chroma_vector_store)
141
+
142
+ # Create vector store index
143
+ vector_index = VectorStoreIndex.from_vector_store(
144
+ chroma_vector_store,
145
+ storage_context=storage_context,
146
+ embed_model=Settings.embed_model
147
+ )
148
+
149
+ retriever_chroma = vector_index.as_retriever(similarity_top_k=Settings.similarity_top_k)
150
+
151
+ # Create hybrid retriever for short texts
152
  retriever_bm25_short = QueryFusionRetriever(
153
+ [bm25_retriever_short, retriever_chroma],
154
  similarity_top_k=Settings.similarity_top_k,
155
  num_queries=1,
156
  use_async=True,
 
249
  if initialize_components():
250
  print("Components initialized successfully!")
251
  app = create_gradio_interface(main_search_action)
252
+ app.queue(max_size=1).launch(
253
+ show_error=True,
254
+ share=True
255
+ )
256
  else:
257
  print(
258
  "Failed to initialize components. Please check the paths and try again.",
 
260
  )
261
  sys.exit(1)
262
 
263
+ if __name__ == "__main__":
264
+ # Видаляємо nest_asyncio.apply()
265
+ main()
266
+
267
 
268
  if __name__ == "__main__":
269
  main()
requirements.txt CHANGED
@@ -1,14 +1,13 @@
1
  llama-index
2
  llama-index-readers-file
3
- llama-index-vector-stores-faiss
4
  llama-index-retrievers-bm25
5
- openai
6
- faiss-cpu
7
- llama-index-embeddings-openai
8
- llama-index-llms-openai
9
- gradio
10
  beautifulsoup4
11
- nest-asyncio
12
  boto3
13
  python-dotenv
14
- openpyxl
 
 
 
 
1
  llama-index
2
  llama-index-readers-file
 
3
  llama-index-retrievers-bm25
4
+ llama_index-vector-stores-chroma
5
+ llama-index-embeddings-huggingface
6
+
 
 
7
  beautifulsoup4
 
8
  boto3
9
  python-dotenv
10
+
11
+ gradio==4.44.1
12
+ nest-asyncio>=1.5.6
13
+ uvicorn>=0.22.0
run.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from main import initialize_components, main_search_action
3
+ from interface import create_gradio_interface
4
+ import uvicorn
5
+ import gradio as gr
6
+
7
+ app = FastAPI()
8
+
9
+ if initialize_components():
10
+ print("Components initialized successfully!")
11
+ gr_app = create_gradio_interface(main_search_action)
12
+ app = gr.mount_gradio_app(app, gr_app, path="/")
13
+
14
+ if __name__ == "__main__":
15
+ uvicorn.run(app, host="0.0.0.0", port=7860)
storage.py CHANGED
@@ -1,7 +1,9 @@
 
1
  import os
2
  from pathlib import Path
3
  import boto3
4
- from typing import Optional
 
5
 
6
 
7
  class StorageManager:
@@ -28,31 +30,36 @@ class StorageManager:
28
  region_name=region_name
29
  )
30
 
31
- def check_local_data(self, required_files: list[str]) -> bool:
32
  """
33
- Check if all required files exist in local directory.
34
 
35
  Args:
36
  required_files: List of required file names
 
37
 
38
  Returns:
39
- bool: True if all files exist, False otherwise
40
  """
41
  self.local_dir.mkdir(parents=True, exist_ok=True)
42
 
 
43
  for file_name in required_files:
44
  if not (self.local_dir / file_name).exists():
45
  print(f"Missing required file: {file_name}")
46
  return False
 
 
 
 
 
 
 
 
47
  return True
48
 
49
  def download_s3_file(self, s3_key: str, local_path: Path) -> bool:
50
- """
51
- Download single file from S3.
52
-
53
- Returns:
54
- bool: True if download successful, False otherwise
55
- """
56
  try:
57
  self.s3_client.download_file(self.bucket_name, s3_key, str(local_path))
58
  print(f"Downloaded: {s3_key} -> {local_path}")
@@ -61,24 +68,26 @@ class StorageManager:
61
  print(f"Error downloading {s3_key}: {str(e)}")
62
  return False
63
 
64
- def download_s3_folder(self) -> bool:
65
  """
66
  Download entire folder from S3 to local directory.
67
 
68
- Returns:
69
- bool: True if download successful, False otherwise
70
  """
71
  try:
72
  if not self.use_s3:
73
  raise ValueError("S3 credentials not configured")
74
 
 
 
75
  response = self.s3_client.list_objects_v2(
76
  Bucket=self.bucket_name,
77
- Prefix=self.prefix
78
  )
79
 
80
  if 'Contents' not in response:
81
- print(f"No files found in S3 bucket {self.bucket_name} with prefix {self.prefix}")
82
  return False
83
 
84
  success = True
@@ -87,7 +96,8 @@ class StorageManager:
87
  if s3_key.endswith('/'):
88
  continue
89
 
90
- local_file_path = self.local_dir / Path(s3_key).relative_to(self.prefix)
 
91
  local_file_path.parent.mkdir(parents=True, exist_ok=True)
92
 
93
  if not self.download_s3_file(s3_key, local_file_path):
@@ -98,28 +108,33 @@ class StorageManager:
98
  print(f"Error downloading S3 folder: {str(e)}")
99
  return False
100
 
101
- def sync_data(self, required_files: list[str]) -> bool:
102
  """
103
  Check local data and sync from S3 if needed.
104
 
105
  Args:
106
  required_files: List of required file names
 
107
 
108
  Returns:
109
- bool: True if all required files are available after sync
110
  """
111
- # First check if we have all files locally
112
- if self.check_local_data(required_files):
113
- print("All required files found locally")
114
  return True
115
 
116
- # If not all files exist locally and S3 is configured, try to download
117
  if self.use_s3:
118
- print("Downloading required files from S3...")
119
- if self.download_s3_folder():
120
- # Verify files after download
121
- return self.check_local_data(required_files)
122
- return False
 
 
 
 
 
 
123
 
124
- print("Missing required files and S3 is not configured")
125
  return False
 
1
+ # storage.py
2
  import os
3
  from pathlib import Path
4
  import boto3
5
+ import shutil
6
+ from typing import Optional, List
7
 
8
 
9
  class StorageManager:
 
30
  region_name=region_name
31
  )
32
 
33
+ def check_local_data(self, required_files: List[str], required_dirs: List[str] = None) -> bool:
34
  """
35
+ Check if all required files and directories exist locally.
36
 
37
  Args:
38
  required_files: List of required file names
39
+ required_dirs: List of required directory names
40
 
41
  Returns:
42
+ bool: True if all required data exists, False otherwise
43
  """
44
  self.local_dir.mkdir(parents=True, exist_ok=True)
45
 
46
+ # Check files
47
  for file_name in required_files:
48
  if not (self.local_dir / file_name).exists():
49
  print(f"Missing required file: {file_name}")
50
  return False
51
+
52
+ # Check directories
53
+ if required_dirs:
54
+ for dir_name in required_dirs:
55
+ if not (self.local_dir / dir_name).is_dir():
56
+ print(f"Missing required directory: {dir_name}")
57
+ return False
58
+
59
  return True
60
 
61
  def download_s3_file(self, s3_key: str, local_path: Path) -> bool:
62
+ """Download single file from S3."""
 
 
 
 
 
63
  try:
64
  self.s3_client.download_file(self.bucket_name, s3_key, str(local_path))
65
  print(f"Downloaded: {s3_key} -> {local_path}")
 
68
  print(f"Error downloading {s3_key}: {str(e)}")
69
  return False
70
 
71
+ def download_s3_folder(self, specific_prefix: str = None) -> bool:
72
  """
73
  Download entire folder from S3 to local directory.
74
 
75
+ Args:
76
+ specific_prefix: Optional specific prefix to download only a subfolder
77
  """
78
  try:
79
  if not self.use_s3:
80
  raise ValueError("S3 credentials not configured")
81
 
82
+ prefix = f"{self.prefix}{specific_prefix}" if specific_prefix else self.prefix
83
+
84
  response = self.s3_client.list_objects_v2(
85
  Bucket=self.bucket_name,
86
+ Prefix=prefix
87
  )
88
 
89
  if 'Contents' not in response:
90
+ print(f"No files found in S3 bucket {self.bucket_name} with prefix {prefix}")
91
  return False
92
 
93
  success = True
 
96
  if s3_key.endswith('/'):
97
  continue
98
 
99
+ relative_path = Path(s3_key).relative_to(self.prefix)
100
+ local_file_path = self.local_dir / relative_path
101
  local_file_path.parent.mkdir(parents=True, exist_ok=True)
102
 
103
  if not self.download_s3_file(s3_key, local_file_path):
 
108
  print(f"Error downloading S3 folder: {str(e)}")
109
  return False
110
 
111
+ def sync_data(self, required_files: List[str], required_dirs: List[str] = None) -> bool:
112
  """
113
  Check local data and sync from S3 if needed.
114
 
115
  Args:
116
  required_files: List of required file names
117
+ required_dirs: List of required directory names
118
 
119
  Returns:
120
+ bool: True if all required data is available after sync
121
  """
122
+ if self.check_local_data(required_files, required_dirs):
123
+ print("All required files and directories found locally")
 
124
  return True
125
 
 
126
  if self.use_s3:
127
+ print("Downloading required data from S3...")
128
+ if not self.download_s3_folder():
129
+ return False
130
+
131
+ # If we have specific directories to sync
132
+ if required_dirs:
133
+ for dir_name in required_dirs:
134
+ if not self.download_s3_folder(dir_name):
135
+ return False
136
+
137
+ return self.check_local_data(required_files, required_dirs)
138
 
139
+ print("Missing required data and S3 is not configured")
140
  return False