Hjgugugjhuhjggg commited on
Commit
c875a14
·
verified ·
1 Parent(s): 543d52b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from llama_cpp import Llama
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
  import uvicorn
4
- from fastapi import FastAPI, HTTPException, Request
5
  import os
6
  from dotenv import load_dotenv
7
  from pydantic import BaseModel
@@ -15,16 +15,20 @@ from faker import Faker
15
  import gradio as gr
16
  from threading import Thread
17
 
 
18
  nltk.download('punkt')
19
  nltk.download('stopwords')
20
 
 
21
  load_dotenv()
22
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
23
 
 
24
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
25
 
26
  fake = Faker()
27
 
 
28
  global_data = {
29
  'models': {},
30
  'tokens': {
@@ -45,31 +49,32 @@ global_data = {
45
  'model_params': {},
46
  }
47
 
 
48
  model_configs = [
49
  {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-70B-Q2_K-GGUF", "filename": "meta-llama-3.1-70b-q2_k.gguf", "name": "meta-llama-3.1-70b", "seed": 42, "n_ctx": 1024},
50
  {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf", "name": "gemma-2-27b", "seed": 42, "n_ctx": 1024}
51
  ]
52
 
 
53
  def load_model(model_config):
54
  model_name = model_config['name']
55
  if model_name not in global_data['models']:
56
  try:
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
58
 
59
- # Manually define the context parameters
60
  context_params = {
61
  "seed": model_config.get('seed', 42),
62
  "n_ctx": model_config.get('n_ctx', 1024)
63
  }
64
 
65
- # Initialize the model with context parameters
66
  model = Llama.from_pretrained(
67
  repo_id=model_config['repo_id'],
68
  filename=model_config['filename'],
69
  use_auth_token=HUGGINGFACE_TOKEN,
70
  verbose=True,
71
  device=device,
72
- context_params=context_params # Pass context params directly
73
  )
74
 
75
  global_data['models'][model_name] = model
@@ -79,11 +84,11 @@ def load_model(model_config):
79
  logging.critical(f"CRITICAL ERROR loading model '{model_name}': {e}", exc_info=True)
80
  return None
81
 
82
- # Load all models
83
  for config in model_configs:
84
  load_model(config)
85
 
86
- # Class for the incoming request
87
  class ChatRequest(BaseModel):
88
  message: str
89
 
@@ -91,7 +96,7 @@ class ChatRequest(BaseModel):
91
  def normalize_input(input_text):
92
  return input_text.strip()
93
 
94
- # Function to remove duplicate sentences based on similarity
95
  def remove_duplicates(text, similarity_threshold=0.85):
96
  sentences = sent_tokenize(text)
97
  unique_sentences = []
@@ -106,12 +111,13 @@ def remove_duplicates(text, similarity_threshold=0.85):
106
  unique_sentences.append(sentence)
107
  return " ".join(unique_sentences)
108
 
109
- # GPU task function with error handling and CPU fallback
110
  @spaces.GPU(duration=0)
111
  def generate_model_response(model, inputs, model_config):
112
  try:
113
  if model is None:
114
  return []
 
115
  responses = []
116
  model_metadata = global_data['model_metadata'].get(model_config['name'], {})
117
  stop_tokens = [global_data['tokens'].get('eos', '<|end_of_text|>')]
@@ -140,7 +146,7 @@ def generate_model_response(model, inputs, model_config):
140
  # FastAPI app
141
  app = FastAPI()
142
 
143
- # POST endpoint to handle chat requests
144
  @app.post("/chat")
145
  async def chat(request: ChatRequest):
146
  input_text = normalize_input(request.message)
@@ -152,7 +158,7 @@ async def chat(request: ChatRequest):
152
  response = generate_model_response(model_instance, input_text, model_configs[0])
153
  return {"response": response[0] if response else "No response generated."}
154
 
155
- # Gradio Interface for testing the model
156
  def gradio_interface(input_text):
157
  model_name = "meta-llama-3.1-70b"
158
  model_instance = global_data['models'].get(model_name, None)
@@ -161,14 +167,19 @@ def gradio_interface(input_text):
161
  response = generate_model_response(model_instance, input_text, model_configs[0])
162
  return response[0] if response else "No response generated."
163
 
164
- # Gradio Interface setup
165
  def start_gradio_interface():
166
  gr.Interface(fn=gradio_interface, inputs="text", outputs="text").launch(share=True)
167
 
168
- # Run Gradio in a separate thread to avoid blocking FastAPI
169
- gradio_thread = Thread(target=start_gradio_interface)
170
- gradio_thread.start()
 
 
 
 
 
171
 
172
- # Run the FastAPI app
173
  if __name__ == "__main__":
174
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from llama_cpp import Llama
2
  from concurrent.futures import ThreadPoolExecutor, as_completed
3
  import uvicorn
4
+ from fastapi import FastAPI, HTTPException
5
  import os
6
  from dotenv import load_dotenv
7
  from pydantic import BaseModel
 
15
  import gradio as gr
16
  from threading import Thread
17
 
18
+ # Download NLTK resources
19
  nltk.download('punkt')
20
  nltk.download('stopwords')
21
 
22
+ # Load environment variables from .env file
23
  load_dotenv()
24
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
25
 
26
+ # Set up logging
27
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s')
28
 
29
  fake = Faker()
30
 
31
+ # Global data structure to hold models and configurations
32
  global_data = {
33
  'models': {},
34
  'tokens': {
 
49
  'model_params': {},
50
  }
51
 
52
+ # Model configurations
53
  model_configs = [
54
  {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-70B-Q2_K-GGUF", "filename": "meta-llama-3.1-70b-q2_k.gguf", "name": "meta-llama-3.1-70b", "seed": 42, "n_ctx": 1024},
55
  {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf", "name": "gemma-2-27b", "seed": 42, "n_ctx": 1024}
56
  ]
57
 
58
+ # Function to load model
59
  def load_model(model_config):
60
  model_name = model_config['name']
61
  if model_name not in global_data['models']:
62
  try:
63
  device = "cuda" if torch.cuda.is_available() else "cpu"
64
 
 
65
  context_params = {
66
  "seed": model_config.get('seed', 42),
67
  "n_ctx": model_config.get('n_ctx', 1024)
68
  }
69
 
70
+ # Initialize model
71
  model = Llama.from_pretrained(
72
  repo_id=model_config['repo_id'],
73
  filename=model_config['filename'],
74
  use_auth_token=HUGGINGFACE_TOKEN,
75
  verbose=True,
76
  device=device,
77
+ context_params=context_params
78
  )
79
 
80
  global_data['models'][model_name] = model
 
84
  logging.critical(f"CRITICAL ERROR loading model '{model_name}': {e}", exc_info=True)
85
  return None
86
 
87
+ # Load all models at the start
88
  for config in model_configs:
89
  load_model(config)
90
 
91
+ # Pydantic model to validate incoming requests
92
  class ChatRequest(BaseModel):
93
  message: str
94
 
 
96
  def normalize_input(input_text):
97
  return input_text.strip()
98
 
99
+ # Function to remove duplicate sentences
100
  def remove_duplicates(text, similarity_threshold=0.85):
101
  sentences = sent_tokenize(text)
102
  unique_sentences = []
 
111
  unique_sentences.append(sentence)
112
  return " ".join(unique_sentences)
113
 
114
+ # Function to handle model response generation with GPU fallback
115
  @spaces.GPU(duration=0)
116
  def generate_model_response(model, inputs, model_config):
117
  try:
118
  if model is None:
119
  return []
120
+
121
  responses = []
122
  model_metadata = global_data['model_metadata'].get(model_config['name'], {})
123
  stop_tokens = [global_data['tokens'].get('eos', '<|end_of_text|>')]
 
146
  # FastAPI app
147
  app = FastAPI()
148
 
149
+ # FastAPI POST endpoint to handle chat requests
150
  @app.post("/chat")
151
  async def chat(request: ChatRequest):
152
  input_text = normalize_input(request.message)
 
158
  response = generate_model_response(model_instance, input_text, model_configs[0])
159
  return {"response": response[0] if response else "No response generated."}
160
 
161
+ # Gradio interface for model testing
162
  def gradio_interface(input_text):
163
  model_name = "meta-llama-3.1-70b"
164
  model_instance = global_data['models'].get(model_name, None)
 
167
  response = generate_model_response(model_instance, input_text, model_configs[0])
168
  return response[0] if response else "No response generated."
169
 
170
+ # Gradio interface setup
171
  def start_gradio_interface():
172
  gr.Interface(fn=gradio_interface, inputs="text", outputs="text").launch(share=True)
173
 
174
+ # Run Gradio in a separate thread
175
+ def start_gradio():
176
+ gradio_thread = Thread(target=start_gradio_interface)
177
+ gradio_thread.daemon = True # Ensures the thread will exit when the main program exits
178
+ gradio_thread.start()
179
+
180
+ # Start the Gradio interface
181
+ start_gradio()
182
 
183
+ # Run FastAPI app using uvicorn
184
  if __name__ == "__main__":
185
  uvicorn.run(app, host="0.0.0.0", port=7860)