File size: 1,947 Bytes
c195a75
 
4c5cefd
c195a75
0ec7e8c
c195a75
 
 
 
 
 
 
fca0532
 
0ec7e8c
c195a75
 
fca0532
 
 
 
 
 
c195a75
 
 
 
 
 
 
 
 
 
 
fca0532
c195a75
 
 
 
 
 
 
 
 
 
 
 
 
 
fca0532
c195a75
 
fca0532
 
c195a75
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import logging
import os

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="SQL Coder API")

# Set environment variable for cache directory
os.environ['TRANSFORMERS_CACHE'] = '/home/user/.cache/huggingface'

# Initialize pipeline
try:
    pipe = pipeline(
        "text-generation",
        model="defog/llama-3-sqlcoder-8b",
        device_map="auto",
        model_kwargs={"torch_dtype": "auto"}
    )
    logger.info("Pipeline initialized successfully")
except Exception as e:
    logger.error(f"Error initializing pipeline: {str(e)}")
    raise

class ChatMessage(BaseModel):
    role: str
    content: str

class QueryRequest(BaseModel):
    messages: list[ChatMessage]
    max_new_tokens: int = 1024
    temperature: float = 0.7

class QueryResponse(BaseModel):
    generated_text: str

@app.post("/generate", response_model=QueryResponse)
async def generate(request: QueryRequest):
    try:
        # Format messages into a single string
        formatted_prompt = "\n".join([f"{msg.role}: {msg.content}" for msg in request.messages])
        
        # Generate response using pipeline
        response = pipe(
            formatted_prompt,
            max_new_tokens=request.max_new_tokens,
            temperature=request.temperature,
            do_sample=True,
            num_return_sequences=1,
            pad_token_id=pipe.tokenizer.eos_token_id
        )
        
        # Extract generated text
        generated_text = response[0]['generated_text']
        
        return QueryResponse(generated_text=generated_text)
    
    except Exception as e:
        logger.error(f"Error generating response: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy"}