llm-swp / handler.py
niruemon's picture
Update handler.py
3251ecc verified
raw
history blame
1.84 kB
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import os
class EndpointHandler:
def __init__(self, path=""):
# ระบุชื่อโมเดลใน Hugging Face Hub
model_name = "niruemon/llm-swp"
# กำหนดไดเรกทอรีสำหรับการ offload โมเดล (สร้างขึ้นถ้ายังไม่มี)
offload_dir = "./offload"
os.makedirs(offload_dir, exist_ok=True)
# โหลดโมเดลและ tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16,
offload_folder=offload_dir,
offload_state_dict=True # เพิ่มพารามิเตอร์นี้เพื่อจัดการการ offload ให้ดียิ่งขึ้น
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# สร้าง pipeline สำหรับการสร้างข้อความ
self.generator = pipeline("text-generation", model=self.model, tokenizer=self.tokenizer, device_map="auto")
def __call__(self, data):
# รับข้อความ input จากผู้ใช้
input_text = data.get("inputs", "")
if not input_text:
return {"error": "No input text provided."}
# สร้างข้อความโดยใช้โมเดล
try:
result = self.generator(input_text, max_length=150, num_return_sequences=1)
generated_text = result[0]["generated_text"]
return {"generated_text": generated_text}
except Exception as e:
return {"error": str(e)}