|
import streamlit as st |
|
st.set_page_config(layout="wide") |
|
|
|
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
|
|
|
|
MODELS = { |
|
"LLaMA 3": "llama3_dpo_final", |
|
"Flan-T5": "flan-t5-small-vietnamese-ecommerce-alpaca" |
|
} |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(model_name): |
|
model_path = MODELS[model_name] |
|
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) |
|
|
|
if "llama" in model_name.lower(): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
device_map="auto" if device == "cuda" else None |
|
).to(device) |
|
else: |
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
model_path, |
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
device_map="auto" if device == "cuda" else None |
|
).to(device) |
|
|
|
return tokenizer, model |
|
|
|
|
|
tokenizer_llama, model_llama = load_model_and_tokenizer("LLaMA 3") |
|
tokenizer_t5, model_t5 = load_model_and_tokenizer("Flan-T5") |
|
|
|
|
|
def generate_response(prompt, tokenizer, model, model_type): |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
if "llama" in model_type.lower(): |
|
output = model.generate( |
|
**inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7 |
|
) |
|
else: |
|
output = model.generate( |
|
**inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7, |
|
decoder_start_token_id=tokenizer.pad_token_id |
|
) |
|
|
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
return response |
|
|
|
|
|
st.title("🤖 So sánh Chatbot: LLaMA 3 vs. Flan-T5") |
|
st.write("💬 Nhập 1 câu hỏi để xem 2 mô hình phản hồi khác nhau như thế nào!") |
|
|
|
|
|
user_input = st.text_input("✍️ Nhập câu hỏi:") |
|
|
|
|
|
if user_input: |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.subheader("🔷 LLaMA 3") |
|
with st.spinner("Đang sinh phản hồi từ LLaMA 3..."): |
|
response_llama = generate_response(user_input, tokenizer_llama, model_llama, "llama") |
|
st.markdown(f"**Phản hồi:**\n\n{response_llama}") |
|
|
|
with col2: |
|
st.subheader("🔶 Flan-T5") |
|
with st.spinner("Đang sinh phản hồi từ Flan-T5..."): |
|
response_t5 = generate_response(user_input, tokenizer_t5, model_t5, "t5") |
|
st.markdown(f"**Phản hồi:**\n\n{response_t5}") |
|
|