OpenCHAT-mini / app.py
ILLERRAPS's picture
Update app.py
ef3b11c verified
raw
history blame
6.93 kB
import asyncio
import json
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.responses import HTMLResponse
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy import create_engine, Column, Integer, String, MetaData, Table
from sqlalchemy.orm import sessionmaker
import gradio as gr
from transformers import pipeline
from PIL import Image
# Database Setup
DATABASE_URL = "sqlite:///chatbot.db"
engine = create_engine(DATABASE_URL)
Session = sessionmaker(bind=engine)
session = Session()
metadata = MetaData()
def create_table(table_name, columns):
if table_name in engine.table_names():
return f"Table '{table_name}' already exists."
columns_list = [Column('id', Integer, primary_key=True)]
for col_name, col_type in columns.items():
if col_type.lower() == 'string':
columns_list.append(Column(col_name, String))
elif col_type.lower() == 'integer':
columns_list.append(Column(col_name, Integer))
else:
return "Unsupported column type. Use 'String' or 'Integer'."
new_table = Table(table_name, metadata, *columns_list)
metadata.create_all(engine)
return f"Table '{table_name}' created successfully."
def edit_table(table_name, columns):
if table_name not in engine.table_names():
return f"Table '{table_name}' does not exist."
table = Table(table_name, metadata, autoload_with=engine)
for col_name, col_type in columns.items():
if col_name not in table.c:
if col_type.lower() == 'string':
new_column = Column(col_name, String)
elif col_type.lower() == 'integer':
new_column = Column(col_name, Integer)
else:
return "Unsupported column type. Use 'String' or 'Integer'."
new_column.create(table, populate_default=True)
return f"Table '{table_name}' updated successfully."
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class ConnectionManager:
def __init__(self):
self.active_connections: dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, username: str):
await websocket.accept()
self.active_connections[username] = websocket
def disconnect(self, username: str):
self.active_connections.pop(username, None)
async def broadcast(self, message: str):
for connection in self.active_connections.values():
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{username}")
async def websocket_endpoint(websocket: WebSocket, username: str, token: str = Depends(oauth2_scheme)):
await manager.connect(websocket, username)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(f"{username}: {data}")
except WebSocketDisconnect:
manager.disconnect(username)
@app.post("/token")
async def login():
# Simplified token generation for demo purposes
return {"access_token": "fake_token", "token_type": "bearer"}
@app.post("/chatbot")
async def chatbot(task: str, table_name: str, columns: str):
response, description = handle_chatbot(task, table_name, columns)
return {"result": response, "description": description}
@app.get("/")
async def get():
return HTMLResponse("""
<html>
<head>
<title>Real-time Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<input id="messageText" type="text" autocomplete="off"/>
<button onclick="sendMessage()">Send</button>
<ul id='messages'>
</ul>
<script>
var ws = new WebSocket("ws://localhost:8000/ws/test_user");
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
var content = document.createTextNode(event.data)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage() {
var input = document.getElementById("messageText")
ws.send(input.value)
input.value = ''
}
</script>
</body>
</html>
""")
# Image generation setup
image_generator = pipeline("image-generation", model="CompVis/stable-diffusion-v1-4")
# Helper functions
def chatbot_response(task, table_name=None, columns=None):
if task == "create_table":
if table_name and columns:
result = create_table(table_name, columns)
else:
result = "Please provide a table name and columns."
elif task == "edit_table":
if table_name and columns:
result = edit_table(table_name, columns)
else:
result = "Please provide a table name and columns."
else:
result = "Unsupported task. Use 'create_table' or 'edit_table'."
description = f"Task: {task}, Table Name: {table_name}, Columns: {columns}"
return result, description
def handle_chatbot(task, table_name, columns):
if task not in ['create_table', 'edit_table']:
return "Unsupported task. Use 'create_table' or 'edit_table'.", None
try:
columns_dict = json.loads(columns)
except json.JSONDecodeError:
return "Invalid columns format. Please use JSON format.", None
return chatbot_response(task, table_name, columns_dict)
def generate_image(description):
images = image_generator(description, num_return_sequences=1)
image = images[0]['image']
return image
# Gradio interface setup
def handle_gradio_chatbot(task, table_name, columns):
response, description = handle_chatbot(task, table_name, columns)
image = generate_image(description)
return response, image
task_input = gr.inputs.Textbox(lines=1, placeholder="Task (create_table or edit_table)")
table_name_input = gr.inputs.Textbox(lines=1, placeholder="Table Name")
columns_input = gr.inputs.Textbox(lines=2, placeholder="Columns (JSON format: {'column1': 'type', 'column2': 'type'})")
interface = gr.Interface(
fn=handle_gradio_chatbot,
inputs=[task_input, table_name_input, columns_input],
outputs=[gr.outputs.Textbox(), gr.outputs.Image(type="pil")],
title="Multiplayer Game with SQL Database and Image Generation",
description="A multiplayer game interface to create and edit SQL tables with image generation."
)
# Function to run FastAPI server in the background
def run_server():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# Run FastAPI server in a separate thread
import threading
threading.Thread(target=run_server, daemon=True).start()
# Launch Gradio interface
interface.launch()