Spaces:
Runtime error
Runtime error
File size: 1,947 Bytes
c195a75 4c5cefd c195a75 0ec7e8c c195a75 fca0532 0ec7e8c c195a75 fca0532 c195a75 fca0532 c195a75 fca0532 c195a75 fca0532 c195a75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import logging
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(title="SQL Coder API")
# Set environment variable for cache directory
os.environ['TRANSFORMERS_CACHE'] = '/home/user/.cache/huggingface'
# Initialize pipeline
try:
pipe = pipeline(
"text-generation",
model="defog/llama-3-sqlcoder-8b",
device_map="auto",
model_kwargs={"torch_dtype": "auto"}
)
logger.info("Pipeline initialized successfully")
except Exception as e:
logger.error(f"Error initializing pipeline: {str(e)}")
raise
class ChatMessage(BaseModel):
role: str
content: str
class QueryRequest(BaseModel):
messages: list[ChatMessage]
max_new_tokens: int = 1024
temperature: float = 0.7
class QueryResponse(BaseModel):
generated_text: str
@app.post("/generate", response_model=QueryResponse)
async def generate(request: QueryRequest):
try:
# Format messages into a single string
formatted_prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages])
# Generate response using pipeline
response = pipe(
formatted_prompt,
max_new_tokens=request.max_new_tokens,
temperature=request.temperature,
do_sample=True,
num_return_sequences=1,
pad_token_id=pipe.tokenizer.eos_token_id
)
# Extract generated text
generated_text = response[0]['generated_text']
return QueryResponse(generated_text=generated_text)
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy"} |