OpenCHAT-mini / app.py
ILLERRAPS's picture
Update app.py
ef3b11c verified
import asyncio
import json
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.responses import HTMLResponse
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy import create_engine, Column, Integer, String, MetaData, Table
from sqlalchemy.orm import sessionmaker
import gradio as gr
from transformers import pipeline
from PIL import Image
# Database Setup
DATABASE_URL = "sqlite:///chatbot.db"
engine = create_engine(DATABASE_URL)
Session = sessionmaker(bind=engine)
session = Session()
metadata = MetaData()
def create_table(table_name, columns):
if table_name in engine.table_names():
return f"Table '{table_name}' already exists."
columns_list = [Column('id', Integer, primary_key=True)]
for col_name, col_type in columns.items():
if col_type.lower() == 'string':
columns_list.append(Column(col_name, String))
elif col_type.lower() == 'integer':
columns_list.append(Column(col_name, Integer))
else:
return "Unsupported column type. Use 'String' or 'Integer'."
new_table = Table(table_name, metadata, *columns_list)
metadata.create_all(engine)
return f"Table '{table_name}' created successfully."
def edit_table(table_name, columns):
if table_name not in engine.table_names():
return f"Table '{table_name}' does not exist."
table = Table(table_name, metadata, autoload_with=engine)
for col_name, col_type in columns.items():
if col_name not in table.c:
if col_type.lower() == 'string':
new_column = Column(col_name, String)
elif col_type.lower() == 'integer':
new_column = Column(col_name, Integer)
else:
return "Unsupported column type. Use 'String' or 'Integer'."
new_column.create(table, populate_default=True)
return f"Table '{table_name}' updated successfully."
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
class ConnectionManager:
def __init__(self):
self.active_connections: dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, username: str):
await websocket.accept()
self.active_connections[username] = websocket
def disconnect(self, username: str):
self.active_connections.pop(username, None)
async def broadcast(self, message: str):
for connection in self.active_connections.values():
await connection.send_text(message)
manager = ConnectionManager()
@app.websocket("/ws/{username}")
async def websocket_endpoint(websocket: WebSocket, username: str, token: str = Depends(oauth2_scheme)):
await manager.connect(websocket, username)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(f"{username}: {data}")
except WebSocketDisconnect:
manager.disconnect(username)
@app.post("/token")
async def login():
# Simplified token generation for demo purposes
return {"access_token": "fake_token", "token_type": "bearer"}
@app.post("/chatbot")
async def chatbot(task: str, table_name: str, columns: str):
response, description = handle_chatbot(task, table_name, columns)
return {"result": response, "description": description}
@app.get("/")
async def get():
return HTMLResponse("""
<html>
<head>
<title>Real-time Chat</title>
</head>
<body>
<h1>WebSocket Chat</h1>
<input id="messageText" type="text" autocomplete="off"/>
<button onclick="sendMessage()">Send</button>
<ul id='messages'>
</ul>
<script>
var ws = new WebSocket("ws://localhost:8000/ws/test_user");
ws.onmessage = function(event) {
var messages = document.getElementById('messages')
var message = document.createElement('li')
var content = document.createTextNode(event.data)
message.appendChild(content)
messages.appendChild(message)
};
function sendMessage() {
var input = document.getElementById("messageText")
ws.send(input.value)
input.value = ''
}
</script>
</body>
</html>
""")
# Image generation setup
image_generator = pipeline("image-generation", model="CompVis/stable-diffusion-v1-4")
# Helper functions
def chatbot_response(task, table_name=None, columns=None):
if task == "create_table":
if table_name and columns:
result = create_table(table_name, columns)
else:
result = "Please provide a table name and columns."
elif task == "edit_table":
if table_name and columns:
result = edit_table(table_name, columns)
else:
result = "Please provide a table name and columns."
else:
result = "Unsupported task. Use 'create_table' or 'edit_table'."
description = f"Task: {task}, Table Name: {table_name}, Columns: {columns}"
return result, description
def handle_chatbot(task, table_name, columns):
if task not in ['create_table', 'edit_table']:
return "Unsupported task. Use 'create_table' or 'edit_table'.", None
try:
columns_dict = json.loads(columns)
except json.JSONDecodeError:
return "Invalid columns format. Please use JSON format.", None
return chatbot_response(task, table_name, columns_dict)
def generate_image(description):
images = image_generator(description, num_return_sequences=1)
image = images[0]['image']
return image
# Gradio interface setup
def handle_gradio_chatbot(task, table_name, columns):
response, description = handle_chatbot(task, table_name, columns)
image = generate_image(description)
return response, image
task_input = gr.inputs.Textbox(lines=1, placeholder="Task (create_table or edit_table)")
table_name_input = gr.inputs.Textbox(lines=1, placeholder="Table Name")
columns_input = gr.inputs.Textbox(lines=2, placeholder="Columns (JSON format: {'column1': 'type', 'column2': 'type'})")
interface = gr.Interface(
fn=handle_gradio_chatbot,
inputs=[task_input, table_name_input, columns_input],
outputs=[gr.outputs.Textbox(), gr.outputs.Image(type="pil")],
title="Multiplayer Game with SQL Database and Image Generation",
description="A multiplayer game interface to create and edit SQL tables with image generation."
)
# Function to run FastAPI server in the background
def run_server():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
# Run FastAPI server in a separate thread
import threading
threading.Thread(target=run_server, daemon=True).start()
# Launch Gradio interface
interface.launch()