Saiky2k's picture
Create app.py
a42a06a verified
import streamlit as st
st.set_page_config(layout="wide") # ⚠️ PHẢI đặt đầu tiên
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
# Định nghĩa đường dẫn mô hình local
MODELS = {
"LLaMA 3": "llama3_dpo_final",
"Flan-T5": "flan-t5-small-vietnamese-ecommerce-alpaca"
}
# Kiểm tra thiết bị
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load mô hình và tokenizer
@st.cache_resource
def load_model_and_tokenizer(model_name):
model_path = MODELS[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
if "llama" in model_name.lower():
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
).to(device)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(
model_path,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
).to(device)
return tokenizer, model
# Tải mô hình 1 lần
tokenizer_llama, model_llama = load_model_and_tokenizer("LLaMA 3")
tokenizer_t5, model_t5 = load_model_and_tokenizer("Flan-T5")
# Hàm sinh phản hồi
def generate_response(prompt, tokenizer, model, model_type):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
if "llama" in model_type.lower():
output = model.generate(
**inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7
)
else:
output = model.generate(
**inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7,
decoder_start_token_id=tokenizer.pad_token_id
)
response = tokenizer.decode(output[0], skip_special_tokens=True)
return response
# Giao diện người dùng
st.title("🤖 So sánh Chatbot: LLaMA 3 vs. Flan-T5")
st.write("💬 Nhập 1 câu hỏi để xem 2 mô hình phản hồi khác nhau như thế nào!")
# Nhập từ người dùng
user_input = st.text_input("✍️ Nhập câu hỏi:")
# Khi có input
if user_input:
col1, col2 = st.columns(2)
with col1:
st.subheader("🔷 LLaMA 3")
with st.spinner("Đang sinh phản hồi từ LLaMA 3..."):
response_llama = generate_response(user_input, tokenizer_llama, model_llama, "llama")
st.markdown(f"**Phản hồi:**\n\n{response_llama}")
with col2:
st.subheader("🔶 Flan-T5")
with st.spinner("Đang sinh phản hồi từ Flan-T5..."):
response_t5 = generate_response(user_input, tokenizer_t5, model_t5, "t5")
st.markdown(f"**Phản hồi:**\n\n{response_t5}")