|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
from init import ACCESS_TOKEN, SYSTEM_PROMPT |
|
from utils import extract_sql, is_sql |
|
from database import execute |
|
|
|
|
|
model_name = "Qwen/Qwen2.5-3B-Instruct" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
messages = [{"role": "system", "content": SYSTEM_PROMPT}] |
|
def respond(message, history, system_message, max_tokens, temperature, top_p): |
|
|
|
|
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
messages.append({"role": "user", "content": message}) |
|
|
|
|
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt") |
|
|
|
|
|
output_ids = model.generate( |
|
input_ids, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True |
|
) |
|
|
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
yield response |
|
|
|
|
|
if is_sql(response): |
|
sql_query = extract_sql(response) |
|
max_attempts = 3 |
|
attempts = 0 |
|
sql_result = None |
|
last_error = None |
|
|
|
while attempts < max_attempts: |
|
try: |
|
sql_result = execute(sql_query) |
|
break |
|
except Exception as e: |
|
last_error = str(e) |
|
attempts += 1 |
|
if attempts < max_attempts: |
|
clarification_prompt = f"Tôi gặp lỗi khi thực hiện truy vấn SQL: {last_error}\nBạn có thể chỉnh sửa câu hỏi hoặc cung cấp thêm thông tin không?" |
|
messages += [ |
|
{"role": "assistant", "content": response}, |
|
{"role": "user", "content": clarification_prompt}, |
|
] |
|
|
|
|
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device) |
|
|
|
|
|
output_ids = model.generate( |
|
input_ids, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True |
|
) |
|
|
|
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
yield response |
|
|
|
if is_sql(response): |
|
sql_query = extract_sql(response) |
|
else: |
|
retry_prompt = f"Tôi đã thử {max_attempts} lần nhưng vẫn gặp lỗi: {last_error}\nBạn có thể cung cấp thêm chi tiết về dữ liệu cần truy vấn không?" |
|
yield retry_prompt |
|
return |
|
|
|
if sql_result is not None: |
|
reformulation_prompt = f"Kết quả truy vấn SQL:\n{sql_result}\nHãy tóm tắt kết quả thành phản hồi tự nhiên." |
|
messages += [ |
|
{"role": "assistant", "content": response}, |
|
{"role": "user", "content": reformulation_prompt}, |
|
] |
|
|
|
input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device) |
|
|
|
output_ids = model.generate( |
|
input_ids, |
|
max_new_tokens=512, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True |
|
) |
|
|
|
reformulated_response = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
yield reformulated_response |
|
|