kkulchatbot / app.py
Sirawitch's picture
Update app.py
33ee4b1 verified
raw
history blame
1.7 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
app = FastAPI()
model_name = "scb10x/llama-3-typhoon-v1.5-8b-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# ใช้ BitsAndBytes สำหรับ quantization
config = AutoConfig.from_pretrained(model_name)
config.quantization_config = BitsAndBytesConfig(load_in_8bit=True)
# โหลดโมเดลด้วย 8-bit quantization
model = AutoModelForCausalLM.from_pretrained(
model_name,
config=config,
device_map="auto",
torch_dtype=torch.float16,
)
class Query(BaseModel):
queryResult: Optional[dict] = None
queryText: Optional[str] = None
@app.post("/webhook")
async def webhook(query: Query):
try:
user_query = query.queryResult.get('queryText') if query.queryResult else query.queryText
if not user_query:
raise HTTPException(status_code=400, detail="No query text provided")
prompt = f"Human: {user_query}\nAI:"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(input_ids, max_new_tokens=100, temperature=0.7)
response = tokenizer.decode(output[0], skip_special_tokens=True)
ai_response = response.split("AI:")[-1].strip()
return {"fulfillmentText": ai_response}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)