Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, File, UploadFile, HTTPException, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import os | |
import logging | |
from transformers import pipeline | |
import hashlib | |
import ast | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
app.mount("/static", StaticFiles(directory="frontend"), name="static") | |
# Hugging Face pipeline setup | |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
if not HF_TOKEN: | |
logger.error("HUGGINGFACE_TOKEN environment variable not set.") | |
raise ValueError("HUGGINGFACE_TOKEN environment variable not set.") | |
try: | |
pipe = pipeline("text-generation", model="Salesforce/codegen-350M-mono", token=HF_TOKEN,clean_up_tokenization_spaces=True) | |
logger.info("Successfully loaded Salesforce/codegen-350M-mono model.") | |
except Exception as e: | |
logger.error(f"Failed to load model: {str(e)}") | |
raise | |
UPLOAD_DIR = "uploads" | |
os.makedirs(UPLOAD_DIR, exist_ok=True) | |
IMAGES_DIR = os.path.join("frontend", "images") | |
os.makedirs(IMAGES_DIR, exist_ok=True) | |
async def upload_file(file: UploadFile = File(...)): | |
if not file.filename.endswith(".xlsx"): | |
raise HTTPException(status_code=400, detail="File must be an Excel file (.xlsx)") | |
file_path = os.path.join(UPLOAD_DIR, file.filename) | |
with open(file_path, "wb") as buffer: | |
buffer.write(await file.read()) | |
logger.info(f"File uploaded: {file.filename}") | |
return {"filename": file.filename} | |
async def generate_visualization(prompt: str = Form(...), filename: str = Form(...)): | |
file_path = os.path.join(UPLOAD_DIR, filename) | |
if not os.path.exists(file_path): | |
raise HTTPException(status_code=404, detail="File not found on server.") | |
try: | |
df = pd.read_excel(file_path) | |
if df.empty: | |
raise ValueError("Excel file is empty.") | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error reading Excel file: {str(e)}") | |
# Your suggested input_text with slight enhancement | |
input_text = f""" | |
You are a Python expert in data visualization. The DataFrame 'df' has columns: {', '.join(df.columns)}. | |
Generate a Python script using Seaborn (sns) and Matplotlib (plt) to create a scatter plot based on the user's request: '{prompt}'. | |
Include plt.title(), plt.xlabel(), plt.ylabel(). Return only the code without explanations below ### CODE START ###: | |
### CODE START ### | |
""" | |
try: | |
logger.info("Starting code generation...") | |
generated_output = pipe(input_text, max_new_tokens=70, do_sample=False) | |
raw_generated_code = generated_output[0]['generated_text'].strip() | |
logger.info(f"Raw generated code: '{raw_generated_code}'") | |
except Exception as e: | |
logger.error(f"Error querying model: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error querying model: {str(e)}") | |
if not raw_generated_code.strip(): | |
logger.error("No code generated by the model.") | |
raise HTTPException(status_code=500, detail="No code generated by the model.") | |
# Cleaning: Extract code after ### CODE START ### | |
cleaned_code_lines = [] | |
code_started = False | |
for line in raw_generated_code.splitlines(): | |
line = line.strip() | |
if line == "### CODE START ###": | |
code_started = True | |
continue | |
if code_started and line and line.startswith(('sns.', 'plt.')): | |
for col in df.columns: | |
if col in line and f"'{col}'" not in line and f'"{col}"' not in line: | |
line = line.replace(col, f"'{col}'") | |
cleaned_code_lines.append(line) | |
cleaned_code = "\n".join(cleaned_code_lines).strip() | |
logger.info(f"Cleaned code: '{cleaned_code}'") | |
if not cleaned_code: | |
logger.error("Cleaned code is empty after filtering.") | |
raise HTTPException(status_code=500, detail="Generated code is empty or invalid") | |
try: | |
ast.parse(cleaned_code) | |
except SyntaxError as e: | |
logger.error(f"Syntax error in cleaned code: '{cleaned_code}' Exception: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Syntax error in generated code: {str(e)}") | |
plot_hash = hashlib.md5(f"{filename}_{prompt}".encode()).hexdigest()[:8] | |
plot_filename = f"plot_{plot_hash}.png" | |
plot_path = os.path.join(IMAGES_DIR, plot_filename) | |
try: | |
exec_globals = {"pd": pd, "plt": plt, "sns": sns, "df": df} | |
plt.close('all') | |
plt.clf() | |
plt.cla() | |
fig = plt.figure(figsize=(8, 6)) | |
exec(cleaned_code, exec_globals) | |
if not fig.get_axes(): | |
plt.close('all') | |
raise ValueError("Generated code produced an empty plot") | |
plt.savefig(plot_path, bbox_inches="tight") | |
logger.info(f"Plot saved to {plot_path}") | |
plt.close('all') | |
except Exception as e: | |
plt.close('all') | |
logger.error(f"Error executing cleaned code: '{cleaned_code}' Exception: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error executing code: {str(e)}") | |
if not os.path.exists(plot_path): | |
raise HTTPException(status_code=500, detail="Plot file was not created.") | |
plot_url = f"/static/images/{plot_filename}?t={int(pd.Timestamp.now().timestamp())}" | |
return {"plot_url": plot_url, "generated_code": raw_generated_code} | |
async def serve_frontend(): | |
with open("frontend/index.html", "r") as f: | |
return HTMLResponse(content=f.read()) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |