testingspace / app.py
Ay-ouya's picture
Upload app.py
2b2cbd7 verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import logging
from transformers import pipeline
import hashlib
import ast
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="frontend"), name="static")
# Hugging Face pipeline setup
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if not HF_TOKEN:
logger.error("HUGGINGFACE_TOKEN environment variable not set.")
raise ValueError("HUGGINGFACE_TOKEN environment variable not set.")
try:
pipe = pipeline("text-generation", model="Salesforce/codegen-350M-mono", token=HF_TOKEN,clean_up_tokenization_spaces=True)
logger.info("Successfully loaded Salesforce/codegen-350M-mono model.")
except Exception as e:
logger.error(f"Failed to load model: {str(e)}")
raise
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
IMAGES_DIR = os.path.join("frontend", "images")
os.makedirs(IMAGES_DIR, exist_ok=True)
@app.post("/upload/")
async def upload_file(file: UploadFile = File(...)):
if not file.filename.endswith(".xlsx"):
raise HTTPException(status_code=400, detail="File must be an Excel file (.xlsx)")
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(await file.read())
logger.info(f"File uploaded: {file.filename}")
return {"filename": file.filename}
@app.post("/generate-visualization/")
async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)):
file_path = os.path.join(UPLOAD_DIR, filename)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="File not found on server.")
try:
df = pd.read_excel(file_path)
if df.empty:
raise ValueError("Excel file is empty.")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error reading Excel file: {str(e)}")
# Your suggested input_text with slight enhancement
input_text = f"""
You are a Python expert in data visualization. The DataFrame 'df' has columns: {', '.join(df.columns)}.
Generate a Python script using Seaborn (sns) and Matplotlib (plt) to create a scatter plot based on the user's request: '{prompt}'.
Include plt.title(), plt.xlabel(), plt.ylabel(). Return only the code without explanations below ### CODE START ###:
### CODE START ###
"""
try:
logger.info("Starting code generation...")
generated_output = pipe(input_text, max_new_tokens=70, do_sample=False)
raw_generated_code = generated_output[0]['generated_text'].strip()
logger.info(f"Raw generated code: '{raw_generated_code}'")
except Exception as e:
logger.error(f"Error querying model: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}")
if not raw_generated_code.strip():
logger.error("No code generated by the model.")
raise HTTPException(status_code=500, detail="No code generated by the model.")
# Cleaning: Extract code after ### CODE START ###
cleaned_code_lines = []
code_started = False
for line in raw_generated_code.splitlines():
line = line.strip()
if line == "### CODE START ###":
code_started = True
continue
if code_started and line and line.startswith(('sns.', 'plt.')):
for col in df.columns:
if col in line and f"'{col}'" not in line and f'"{col}"' not in line:
line = line.replace(col, f"'{col}'")
cleaned_code_lines.append(line)
cleaned_code = "\n".join(cleaned_code_lines).strip()
logger.info(f"Cleaned code: '{cleaned_code}'")
if not cleaned_code:
logger.error("Cleaned code is empty after filtering.")
raise HTTPException(status_code=500, detail="Generated code is empty or invalid")
try:
ast.parse(cleaned_code)
except SyntaxError as e:
logger.error(f"Syntax error in cleaned code: '{cleaned_code}' Exception: {str(e)}")
raise HTTPException(status_code=500, detail=f"Syntax error in generated code: {str(e)}")
plot_hash = hashlib.md5(f"{filename}_{prompt}".encode()).hexdigest()[:8]
plot_filename = f"plot_{plot_hash}.png"
plot_path = os.path.join(IMAGES_DIR, plot_filename)
try:
exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df}
plt.close('all')
plt.clf()
plt.cla()
fig = plt.figure(figsize=(8, 6))
exec(cleaned_code, exec_globals)
if not fig.get_axes():
plt.close('all')
raise ValueError("Generated code produced an empty plot")
plt.savefig(plot_path, bbox_inches="tight")
logger.info(f"Plot saved to {plot_path}")
plt.close('all')
except Exception as e:
plt.close('all')
logger.error(f"Error executing cleaned code: '{cleaned_code}' Exception: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}")
if not os.path.exists(plot_path):
raise HTTPException(status_code=500, detail="Plot file was not created.")
plot_url = f"/static/images/{plot_filename}?t={int(pd.Timestamp.now().timestamp())}"
return {"plot_url": plot_url, "generated_code": raw_generated_code}
@app.get("/")
async def serve_frontend():
with open("frontend/index.html", "r") as f:
return HTMLResponse(content=f.read())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)