Saiky2k commited on
Commit
a42a06a
·
verified ·
1 Parent(s): 276286c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ st.set_page_config(layout="wide") # ⚠️ PHẢI đặt đầu tiên
3
+
4
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
5
+ import torch
6
+
7
+ # Định nghĩa đường dẫn mô hình local
8
+ MODELS = {
9
+ "LLaMA 3": "llama3_dpo_final",
10
+ "Flan-T5": "flan-t5-small-vietnamese-ecommerce-alpaca"
11
+ }
12
+
13
+ # Kiểm tra thiết bị
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Load mô hình và tokenizer
17
+ @st.cache_resource
18
+ def load_model_and_tokenizer(model_name):
19
+ model_path = MODELS[model_name]
20
+ tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
21
+
22
+ if "llama" in model_name.lower():
23
+ model = AutoModelForCausalLM.from_pretrained(
24
+ model_path,
25
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
26
+ device_map="auto" if device == "cuda" else None
27
+ ).to(device)
28
+ else:
29
+ model = AutoModelForSeq2SeqLM.from_pretrained(
30
+ model_path,
31
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
32
+ device_map="auto" if device == "cuda" else None
33
+ ).to(device)
34
+
35
+ return tokenizer, model
36
+
37
+ # Tải mô hình 1 lần
38
+ tokenizer_llama, model_llama = load_model_and_tokenizer("LLaMA 3")
39
+ tokenizer_t5, model_t5 = load_model_and_tokenizer("Flan-T5")
40
+
41
+ # Hàm sinh phản hồi
42
+ def generate_response(prompt, tokenizer, model, model_type):
43
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
44
+
45
+ if "llama" in model_type.lower():
46
+ output = model.generate(
47
+ **inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7
48
+ )
49
+ else:
50
+ output = model.generate(
51
+ **inputs, max_length=512, do_sample=True, top_p=0.9, temperature=0.7,
52
+ decoder_start_token_id=tokenizer.pad_token_id
53
+ )
54
+
55
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
56
+ return response
57
+
58
+ # Giao diện người dùng
59
+ st.title("🤖 So sánh Chatbot: LLaMA 3 vs. Flan-T5")
60
+ st.write("💬 Nhập 1 câu hỏi để xem 2 mô hình phản hồi khác nhau như thế nào!")
61
+
62
+ # Nhập từ người dùng
63
+ user_input = st.text_input("✍️ Nhập câu hỏi:")
64
+
65
+ # Khi có input
66
+ if user_input:
67
+ col1, col2 = st.columns(2)
68
+
69
+ with col1:
70
+ st.subheader("🔷 LLaMA 3")
71
+ with st.spinner("Đang sinh phản hồi từ LLaMA 3..."):
72
+ response_llama = generate_response(user_input, tokenizer_llama, model_llama, "llama")
73
+ st.markdown(f"**Phản hồi:**\n\n{response_llama}")
74
+
75
+ with col2:
76
+ st.subheader("🔶 Flan-T5")
77
+ with st.spinner("Đang sinh phản hồi từ Flan-T5..."):
78
+ response_t5 = generate_response(user_input, tokenizer_t5, model_t5, "t5")
79
+ st.markdown(f"**Phản hồi:**\n\n{response_t5}")