Uhhy commited on
Commit
a17dc9a
1 Parent(s): e6dda1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -32
app.py CHANGED
@@ -1,11 +1,11 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from llama_cpp import Llama
4
- from concurrent.futures import ProcessPoolExecutor, as_completed
 
5
  import uvicorn
6
  from dotenv import load_dotenv
7
  from difflib import SequenceMatcher
8
- from tqdm import tqdm
9
  import threading
10
 
11
  load_dotenv()
@@ -13,30 +13,23 @@ load_dotenv()
13
  app = FastAPI()
14
 
15
  # Configuración de los modelos
16
- models = [
17
  {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf"},
18
  {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Instruct-Q2_K-GGUF", "filename": "meta-llama-3.1-8b-instruct-q2_k.gguf"},
19
  {"repo_id": "Ffftdtd5dtft/gemma-2-9b-it-Q2_K-GGUF", "filename": "gemma-2-9b-it-q2_k.gguf"},
20
  {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf"},
21
  ]
22
 
23
- # Función para cargar un modelo
24
  def load_model(model_config):
25
  return Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'])
26
 
27
- # Cargar modelos en paralelo
28
  def load_all_models():
29
- with ProcessPoolExecutor() as executor:
30
- future_to_model = {executor.submit(load_model, model): model for model in models}
31
- loaded_models = {}
32
- for future in as_completed(future_to_model):
33
- model = future_to_model[future]
34
- try:
35
- loaded_models[model['repo_id']] = future.result()
36
- print(f"Modelo cargado en RAM: {model['repo_id']}")
37
- except Exception as exc:
38
- print(f"Error al cargar modelo {model['repo_id']}: {exc}")
39
- return loaded_models
40
 
41
  # Cargar modelos en memoria
42
  llms = load_all_models()
@@ -47,7 +40,7 @@ class ChatRequest(BaseModel):
47
  top_p: float = 0.95
48
  temperature: float = 0.7
49
 
50
- # Función global para generar respuestas de chat
51
  def generate_chat_response(request, llm):
52
  try:
53
  user_input = normalize_input(request.message)
@@ -103,8 +96,10 @@ def filter_by_similarity(responses):
103
  break
104
  return best_response
105
 
106
- def worker_function(llm, request):
107
- return generate_chat_response(request, llm)
 
 
108
 
109
  @app.post("/generate_chat")
110
  async def generate_chat(request: ChatRequest):
@@ -114,26 +109,28 @@ async def generate_chat(request: ChatRequest):
114
  print(f"Procesando solicitud: {request.message}")
115
 
116
  responses = []
117
- threads = []
118
-
119
- # Crear un hilo para cada modelo
120
- for llm in llms.values():
121
- thread = threading.Thread(target=lambda: responses.append(worker_function(llm, request)))
122
- threads.append(thread)
123
- thread.start()
124
-
125
- # Esperar a que todos los hilos terminen
126
- for thread in threads:
127
- thread.join()
 
 
128
 
129
  # Seleccionar la mejor respuesta
130
- best_response = select_best_response([response['response'] for response in responses])
131
 
132
  print(f"Mejor respuesta seleccionada: {best_response}")
133
 
134
  return {
135
  "best_response": best_response,
136
- "all_responses": [response['response'] for response in responses]
137
  }
138
 
139
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from llama_cpp import Llama
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from tqdm import tqdm
6
  import uvicorn
7
  from dotenv import load_dotenv
8
  from difflib import SequenceMatcher
 
9
  import threading
10
 
11
  load_dotenv()
 
13
  app = FastAPI()
14
 
15
  # Configuración de los modelos
16
+ model_configs = [
17
  {"repo_id": "Ffftdtd5dtft/gpt2-xl-Q2_K-GGUF", "filename": "gpt2-xl-q2_k.gguf"},
18
  {"repo_id": "Ffftdtd5dtft/Meta-Llama-3.1-8B-Instruct-Q2_K-GGUF", "filename": "meta-llama-3.1-8b-instruct-q2_k.gguf"},
19
  {"repo_id": "Ffftdtd5dtft/gemma-2-9b-it-Q2_K-GGUF", "filename": "gemma-2-9b-it-q2_k.gguf"},
20
  {"repo_id": "Ffftdtd5dtft/gemma-2-27b-Q2_K-GGUF", "filename": "gemma-2-27b-q2_k.gguf"},
21
  ]
22
 
23
+ # Cargar un modelo
24
  def load_model(model_config):
25
  return Llama.from_pretrained(repo_id=model_config['repo_id'], filename=model_config['filename'])
26
 
27
+ # Cargar todos los modelos simultáneamente
28
  def load_all_models():
29
+ with ThreadPoolExecutor(max_workers=len(model_configs)) as executor:
30
+ futures = [executor.submit(load_model, config) for config in model_configs]
31
+ models = [future.result() for future in as_completed(futures)]
32
+ return models
 
 
 
 
 
 
 
33
 
34
  # Cargar modelos en memoria
35
  llms = load_all_models()
 
40
  top_p: float = 0.95
41
  temperature: float = 0.7
42
 
43
+ # Función para generar respuestas de chat
44
  def generate_chat_response(request, llm):
45
  try:
46
  user_input = normalize_input(request.message)
 
96
  break
97
  return best_response
98
 
99
+ def worker_function(llm, request, progress_bar):
100
+ response = generate_chat_response(request, llm)
101
+ progress_bar.update(1)
102
+ return response
103
 
104
  @app.post("/generate_chat")
105
  async def generate_chat(request: ChatRequest):
 
109
  print(f"Procesando solicitud: {request.message}")
110
 
111
  responses = []
112
+ num_models = len(llms)
113
+
114
+ # Crear barra de progreso
115
+ with tqdm(total=num_models, desc="Generando respuestas", unit="modelo") as progress_bar:
116
+ # Ejecutar modelos en paralelo
117
+ with ThreadPoolExecutor(max_workers=num_models) as executor:
118
+ futures = [executor.submit(worker_function, llm, request, progress_bar) for llm in llms]
119
+ for future in as_completed(futures):
120
+ try:
121
+ response = future.result()
122
+ responses.append(response['response'])
123
+ except Exception as exc:
124
+ print(f"Error en la generación de respuesta: {exc}")
125
 
126
  # Seleccionar la mejor respuesta
127
+ best_response = select_best_response(responses)
128
 
129
  print(f"Mejor respuesta seleccionada: {best_response}")
130
 
131
  return {
132
  "best_response": best_response,
133
+ "all_responses": responses
134
  }
135
 
136
  if __name__ == "__main__":