Raj-Maharajwala commited on
Commit
5858332
·
verified ·
1 Parent(s): 0db5f78

Upload 2 files

Browse files
inference_open-insurance-llm-gguf.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import sys
5
+ import psutil
6
+ import datetime
7
+ import traceback
8
+ import multiprocessing
9
+ from pathlib import Path
10
+ from llama_cpp import Llama
11
+ from typing import Optional, Dict, Any
12
+ from dataclasses import dataclass
13
+ from rich.console import Console
14
+ from rich.logging import RichHandler
15
+ from contextlib import contextmanager
16
+ from rich.traceback import install
17
+ from rich.theme import Theme
18
+ from huggingface_hub import hf_hub_download
19
+ # from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
20
+ # Install rich traceback handler
21
+ install(show_locals=True)
22
+
23
+ @dataclass
24
+ class ModelConfig:
25
+ model_name: str = "Raj-Maharajwala/Open-Insurance-LLM-Llama3-8B-GGUF"
26
+ model_file: str = "open-insurance-llm-q4_k_m.gguf"
27
+ # model_file: str = "open-insurance-llm-q8_0.gguf"
28
+ # model_file: str = "open-insurance-llm-q5_k_m.gguf"
29
+ max_tokens: int = 1000
30
+ top_k: int = 15
31
+ top_p: float = 0.2
32
+ repeat_penalty: float = 1.2
33
+ num_beams: int = 4
34
+ n_gpu_layers: int = -2 #-2 # -1 for complete GPU usage
35
+ temperature: float = 0.1 # Coherent(0.1) vs Creativity(0.8)
36
+ n_ctx: int = 2048 # 2048 - 8192 -> As per Llama 3 Full Capacity
37
+ n_batch: int = 256
38
+ verbose: bool = False
39
+ use_mmap: bool = False
40
+ use_mlock: bool = True
41
+ offload_kqv: bool =True
42
+
43
+ class CustomFormatter(logging.Formatter):
44
+ """Enhanced formatter with detailed context for different log levels"""
45
+ FORMATS = {
46
+ logging.DEBUG: "🔍 %(asctime)s - %(name)s - [%(filename)s:%(lineno)d] - %(levelname)s - %(message)s",
47
+ logging.INFO: "ℹ️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s",
48
+ logging.WARNING: "⚠️ %(asctime)s - %(name)s - [%(funcName)s] - %(levelname)s - %(message)s\nContext: %(pathname)s",
49
+ logging.ERROR: "❌ %(asctime)s - %(name)s - [%(funcName)s:%(lineno)d] - %(levelname)s - %(message)s",
50
+ logging.CRITICAL: """🚨 %(asctime)s - %(name)s - %(levelname)s
51
+ Location: %(pathname)s:%(lineno)d
52
+ Function: %(funcName)s
53
+ Process: %(process)d
54
+ Thread: %(thread)d
55
+ Message: %(message)s
56
+ Memory: %(memory).2fMB
57
+ """
58
+ }
59
+
60
+ def format(self, record):
61
+ # Add memory usage information
62
+ if not hasattr(record, 'memory'):
63
+ record.memory = psutil.Process().memory_info().rss / (1024 * 1024)
64
+
65
+ log_fmt = self.FORMATS.get(record.levelno)
66
+ formatter = logging.Formatter(log_fmt, datefmt='%Y-%m-%d %H:%M:%S')
67
+
68
+ # Add performance metrics if available
69
+ if hasattr(record, 'duration'):
70
+ record.message = f"{record.message}\nDuration: {record.duration:.2f}s"
71
+
72
+ return formatter.format(record)
73
+
74
+ def setup_logging(log_dir: str = "logs") -> logging.Logger:
75
+ """Enhanced logging setup with multiple handlers and log files"""
76
+ Path(log_dir).mkdir(exist_ok=True)
77
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
78
+ log_path = (Path(log_dir) / f"l_{timestamp}")
79
+ log_path.mkdir(exist_ok=True)
80
+
81
+ # Create logger
82
+ logger = logging.getLogger("InsuranceLLM")
83
+ # Clear any existing handlers
84
+ logger.handlers.clear()
85
+ logger.setLevel(logging.DEBUG)
86
+
87
+ # Create handlers with level-specific files
88
+ handlers = {
89
+ 'debug': (logging.FileHandler(log_path / f"debug_{timestamp}.log"), logging.DEBUG),
90
+ 'info': (logging.FileHandler(log_path / f"info_{timestamp}.log"), logging.INFO),
91
+ 'error': (logging.FileHandler(log_path / f"error_{timestamp}.log"), logging.ERROR),
92
+ 'critical': (logging.FileHandler(log_path / f"critical_{timestamp}.log"), logging.CRITICAL),
93
+ 'console': (RichHandler(
94
+ console=Console(theme=custom_theme),
95
+ show_time=True,
96
+ show_path=False,
97
+ enable_link_path=True
98
+ ), logging.INFO)
99
+ }
100
+
101
+ # Configure handlers
102
+ formatter = CustomFormatter()
103
+ for (handler, level) in handlers.values():
104
+ handler.setLevel(level)
105
+ handler.setFormatter(formatter)
106
+ logger.addHandler(handler)
107
+
108
+ # Log startup information (will now appear only once)
109
+ logger.info(f"Starting new session {timestamp}")
110
+ logger.info(f"Log directory: {log_dir}")
111
+ return logger
112
+
113
+
114
+ # Custom theme configuration
115
+ custom_theme = Theme({"info": "bold cyan","warning": "bold yellow", "error": "bold red","critical": "bold white on red","success": "bold green","timestamp": "bold magenta","metrics": "bold blue","memory": "bold yellow","performance": "bold cyan",})
116
+
117
+ console = Console(theme=custom_theme)
118
+
119
+ class PerformanceMetrics:
120
+ def __init__(self):
121
+ self.start_time = time.time()
122
+ self.tokens = 0
123
+ self.response_times = []
124
+ self.last_reset = self.start_time
125
+
126
+ def reset_timer(self):
127
+ """Reset the timer for individual response measurements"""
128
+ self.last_reset = time.time()
129
+
130
+ def update(self, tokens: int):
131
+ self.tokens += tokens
132
+ response_time = time.time() - self.last_reset
133
+ self.response_times.append(response_time)
134
+
135
+ @property
136
+ def elapsed_time(self) -> float:
137
+ return time.time() - self.start_time
138
+
139
+ @property
140
+ def last_response_time(self) -> float:
141
+ return self.response_times[-1] if self.response_times else 0
142
+
143
+ class InsuranceLLM:
144
+ def __init__(self, config: ModelConfig):
145
+ self.config = config
146
+ self.llm_ctx: Optional[Llama] = None
147
+ self.metrics = PerformanceMetrics()
148
+ self.logger = setup_logging()
149
+
150
+ nvidia_llama3_chatqa_system = (
151
+ "This is a chat between a user and an artificial intelligence assistant. "
152
+ "The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. "
153
+ "The assistant should also indicate when the answer cannot be found in the context. "
154
+ )
155
+ enhanced_system_message = (
156
+ "You are an expert and experienced from the Insurance domain with extensive insurance knowledge and "
157
+ "professional writer skills, especially about insurance policies. "
158
+ "Your name is OpenInsuranceLLM, and you were developed by Raj Maharajwala. "
159
+ "You are willing to help answer the user's query with a detailed explanation. "
160
+ "In your explanation, leverage your deep insurance expertise, such as relevant insurance policies, "
161
+ "complex coverage plans, or other pertinent insurance concepts. Use precise insurance terminology while "
162
+ "still aiming to make the explanation clear and accessible to a general audience."
163
+ )
164
+ self.full_system_message = nvidia_llama3_chatqa_system + enhanced_system_message
165
+
166
+ @contextmanager
167
+ def timer(self, description: str):
168
+ start_time = time.time()
169
+ yield
170
+ elapsed_time = time.time() - start_time
171
+ self.logger.info(f"{description}: {elapsed_time:.2f}s")
172
+
173
+ def download_model(self) -> str:
174
+ try:
175
+ with console.status("[bold green]Downloading model..."):
176
+ model_path = hf_hub_download(
177
+ self.config.model_name,
178
+ filename=self.config.model_file,
179
+ local_dir=os.path.join(os.getcwd(), 'gguf_dir')
180
+ )
181
+ self.logger.info(f"Model downloaded successfully to {model_path}")
182
+ return model_path
183
+ except Exception as e:
184
+ self.logger.error(f"Error downloading model: {str(e)}")
185
+ raise
186
+
187
+ def load_model(self) -> None:
188
+ try:
189
+ # self.check_metal_support()
190
+ quantized_path = os.path.join(os.getcwd(), "gguf_dir")
191
+ directory = Path(quantized_path)
192
+
193
+ try:
194
+ model_path = str(list(directory.glob(self.config.model_file))[0])
195
+ except IndexError:
196
+ model_path = self.download_model()
197
+
198
+ with console.status("[bold green]Loading model..."):
199
+ self.llm_ctx = Llama(
200
+ model_path=model_path,
201
+ n_gpu_layers=self.config.n_gpu_layers,
202
+ n_ctx=self.config.n_ctx,
203
+ n_batch=self.config.n_batch,
204
+ num_beams=self.config.num_beams,
205
+ verbose=self.config.verbose,
206
+ use_mlock=self.config.use_mlock,
207
+ use_mmap=self.config.use_mmap,
208
+ offload_kqv=self.config.offload_kqv
209
+ )
210
+ self.logger.info("Model loaded successfully")
211
+
212
+ except Exception as e:
213
+ self.logger.error(f"Error loading model: {str(e)}")
214
+ raise
215
+
216
+ def get_prompt(self, question: str, context: str = "") -> str:
217
+ if context:
218
+ return (
219
+ f"System: {self.full_system_message}\n\n"
220
+ f"User: Context: {context}\nQuestion: {question}\n\n"
221
+ "Assistant:"
222
+ )
223
+ return (
224
+ f"System: {self.full_system_message}\n\n"
225
+ f"User: Question: {question}\n\n"
226
+ "Assistant:"
227
+ )
228
+
229
+
230
+ def generate_response(self, prompt: str) -> Dict[str, Any]:
231
+ if not self.llm_ctx:
232
+ raise RuntimeError("Model not loaded. Call load_model() first.")
233
+
234
+ try:
235
+ response = {"text": "", "tokens": 0}
236
+
237
+ # Print the initial prompt
238
+ # print("Assistant: ", end="", flush=True)
239
+ console.print("\n[bold cyan]Assistant: [/bold cyan]", end="")
240
+
241
+ # Initialize complete response
242
+ complete_response = ""
243
+
244
+ for chunk in self.llm_ctx.create_completion(
245
+ prompt,
246
+ max_tokens=self.config.max_tokens,
247
+ top_k=self.config.top_k,
248
+ top_p=self.config.top_p,
249
+ temperature=self.config.temperature,
250
+ repeat_penalty=self.config.repeat_penalty,
251
+ stream=True
252
+ ):
253
+ text_chunk = chunk["choices"][0]["text"]
254
+ response["text"] += text_chunk
255
+ response["tokens"] += 1
256
+
257
+ # Append to complete response
258
+ complete_response += text_chunk
259
+
260
+ # Use simple print for streaming output
261
+ print(text_chunk, end="", flush=True)
262
+
263
+ # Print final newline
264
+ print()
265
+
266
+ return response
267
+
268
+ except RuntimeError as e:
269
+ if "llama_decode returned -3" in str(e):
270
+ self.logger.error("Memory allocation failed. Try reducing context window or batch size")
271
+ raise
272
+
273
+ def run_inference_loop(self):
274
+ try:
275
+ self.load_model()
276
+ console.print("\n[bold green]Welcome to Open-Insurance-LLM![/bold green]")
277
+ console.print("Enter your questions (type '/bye', 'exit', or 'quit' to end the session)\n")
278
+ console.print("Optional: You can provide context by typing 'context:' followed by your context, then 'question:' followed by your question\n")
279
+ memory_used = psutil.Process().memory_info().rss / 1024 / 1024
280
+ console.print(f"[dim]Memory usage: {memory_used:.2f} MB[/dim]")
281
+ while True:
282
+ try:
283
+ user_input = console.input("[bold cyan]User:[/bold cyan] ").strip()
284
+
285
+ if user_input.lower() in ["exit", "/bye", "quit"]:
286
+ console.print(f"[dim]Total tokens uptill now: {self.metrics.tokens}[/dim]")
287
+ console.print(f"[dim]Total Session Time: {self.metrics.elapsed_time:.2}[/dim]")
288
+ console.print("\n[bold green]Thank you for using OpenInsuranceLLM![/bold green]")
289
+ break
290
+
291
+ context = ""
292
+ question = user_input
293
+ if "context:" in user_input.lower() and "question:" in user_input.lower():
294
+ parts = user_input.split("question:", 1)
295
+ context = parts[0].replace("context:", "").strip()
296
+ question = parts[1].strip()
297
+
298
+ prompt = self.get_prompt(question, context)
299
+
300
+ # Reset timer before generation
301
+ self.metrics.reset_timer()
302
+
303
+ # Generate response
304
+ response = self.generate_response(prompt)
305
+
306
+ # Update metrics after generation
307
+ self.metrics.update(response["tokens"])
308
+
309
+
310
+ # Print metrics
311
+ console.print(f"[dim]Average tokens/sec: {response['tokens']/(self.metrics.last_response_time if self.metrics.last_response_time!=0 else 1):.2f} ||[/dim]",
312
+ f"[dim]Tokens generated: {response['tokens']} ||[/dim]",
313
+ f"[dim]Response time: {self.metrics.last_response_time:.2f}s[/dim]", end="\n\n\n")
314
+
315
+ except KeyboardInterrupt:
316
+ console.print("\n[yellow]Input interrupted. Type '/bye', 'exit', or 'quit' to quit.[/yellow]")
317
+ continue
318
+ except Exception as e:
319
+ self.logger.error(f"Error processing input: {str(e)}")
320
+ console.print(f"\n[red]Error: {str(e)}[/red]")
321
+ continue
322
+
323
+ except Exception as e:
324
+ self.logger.error(f"Fatal error in inference loop: {str(e)}")
325
+ console.print(f"\n[red]Fatal error: {str(e)}[/red]")
326
+ finally:
327
+ if self.llm_ctx:
328
+ del self.llm_ctx
329
+
330
+ def main():
331
+ if hasattr(multiprocessing, "set_start_method"):
332
+ multiprocessing.set_start_method("spawn", force=True)
333
+ try:
334
+ config = ModelConfig()
335
+ llm = InsuranceLLM(config)
336
+ llm.run_inference_loop()
337
+ except KeyboardInterrupt:
338
+ console.print("\n[yellow]Program interrupted by user[/yellow]")
339
+ except Exception as e:
340
+ error_msg = f"Application error: {str(e)}"
341
+ logging.error(error_msg)
342
+ console.print(f"\n[red]{error_msg}[/red]")
343
+
344
+ if __name__ == "__main__":
345
+ main()
inference_requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ black==24.10.0
2
+ certifi==2024.8.30
3
+ charset-normalizer==3.4.0
4
+ click==8.1.7
5
+ diskcache==5.6.3
6
+ filelock==3.16.1
7
+ fsspec==2024.10.0
8
+ huggingface-hub==0.26.2
9
+ idna==3.10
10
+ iniconfig==2.0.0
11
+ isort==5.13.2
12
+ Jinja2==3.1.4
13
+ llama_cpp_python==0.3.2
14
+ markdown-it-py==3.0.0
15
+ MarkupSafe==3.0.2
16
+ mdurl==0.1.2
17
+ mypy-extensions==1.0.0
18
+ numpy==2.1.3
19
+ packaging==24.2
20
+ pathspec==0.12.1
21
+ platformdirs==4.3.6
22
+ pluggy==1.5.0
23
+ psutil==6.1.0
24
+ Pygments==2.18.0
25
+ pytest==8.3.3
26
+ PyYAML==6.0.2
27
+ requests==2.32.3
28
+ rich==13.9.4
29
+ tqdm==4.67.0
30
+ typing_extensions==4.12.2
31
+ urllib3==2.2.3