Yhhxhfh commited on
Commit
77c2378
verified
1 Parent(s): a1f943d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +198 -0
app.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import asyncio
4
+ import uvicorn
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from fastapi import FastAPI, Query, HTTPException
8
+ from fastapi.responses import HTMLResponse
9
+
10
+ # Configuraci贸n de logging
11
+ logging.basicConfig(level=logging.DEBUG)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Inicializar la aplicaci贸n FastAPI
15
+ app = FastAPI()
16
+
17
+ # Diccionario para almacenar los modelos
18
+ data_and_models_dict = {}
19
+
20
+ # Lista para almacenar el historial de mensajes
21
+ message_history = []
22
+
23
+ # Funci贸n para cargar modelos
24
+ async def load_models():
25
+ gpt_models = ["gpt2-medium", "gpt2-large", "gpt2"]
26
+ for model_name in gpt_models:
27
+ try:
28
+ model = AutoModelForCausalLM.from_pretrained(model_name)
29
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
30
+ logger.info(f"Successfully loaded {model_name} model")
31
+ return model, tokenizer
32
+ except Exception as e:
33
+ logger.error(f"Failed to load GPT-2 model: {e}")
34
+ raise HTTPException(status_code=500, detail="Failed to load any models")
35
+
36
+ # Funci贸n para descargar modelos
37
+ async def download_models():
38
+ model, tokenizer = await load_models()
39
+ data_and_models_dict['model'] = (model, tokenizer)
40
+
41
+ @app.get('/')
42
+ async def main():
43
+ html_code = """
44
+ <!DOCTYPE html>
45
+ <html lang="en">
46
+ <head>
47
+ <meta charset="UTF-8">
48
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
49
+ <title>ChatGPT Chatbot</title>
50
+ <style>
51
+ body, html {
52
+ height: 100%;
53
+ margin: 0;
54
+ padding: 0;
55
+ font-family: Arial, sans-serif;
56
+ }
57
+ .container {
58
+ height: 100%;
59
+ display: flex;
60
+ flex-direction: column;
61
+ justify-content: center;
62
+ align-items: center;
63
+ }
64
+ .chat-container {
65
+ border-radius: 10px;
66
+ overflow: hidden;
67
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
68
+ width: 100%;
69
+ height: 100%;
70
+ }
71
+ .chat-box {
72
+ height: calc(100% - 60px);
73
+ overflow-y: auto;
74
+ padding: 10px;
75
+ }
76
+ .chat-input {
77
+ width: calc(100% - 100px);
78
+ padding: 10px;
79
+ border: none;
80
+ border-top: 1px solid #ccc;
81
+ font-size: 16px;
82
+ }
83
+ .input-container {
84
+ display: flex;
85
+ align-items: center;
86
+ justify-content: space-between;
87
+ padding: 10px;
88
+ background-color: #f5f5f5;
89
+ border-top: 1px solid #ccc;
90
+ width: 100%;
91
+ }
92
+ button {
93
+ padding: 10px;
94
+ border: none;
95
+ cursor: pointer;
96
+ background-color: #007bff;
97
+ color: #fff;
98
+ font-size: 16px;
99
+ }
100
+ .user-message {
101
+ background-color: #cce5ff;
102
+ border-radius: 5px;
103
+ align-self: flex-end;
104
+ max-width: 70%;
105
+ margin-left: auto;
106
+ margin-right: 10px;
107
+ margin-bottom: 10px;
108
+ }
109
+ .bot-message {
110
+ background-color: #d1ecf1;
111
+ border-radius: 5px;
112
+ align-self: flex-start;
113
+ max-width: 70%;
114
+ margin-bottom: 10px;
115
+ }
116
+ </style>
117
+ </head>
118
+ <body>
119
+ <div class="container">
120
+ <div class="chat-container">
121
+ <div class="chat-box" id="chat-box"></div>
122
+ <div class="input-container">
123
+ <input type="text" class="chat-input" id="user-input" placeholder="Escribe un mensaje...">
124
+ <button onclick="sendMessage()">Enviar</button>
125
+ </div>
126
+ </div>
127
+ </div>
128
+ <script>
129
+ const chatBox = document.getElementById('chat-box');
130
+ const userInput = document.getElementById('user-input');
131
+
132
+ function saveMessage(sender, message) {
133
+ const messageElement = document.createElement('div');
134
+ messageElement.textContent = `${sender}: ${message}`;
135
+ messageElement.classList.add(`${sender}-message`);
136
+ chatBox.appendChild(messageElement);
137
+ userInput.value = '';
138
+ }
139
+
140
+ async function sendMessage() {
141
+ const userMessage = userInput.value.trim();
142
+ if (!userMessage) return;
143
+
144
+ saveMessage('user', userMessage);
145
+ await fetch(`/autocomplete?q=${userMessage}`)
146
+ .then(response => response.text())
147
+ .then(data => {
148
+ saveMessage('bot', data);
149
+ chatBox.scrollTop = chatBox.scrollHeight;
150
+ })
151
+ .catch(error => console.error('Error:', error));
152
+ }
153
+
154
+ userInput.addEventListener("keyup", function(event) {
155
+ if (event.keyCode === 13) {
156
+ event.preventDefault();
157
+ sendMessage();
158
+ }
159
+ });
160
+ </script>
161
+ </body>
162
+ </html>
163
+ """
164
+ return HTMLResponse(content=html_code, status_code=200)
165
+
166
+ # Ruta para la generaci贸n de respuestas
167
+ @app.get('/autocomplete')
168
+ async def autocomplete(q: str = Query(...)):
169
+ global data_and_models_dict, message_history
170
+
171
+ # Verificar si hay modelos cargados
172
+ if 'model' not in data_and_models_dict:
173
+ await download_models()
174
+
175
+ # Obtener el modelo
176
+ model, tokenizer = data_and_models_dict['model']
177
+
178
+ # Guardar el mensaje del usuario en el historial
179
+ message_history.append(q)
180
+
181
+ # Generar una respuesta utilizando el modelo
182
+ input_ids = tokenizer.encode(q, return_tensors="pt")
183
+ output = model.generate(input_ids, max_length=50, num_return_sequences=1)
184
+ response_text = tokenizer.decode(output[0], skip_special_tokens=True)
185
+
186
+ # Guardar la respuesta en el historial
187
+ message_history.append(response_text)
188
+
189
+ return response_text
190
+
191
+ # Funci贸n para ejecutar la aplicaci贸n sin reiniciarla
192
+ def run_app():
193
+ asyncio.run(download_models())
194
+ uvicorn.run(app, host='0.0.0.0', port=7860)
195
+
196
+ # Ejecutar la aplicaci贸n
197
+ if __name__ == "__main__":
198
+ run_app()