Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
st.set_page_config(layout="wide") # ⚠️ PHẢI đặt đầu tiên
|
3 |
+
|
4 |
+
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# Định nghĩa đường dẫn mô hình local
|
8 |
+
MODELS = {
|
9 |
+
"LLaMA 3": "llama3_dpo_final",
|
10 |
+
"Flan-T5": "flan-t5-small-vietnamese-ecommerce-alpaca"
|
11 |
+
}
|
12 |
+
|
13 |
+
# Kiểm tra thiết bị
|
14 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
|
16 |
+
# Load mô hình và tokenizer
|
17 |
+
@st.cache_resource
|
18 |
+
def load_model_and_tokenizer(model_name):
|
19 |
+
model_path = MODELS[model_name]
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
|
21 |
+
|
22 |
+
if "llama" in model_name.lower():
|
23 |
+
model = AutoModelForCausalLM.from_pretrained(
|
24 |
+
model_path,
|
25 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
26 |
+
device_map="auto" if device == "cuda" else None
|
27 |
+
).to(device)
|
28 |
+
else:
|
29 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
30 |
+
model_path,
|
31 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
32 |
+
device_map="auto" if device == "cuda" else None
|
33 |
+
).to(device)
|
34 |
+
|
35 |
+
return tokenizer, model
|
36 |
+
|
37 |
+
# Tải mô hình 1 lần
|
38 |
+
tokenizer_llama, model_llama = load_model_and_tokenizer("LLaMA 3")
|
39 |
+
tokenizer_t5, model_t5 = load_model_and_tokenizer("Flan-T5")
|
40 |
+
|
41 |
+
# Hàm sinh phản hồi
|
42 |
+
def generate_response(prompt, tokenizer, model, model_type):
|
43 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
44 |
+
|
45 |
+
if "llama" in model_type.lower():
|
46 |
+
output = model.generate(
|
47 |
+
**inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
output = model.generate(
|
51 |
+
**inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7,
|
52 |
+
decoder_start_token_id=tokenizer.pad_token_id
|
53 |
+
)
|
54 |
+
|
55 |
+
response = tokenizer.decode(output[0], skip_special_tokens=True)
|
56 |
+
return response
|
57 |
+
|
58 |
+
# Giao diện người dùng
|
59 |
+
st.title("🤖 So sánh Chatbot: LLaMA 3 vs. Flan-T5")
|
60 |
+
st.write("💬 Nhập 1 câu hỏi để xem 2 mô hình phản hồi khác nhau như thế nào!")
|
61 |
+
|
62 |
+
# Nhập từ người dùng
|
63 |
+
user_input = st.text_input("✍️ Nhập câu hỏi:")
|
64 |
+
|
65 |
+
# Khi có input
|
66 |
+
if user_input:
|
67 |
+
col1, col2 = st.columns(2)
|
68 |
+
|
69 |
+
with col1:
|
70 |
+
st.subheader("🔷 LLaMA 3")
|
71 |
+
with st.spinner("Đang sinh phản hồi từ LLaMA 3..."):
|
72 |
+
response_llama = generate_response(user_input, tokenizer_llama, model_llama, "llama")
|
73 |
+
st.markdown(f"**Phản hồi:**\n\n{response_llama}")
|
74 |
+
|
75 |
+
with col2:
|
76 |
+
st.subheader("🔶 Flan-T5")
|
77 |
+
with st.spinner("Đang sinh phản hồi từ Flan-T5..."):
|
78 |
+
response_t5 = generate_response(user_input, tokenizer_t5, model_t5, "t5")
|
79 |
+
st.markdown(f"**Phản hồi:**\n\n{response_t5}")
|