Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,10 @@ from transformers import AutoTokenizer, GPT2LMHeadModel, pipeline
|
|
12 |
from loguru import logger
|
13 |
from dotenv import load_dotenv
|
14 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
|
|
|
|
|
15 |
|
16 |
sys.path.append('..')
|
17 |
|
@@ -28,6 +32,33 @@ redis_client = redis.Redis(host=redis_host, port=redis_port, password=redis_pass
|
|
28 |
|
29 |
MAX_ITEMS_PER_TABLE = 10000
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def get_current_table_index():
|
32 |
return int(redis_client.get("current_table_index") or 0)
|
33 |
|
@@ -57,6 +88,20 @@ def load_and_store_models(model_names):
|
|
57 |
except Exception as e:
|
58 |
logger.error(f"Error loading model {name}: {e}")
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
app = FastAPI()
|
61 |
app.add_middleware(
|
62 |
CORSMiddleware,
|
@@ -89,6 +134,21 @@ async def index():
|
|
89 |
.user-message, .bot-message {{ margin-bottom: 10px; padding: 8px 12px; border-radius: 8px; max-width: 70%; word-wrap: break-word; }}
|
90 |
.user-message {{ background-color: #007bff; color: #fff; align-self: flex-end; }}
|
91 |
.bot-message {{ background-color: #4CAF50; color: #fff; }}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
</style>
|
93 |
</head>
|
94 |
<body>
|
@@ -98,25 +158,53 @@ async def index():
|
|
98 |
<div class="chat-box" id="chat-box">
|
99 |
{chat_history_html}
|
100 |
</div>
|
101 |
-
<input type="text" class="chat-input" id="user-input" placeholder="Type your message...">
|
|
|
102 |
</div>
|
103 |
</div>
|
104 |
<script>
|
105 |
const userInput = document.getElementById('user-input');
|
|
|
106 |
|
107 |
userInput.addEventListener('keyup', function(event) {{
|
108 |
if (event.key === 'Enter') {{
|
109 |
event.preventDefault();
|
110 |
sendMessage();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
}}
|
112 |
}});
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
function sendMessage() {{
|
115 |
const userMessage = userInput.value.trim();
|
116 |
if (userMessage === '') return;
|
117 |
|
118 |
appendMessage('user', userMessage);
|
119 |
userInput.value = '';
|
|
|
120 |
|
121 |
fetch(`/autocomplete?q=` + encodeURIComponent(userMessage))
|
122 |
.then(response => response.json())
|
@@ -159,41 +247,26 @@ def calculate_similarity(base_text, candidate_texts):
|
|
159 |
return similarities
|
160 |
|
161 |
@app.get('/autocomplete')
|
162 |
-
async def autocomplete(q: str = Query(..., title='query'), background_tasks: BackgroundTasks = BackgroundTasks()):
|
163 |
global message_history
|
164 |
message_history.append(('user', q))
|
165 |
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
@app.get('/get_response')
|
170 |
async def get_response(q: str = Query(..., title='query')):
|
171 |
response = redis_client.hget("responses", q)
|
172 |
return {"response": response}
|
173 |
|
174 |
-
def generate_responses(q):
|
175 |
-
generated_responses = []
|
176 |
-
try:
|
177 |
-
for model_name in redis_client.hkeys("models"):
|
178 |
-
try:
|
179 |
-
model_data = redis_client.hget("models", model_name)
|
180 |
-
model = GPT2LMHeadModel.from_pretrained(model_name)
|
181 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
182 |
-
text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
|
183 |
-
generated_response = text_generation_pipeline(q, do_sample=True, max_length=50, num_return_sequences=5)
|
184 |
-
generated_responses.extend([response['generated_text'] for response in generated_response])
|
185 |
-
except Exception as e:
|
186 |
-
logger.error(f"Error generating response with model {model_name}: {e}")
|
187 |
|
188 |
-
if generated_responses:
|
189 |
-
similarities = calculate_similarity(q, generated_responses)
|
190 |
-
most_coherent_response = generated_responses[np.argmax(similarities)]
|
191 |
-
store_to_redis_table(q, "\n".join(generated_responses))
|
192 |
-
redis_client.hset("responses", q, most_coherent_response)
|
193 |
-
else:
|
194 |
-
logger.warning("No valid responses generated.")
|
195 |
-
except Exception as e:
|
196 |
-
logger.error(f"General error in autocomplete: {e}")
|
197 |
|
198 |
if __name__ == '__main__':
|
199 |
gpt2_models = [
|
@@ -210,6 +283,13 @@ if __name__ == '__main__':
|
|
210 |
"Salesforce/codegen-350M-multi"
|
211 |
]
|
212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
load_and_store_models(gpt2_models + programming_models)
|
|
|
214 |
|
215 |
uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 7860)))
|
|
|
12 |
from loguru import logger
|
13 |
from dotenv import load_dotenv
|
14 |
from sklearn.metrics.pairwise import cosine_similarity
|
15 |
+
from kaggle.api.kaggle_api_extended import KaggleApi
|
16 |
+
|
17 |
+
# Importar la librer铆a de spaces
|
18 |
+
from huggingface_hub import spaces
|
19 |
|
20 |
sys.path.append('..')
|
21 |
|
|
|
32 |
|
33 |
MAX_ITEMS_PER_TABLE = 10000
|
34 |
|
35 |
+
# Decorador para usar GPU en Spaces
|
36 |
+
@spaces.GPU()
|
37 |
+
def generate_responses_gpu(q):
|
38 |
+
generated_responses = []
|
39 |
+
try:
|
40 |
+
for model_name in redis_client.hkeys("models"):
|
41 |
+
try:
|
42 |
+
model_data = redis_client.hget("models", model_name)
|
43 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
44 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
45 |
+
text_generation_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
|
46 |
+
generated_response = text_generation_pipeline(q, do_sample=True, max_length=50, num_return_sequences=5)
|
47 |
+
generated_responses.extend([response['generated_text'] for response in generated_response])
|
48 |
+
except Exception as e:
|
49 |
+
logger.error(f"Error generating response with model {model_name}: {e}")
|
50 |
+
|
51 |
+
if generated_responses:
|
52 |
+
similarities = calculate_similarity(q, generated_responses)
|
53 |
+
most_coherent_response = generated_responses[np.argmax(similarities)]
|
54 |
+
store_to_redis_table(q, "\n".join(generated_responses))
|
55 |
+
redis_client.hset("responses", q, most_coherent_response)
|
56 |
+
else:
|
57 |
+
logger.warning("No valid responses generated.")
|
58 |
+
except Exception as e:
|
59 |
+
logger.error(f"General error in autocomplete: {e}")
|
60 |
+
|
61 |
+
|
62 |
def get_current_table_index():
|
63 |
return int(redis_client.get("current_table_index") or 0)
|
64 |
|
|
|
88 |
except Exception as e:
|
89 |
logger.error(f"Error loading model {name}: {e}")
|
90 |
|
91 |
+
def load_kaggle_datasets(dataset_names):
|
92 |
+
api = KaggleApi()
|
93 |
+
api.authenticate()
|
94 |
+
for dataset_name in dataset_names:
|
95 |
+
try:
|
96 |
+
api.dataset_download_files(dataset_name, path='./kaggle_datasets', unzip=True)
|
97 |
+
dataset = load_dataset('csv', data_files=[f'./kaggle_datasets/{dataset_name}/*.csv'])['train']
|
98 |
+
sample_data = dataset.to_pandas().head(10).to_json(orient='records')
|
99 |
+
store_to_redis_table(dataset_name, sample_data)
|
100 |
+
redis_client.hset("kaggle_datasets", dataset_name, sample_data)
|
101 |
+
except Exception as e:
|
102 |
+
logger.error(f"Error loading Kaggle dataset {dataset_name}: {e}")
|
103 |
+
|
104 |
+
|
105 |
app = FastAPI()
|
106 |
app.add_middleware(
|
107 |
CORSMiddleware,
|
|
|
134 |
.user-message, .bot-message {{ margin-bottom: 10px; padding: 8px 12px; border-radius: 8px; max-width: 70%; word-wrap: break-word; }}
|
135 |
.user-message {{ background-color: #007bff; color: #fff; align-self: flex-end; }}
|
136 |
.bot-message {{ background-color: #4CAF50; color: #fff; }}
|
137 |
+
#autocomplete-suggestions {{
|
138 |
+
position: absolute;
|
139 |
+
background-color: #fff;
|
140 |
+
border: 1px solid #ccc;
|
141 |
+
border-radius: 4px;
|
142 |
+
z-index: 10;
|
143 |
+
max-width: calc(100% - 40px);
|
144 |
+
}}
|
145 |
+
.suggestion {{
|
146 |
+
padding: 8px;
|
147 |
+
cursor: pointer;
|
148 |
+
}}
|
149 |
+
.suggestion:hover {{
|
150 |
+
background-color: #f0f0f0;
|
151 |
+
}}
|
152 |
</style>
|
153 |
</head>
|
154 |
<body>
|
|
|
158 |
<div class="chat-box" id="chat-box">
|
159 |
{chat_history_html}
|
160 |
</div>
|
161 |
+
<input type="text" class="chat-input" id="user-input" placeholder="Type your message..." autocomplete="off">
|
162 |
+
<div id="autocomplete-suggestions"></div>
|
163 |
</div>
|
164 |
</div>
|
165 |
<script>
|
166 |
const userInput = document.getElementById('user-input');
|
167 |
+
const autocompleteSuggestions = document.getElementById('autocomplete-suggestions');
|
168 |
|
169 |
userInput.addEventListener('keyup', function(event) {{
|
170 |
if (event.key === 'Enter') {{
|
171 |
event.preventDefault();
|
172 |
sendMessage();
|
173 |
+
}} else {{
|
174 |
+
fetch(`/autocomplete?q=` + encodeURIComponent(userInput.value))
|
175 |
+
.then(response => response.json())
|
176 |
+
.then(data => {{
|
177 |
+
displayAutocompleteSuggestions(data.suggestions);
|
178 |
+
}})
|
179 |
+
.catch(error => {{
|
180 |
+
console.error('Error:', error);
|
181 |
+
}});
|
182 |
}}
|
183 |
}});
|
184 |
|
185 |
+
function displayAutocompleteSuggestions(suggestions) {{
|
186 |
+
autocompleteSuggestions.innerHTML = '';
|
187 |
+
if (suggestions.length > 0) {{
|
188 |
+
suggestions.forEach(suggestion => {{
|
189 |
+
const suggestionElement = document.createElement('div');
|
190 |
+
suggestionElement.className = 'suggestion';
|
191 |
+
suggestionElement.innerText = suggestion;
|
192 |
+
suggestionElement.onclick = () => {{
|
193 |
+
userInput.value = suggestion;
|
194 |
+
autocompleteSuggestions.innerHTML = '';
|
195 |
+
}};
|
196 |
+
autocompleteSuggestions.appendChild(suggestionElement);
|
197 |
+
}});
|
198 |
+
}}
|
199 |
+
}}
|
200 |
+
|
201 |
function sendMessage() {{
|
202 |
const userMessage = userInput.value.trim();
|
203 |
if (userMessage === '') return;
|
204 |
|
205 |
appendMessage('user', userMessage);
|
206 |
userInput.value = '';
|
207 |
+
autocompleteSuggestions.innerHTML = '';
|
208 |
|
209 |
fetch(`/autocomplete?q=` + encodeURIComponent(userMessage))
|
210 |
.then(response => response.json())
|
|
|
247 |
return similarities
|
248 |
|
249 |
@app.get('/autocomplete')
|
250 |
+
async def autocomplete(q: str = Query(..., title='query'), background_tasks: BackgroundTasks = BackgroundTasks()):
|
251 |
global message_history
|
252 |
message_history.append(('user', q))
|
253 |
|
254 |
+
suggestions = []
|
255 |
+
if q:
|
256 |
+
for key in redis_client.hkeys("responses"):
|
257 |
+
if q.lower() in key.lower():
|
258 |
+
suggestions.append(key)
|
259 |
+
|
260 |
+
# Lanzar la tarea en segundo plano utilizando la funci贸n decorada con @spaces.GPU()
|
261 |
+
background_tasks.add_task(generate_responses_gpu, q)
|
262 |
+
return {"status": "Processing request, please wait...", "suggestions": suggestions}
|
263 |
|
264 |
@app.get('/get_response')
|
265 |
async def get_response(q: str = Query(..., title='query')):
|
266 |
response = redis_client.hget("responses", q)
|
267 |
return {"response": response}
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
if __name__ == '__main__':
|
272 |
gpt2_models = [
|
|
|
283 |
"Salesforce/codegen-350M-multi"
|
284 |
]
|
285 |
|
286 |
+
kaggle_datasets = [
|
287 |
+
"uciml/iris",
|
288 |
+
"arshid/iris-flower-dataset",
|
289 |
+
"heesoo37/120-years-of-olympic-history-athletes-and-results"
|
290 |
+
]
|
291 |
+
|
292 |
load_and_store_models(gpt2_models + programming_models)
|
293 |
+
load_kaggle_datasets(kaggle_datasets)
|
294 |
|
295 |
uvicorn.run(app=app, host='0.0.0.0', port=int(os.getenv("PORT", 7860)))
|