Yhhxhfh commited on
Commit
a4863fb
verified
1 Parent(s): eab90e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -26
app.py CHANGED
@@ -12,6 +12,10 @@ from transformers import AutoTokenizer, GPT2LMHeadModel, pipeline
12
  from loguru import logger
13
  from dotenv import load_dotenv
14
  from sklearn.metrics.pairwise import cosine_similarity
 
 
 
 
15
 
16
  sys.path.append('..')
17
 
@@ -28,6 +32,33 @@ redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_pass
28
 
29
  MAX_ITEMS_PER_TABLE = 10000
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def get_current_table_index():
32
  return int(redis_client.get("current_table_index") or 0)
33
 
@@ -57,6 +88,20 @@ def load_and_store_models(model_names):
57
  except Exception as e:
58
  logger.error(f"Error loading model {name}: {e}")
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  app = FastAPI()
61
  app.add_middleware(
62
  CORSMiddleware,
@@ -89,6 +134,21 @@ async def index():
89
  .user-message, .bot-message {{ margin-bottom: 10px; padding: 8px 12px; border-radius: 8px; max-width: 70%; word-wrap: break-word; }}
90
  .user-message {{ background-color: #007bff; color: #fff; align-self: flex-end; }}
91
  .bot-message {{ background-color: #4CAF50; color: #fff; }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  </style>
93
  </head>
94
  <body>
@@ -98,25 +158,53 @@ async def index():
98
  <div class="chat-box" id="chat-box">
99
  {chat_history_html}
100
  </div>
101
- <input type="text" class="chat-input" id="user-input" placeholder="Type your message...">
 
102
  </div>
103
  </div>
104
  <script>
105
  const userInput = document.getElementById('user-input');
 
106
 
107
  userInput.addEventListener('keyup', function(event) {{
108
  if (event.key === 'Enter') {{
109
  event.preventDefault();
110
  sendMessage();
 
 
 
 
 
 
 
 
 
111
  }}
112
  }});
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  function sendMessage() {{
115
  const userMessage = userInput.value.trim();
116
  if (userMessage === '') return;
117
 
118
  appendMessage('user', userMessage);
119
  userInput.value = '';
 
120
 
121
  fetch(`/autocomplete?q=` + encodeURIComponent(userMessage))
122
  .then(response => response.json())
@@ -159,41 +247,26 @@ def calculate_similarity(base_text, candidate_texts):
159
  return similarities
160
 
161
  @app.get('/autocomplete')
162
- async def autocomplete(q: str = Query(..., title='query'), background_tasks: BackgroundTasks = BackgroundTasks()): # Correcci贸n: Mover background_tasks al final
163
  global message_history
164
  message_history.append(('user', q))
165
 
166
- background_tasks.add_task(generate_responses, q)
167
- return {"status": "Processing request, please wait..."}
 
 
 
 
 
 
 
168
 
169
  @app.get('/get_response')
170
  async def get_response(q: str = Query(..., title='query')):
171
  response = redis_client.hget("responses", q)
172
  return {"response": response}
173
 
174
- def generate_responses(q):
175
- generated_responses = []
176
- try:
177
- for model_name in redis_client.hkeys("models"):
178
- try:
179
- model_data = redis_client.hget("models", model_name)
180
- model = GPT2LMHeadModel.from_pretrained(model_name)
181
- tokenizer = AutoTokenizer.from_pretrained(model_name)
182
- text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
183
- generated_response = text_generation_pipeline(q, do_sample=True, max_length=50, num_return_sequences=5)
184
- generated_responses.extend([response['generated_text'] for response in generated_response])
185
- except Exception as e:
186
- logger.error(f"Error generating response with model {model_name}: {e}")
187
 
188
- if generated_responses:
189
- similarities = calculate_similarity(q, generated_responses)
190
- most_coherent_response = generated_responses[np.argmax(similarities)]
191
- store_to_redis_table(q, "\n".join(generated_responses))
192
- redis_client.hset("responses", q, most_coherent_response)
193
- else:
194
- logger.warning("No valid responses generated.")
195
- except Exception as e:
196
- logger.error(f"General error in autocomplete: {e}")
197
 
198
  if __name__ == '__main__':
199
  gpt2_models = [
@@ -210,6 +283,13 @@ if __name__ == '__main__':
210
  "Salesforce/codegen-350M-multi"
211
  ]
212
 
 
 
 
 
 
 
213
  load_and_store_models(gpt2_models + programming_models)
 
214
 
215
  uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 7860)))
 
12
  from loguru import logger
13
  from dotenv import load_dotenv
14
  from sklearn.metrics.pairwise import cosine_similarity
15
+ from kaggle.api.kaggle_api_extended import KaggleApi
16
+
17
+ # Importar la librer铆a de spaces
18
+ from huggingface_hub import spaces
19
 
20
  sys.path.append('..')
21
 
 
32
 
33
  MAX_ITEMS_PER_TABLE = 10000
34
 
35
+ # Decorador para usar GPU en Spaces
36
+ @spaces.GPU()
37
+ def generate_responses_gpu(q):
38
+ generated_responses = []
39
+ try:
40
+ for model_name in redis_client.hkeys("models"):
41
+ try:
42
+ model_data = redis_client.hget("models", model_name)
43
+ model = GPT2LMHeadModel.from_pretrained(model_name)
44
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
45
+ text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
46
+ generated_response = text_generation_pipeline(q, do_sample=True, max_length=50, num_return_sequences=5)
47
+ generated_responses.extend([response['generated_text'] for response in generated_response])
48
+ except Exception as e:
49
+ logger.error(f"Error generating response with model {model_name}: {e}")
50
+
51
+ if generated_responses:
52
+ similarities = calculate_similarity(q, generated_responses)
53
+ most_coherent_response = generated_responses[np.argmax(similarities)]
54
+ store_to_redis_table(q, "\n".join(generated_responses))
55
+ redis_client.hset("responses", q, most_coherent_response)
56
+ else:
57
+ logger.warning("No valid responses generated.")
58
+ except Exception as e:
59
+ logger.error(f"General error in autocomplete: {e}")
60
+
61
+
62
  def get_current_table_index():
63
  return int(redis_client.get("current_table_index") or 0)
64
 
 
88
  except Exception as e:
89
  logger.error(f"Error loading model {name}: {e}")
90
 
91
+ def load_kaggle_datasets(dataset_names):
92
+ api = KaggleApi()
93
+ api.authenticate()
94
+ for dataset_name in dataset_names:
95
+ try:
96
+ api.dataset_download_files(dataset_name, path='./kaggle_datasets', unzip=True)
97
+ dataset = load_dataset('csv', data_files=[f'./kaggle_datasets/{dataset_name}/*.csv'])['train']
98
+ sample_data = dataset.to_pandas().head(10).to_json(orient='records')
99
+ store_to_redis_table(dataset_name, sample_data)
100
+ redis_client.hset("kaggle_datasets", dataset_name, sample_data)
101
+ except Exception as e:
102
+ logger.error(f"Error loading Kaggle dataset {dataset_name}: {e}")
103
+
104
+
105
  app = FastAPI()
106
  app.add_middleware(
107
  CORSMiddleware,
 
134
  .user-message, .bot-message {{ margin-bottom: 10px; padding: 8px 12px; border-radius: 8px; max-width: 70%; word-wrap: break-word; }}
135
  .user-message {{ background-color: #007bff; color: #fff; align-self: flex-end; }}
136
  .bot-message {{ background-color: #4CAF50; color: #fff; }}
137
+ #autocomplete-suggestions {{
138
+ position: absolute;
139
+ background-color: #fff;
140
+ border: 1px solid #ccc;
141
+ border-radius: 4px;
142
+ z-index: 10;
143
+ max-width: calc(100% - 40px);
144
+ }}
145
+ .suggestion {{
146
+ padding: 8px;
147
+ cursor: pointer;
148
+ }}
149
+ .suggestion:hover {{
150
+ background-color: #f0f0f0;
151
+ }}
152
  </style>
153
  </head>
154
  <body>
 
158
  <div class="chat-box" id="chat-box">
159
  {chat_history_html}
160
  </div>
161
+ <input type="text" class="chat-input" id="user-input" placeholder="Type your message..." autocomplete="off">
162
+ <div id="autocomplete-suggestions"></div>
163
  </div>
164
  </div>
165
  <script>
166
  const userInput = document.getElementById('user-input');
167
+ const autocompleteSuggestions = document.getElementById('autocomplete-suggestions');
168
 
169
  userInput.addEventListener('keyup', function(event) {{
170
  if (event.key === 'Enter') {{
171
  event.preventDefault();
172
  sendMessage();
173
+ }} else {{
174
+ fetch(`/autocomplete?q=` + encodeURIComponent(userInput.value))
175
+ .then(response => response.json())
176
+ .then(data => {{
177
+ displayAutocompleteSuggestions(data.suggestions);
178
+ }})
179
+ .catch(error => {{
180
+ console.error('Error:', error);
181
+ }});
182
  }}
183
  }});
184
 
185
+ function displayAutocompleteSuggestions(suggestions) {{
186
+ autocompleteSuggestions.innerHTML = '';
187
+ if (suggestions.length > 0) {{
188
+ suggestions.forEach(suggestion => {{
189
+ const suggestionElement = document.createElement('div');
190
+ suggestionElement.className = 'suggestion';
191
+ suggestionElement.innerText = suggestion;
192
+ suggestionElement.onclick = () => {{
193
+ userInput.value = suggestion;
194
+ autocompleteSuggestions.innerHTML = '';
195
+ }};
196
+ autocompleteSuggestions.appendChild(suggestionElement);
197
+ }});
198
+ }}
199
+ }}
200
+
201
  function sendMessage() {{
202
  const userMessage = userInput.value.trim();
203
  if (userMessage === '') return;
204
 
205
  appendMessage('user', userMessage);
206
  userInput.value = '';
207
+ autocompleteSuggestions.innerHTML = '';
208
 
209
  fetch(`/autocomplete?q=` + encodeURIComponent(userMessage))
210
  .then(response => response.json())
 
247
  return similarities
248
 
249
  @app.get('/autocomplete')
250
+ async def autocomplete(q: str = Query(..., title='query'), background_tasks: BackgroundTasks = BackgroundTasks()):
251
  global message_history
252
  message_history.append(('user', q))
253
 
254
+ suggestions = []
255
+ if q:
256
+ for key in redis_client.hkeys("responses"):
257
+ if q.lower() in key.lower():
258
+ suggestions.append(key)
259
+
260
+ # Lanzar la tarea en segundo plano utilizando la funci贸n decorada con @spaces.GPU()
261
+ background_tasks.add_task(generate_responses_gpu, q)
262
+ return {"status": "Processing request, please wait...", "suggestions": suggestions}
263
 
264
  @app.get('/get_response')
265
  async def get_response(q: str = Query(..., title='query')):
266
  response = redis_client.hget("responses", q)
267
  return {"response": response}
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
 
 
 
 
 
 
 
 
 
270
 
271
  if __name__ == '__main__':
272
  gpt2_models = [
 
283
  "Salesforce/codegen-350M-multi"
284
  ]
285
 
286
+ kaggle_datasets = [
287
+ "uciml/iris",
288
+ "arshid/iris-flower-dataset",
289
+ "heesoo37/120-years-of-olympic-history-athletes-and-results"
290
+ ]
291
+
292
  load_and_store_models(gpt2_models + programming_models)
293
+ load_kaggle_datasets(kaggle_datasets)
294
 
295
  uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 7860)))