amiguel commited on
Commit
e476c2e
Β·
verified Β·
1 Parent(s): 037c4ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +217 -56
app.py CHANGED
@@ -1,62 +1,223 @@
1
- # Load model if not already loaded or if model type changed
2
- if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
3
- model_data = load_model(hf_token, model_type, selected_model)
4
- if model_data is None:
5
- st.error("Failed to load model. Please check your token and try again.")
6
- st.stop()
7
-
8
- st.session_state.model, st.session_state.tokenizer = model_data
9
- st.session_state.model_type = model_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- model = st.session_state.model
12
- tokenizer = st.session_state.tokenizer
 
 
 
 
 
13
 
14
- # Add user message
15
- with st.chat_message("user", avatar=USER_AVATAR):
16
- st.markdown(prompt)
17
- st.session_state.messages.append({"role": "user", "content": prompt})
18
 
19
- # Process file
20
- file_context = process_file(uploaded_file)
 
21
 
22
- # Generate response with KV caching
23
- if model and tokenizer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
- with st.chat_message("assistant", avatar=BOT_AVATAR):
26
- start_time = time.time()
27
- streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
28
-
29
- response_container = st.empty()
30
- full_response = ""
31
-
32
- for chunk in streamer:
33
- cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
34
- full_response += cleaned_chunk + " "
35
- response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
36
-
37
- # Calculate performance metrics
38
- end_time = time.time()
39
- input_tokens = len(tokenizer(prompt)["input_ids"])
40
- output_tokens = len(tokenizer(full_response)["input_ids"])
41
- speed = output_tokens / (end_time - start_time)
42
-
43
- # Calculate costs (hypothetical pricing model)
44
- input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
45
- output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens
46
- total_cost_usd = input_cost + output_cost
47
- total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza)
48
-
49
- # Display metrics
50
- st.caption(
51
- f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
52
- f"πŸ•’ Speed: {speed:.1f}t/s | πŸ’° Cost (USD): ${total_cost_usd:.4f} | "
53
- f"πŸ’΅ Cost (AOA): {total_cost_aoa:.4f}"
54
- )
55
-
56
- response_container.markdown(full_response)
57
- st.session_state.messages.append({"role": "assistant", "content": full_response})
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  except Exception as e:
60
- st.error(f"⚑ Generation error: {str(e)}")
61
- else:
62
- st.error("πŸ€– Model not loaded!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
+ from huggingface_hub import login
4
+ from threading import Thread
5
+ import PyPDF2
6
+ import pandas as pd
7
+ import torch
8
+ import time
9
+
10
+ # Check if 'peft' is installed
11
+ try:
12
+ from peft import PeftModel, PeftConfig
13
+ except ImportError:
14
+ raise ImportError(
15
+ "The 'peft' library is required but not installed. "
16
+ "Please install it using: `pip install peft`"
17
+ )
18
+
19
+ # Set page configuration
20
+ st.set_page_config(
21
+ page_title="WizNerd Insp",
22
+ page_icon="πŸš€",
23
+ layout="centered"
24
+ )
25
+
26
+ # Hardcoded Hugging Face token (replace with your actual token)
27
+ HF_TOKEN = "your_hugging_face_token_here"
28
 
29
+ # Model names
30
+ BASE_MODEL_NAME = "google-bert/bert-base-uncased"
31
+ MODEL_OPTIONS = {
32
+ "Full Fine-Tuned": "amiguel/instruct_BERT-base-uncased_model",
33
+ "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
34
+ "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" # Hypothetical, adjust if needed
35
+ }
36
 
37
+ # Title with rocket emojis
38
+ st.title("πŸš€ WizNerd Insp πŸš€")
 
 
39
 
40
+ # Configure Avatars
41
+ USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
42
+ BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"
43
 
44
+ # Sidebar configuration
45
+ with st.sidebar:
46
+ st.header("Model Selection πŸ€–")
47
+ model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
48
+ selected_model = MODEL_OPTIONS[model_type]
49
+
50
+ st.header("Upload Documents πŸ“‚")
51
+ uploaded_file = st.file_uploader(
52
+ "Choose a PDF or XLSX file",
53
+ type=["pdf", "xlsx"],
54
+ label_visibility="collapsed"
55
+ )
56
+
57
+ # Initialize chat history
58
+ if "messages" not in st.session_state:
59
+ st.session_state.messages = []
60
+
61
+ # File processing function
62
+ @st.cache_data
63
+ def process_file(uploaded_file):
64
+ if uploaded_file is None:
65
+ return ""
66
+
67
  try:
68
+ if uploaded_file.type == "application/pdf":
69
+ pdf_reader = PyPDF2.PdfReader(uploaded_file)
70
+ return "\n".join([page.extract_text() for page in pdf_reader.pages])
71
+ elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
72
+ df = pd.read_excel(uploaded_file)
73
+ return df.to_markdown()
74
+ except Exception as e:
75
+ st.error(f"πŸ“„ Error processing file: {str(e)}")
76
+ return ""
77
+
78
+ # Model loading function
79
+ @st.cache_resource
80
+ def load_model(hf_token, model_type, selected_model):
81
+ try:
82
+ if not hf_token:
83
+ st.error("πŸ” Authentication required! Please provide a valid Hugging Face token.")
84
+ return None
85
+
86
+ login(token=hf_token)
87
+
88
+ # Load tokenizer
89
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token)
90
+
91
+ # Determine device
92
+ device = "cuda" if torch.cuda.is_available() else "cpu"
93
+
94
+ # Load model based on type
95
+ if model_type == "Full Fine-Tuned":
96
+ # Load full fine-tuned model directly
97
+ model = AutoModelForCausalLM.from_pretrained(
98
+ selected_model,
99
+ torch_dtype=torch.bfloat16,
100
+ token=hf_token
101
+ ).to(device)
102
+ else:
103
+ # Load base model and apply PEFT adapter
104
+ base_model = AutoModelForCausalLM.from_pretrained(
105
+ BASE_MODEL_NAME,
106
+ torch_dtype=torch.bfloat16,
107
+ token=hf_token
108
+ ).to(device)
109
+ model = PeftModel.from_pretrained(
110
+ base_model,
111
+ selected_model,
112
+ torch_dtype=torch.bfloat16,
113
+ is_trainable=False, # Inference mode
114
+ token=hf_token
115
+ ).to(device)
116
+
117
+ return model, tokenizer
118
+
119
  except Exception as e:
120
+ st.error(f"πŸ€– Model loading failed: {str(e)}")
121
+ return None
122
+
123
+ # Generation function with KV caching
124
+ def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
125
+ full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
126
+
127
+ streamer = TextIteratorStreamer(
128
+ tokenizer,
129
+ skip_prompt=True,
130
+ skip_special_tokens=True
131
+ )
132
+
133
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
134
+
135
+ generation_kwargs = {
136
+ "input_ids": inputs["input_ids"],
137
+ "attention_mask": inputs["attention_mask"],
138
+ "max_new_tokens": 1024,
139
+ "temperature": 0.7,
140
+ "top_p": 0.9,
141
+ "repetition_penalty": 1.1,
142
+ "do_sample": True,
143
+ "use_cache": use_cache,
144
+ "streamer": streamer
145
+ }
146
+
147
+ Thread(target=model.generate, kwargs=generation_kwargs).start()
148
+ return streamer
149
+
150
+ # Display chat messages
151
+ for message in st.session_state.messages:
152
+ try:
153
+ avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
154
+ with st.chat_message(message["role"], avatar=avatar):
155
+ st.markdown(message["content"])
156
+ except:
157
+ with st.chat_message(message["role"]):
158
+ st.markdown(message["content"])
159
+
160
+ # Chat input handling
161
+ if prompt := st.chat_input("Ask your inspection question..."):
162
+ # Load model if not already loaded or if model type changed
163
+ if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
164
+ model_data = load_model(HF_TOKEN, model_type, selected_model)
165
+ if model_data is None:
166
+ st.error("Failed to load model. Please check your token and try again.")
167
+ st.stop()
168
+
169
+ st.session_state.model, st.session_state.tokenizer = model_data
170
+ st.session_state.model_type = model_type
171
+
172
+ model = st.session_state.model
173
+ tokenizer = st.session_state.tokenizer
174
+
175
+ # Add user message
176
+ with st.chat_message("user", avatar=USER_AVATAR):
177
+ st.markdown(prompt)
178
+ st.session_state.messages.append({"role": "user", "content": prompt})
179
+
180
+ # Process file
181
+ file_context = process_file(uploaded_file)
182
+
183
+ # Generate response with KV caching
184
+ if model and tokenizer:
185
+ try:
186
+ with st.chat_message("assistant", avatar=BOT_AVATAR):
187
+ start_time = time.time()
188
+ streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
189
+
190
+ response_container = st.empty()
191
+ full_response = ""
192
+
193
+ for chunk in streamer:
194
+ cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
195
+ full_response += cleaned_chunk + " "
196
+ response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
197
+
198
+ # Calculate performance metrics
199
+ end_time = time.time()
200
+ input_tokens = len(tokenizer(prompt)["input_ids"])
201
+ output_tokens = len(tokenizer(full_response)["input_ids"])
202
+ speed = output_tokens / (end_time - start_time)
203
+
204
+ # Calculate costs (hypothetical pricing model)
205
+ input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens
206
+ output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens
207
+ total_cost_usd = input_cost + output_cost
208
+ total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza)
209
+
210
+ # Display metrics
211
+ st.caption(
212
+ f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
213
+ f"πŸ•’ Speed: {speed:.1f}t/s | πŸ’° Cost (USD): ${total_cost_usd:.4f} | "
214
+ f"πŸ’΅ Cost (AOA): {total_cost_aoa:.4f}"
215
+ )
216
+
217
+ response_container.markdown(full_response)
218
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
219
+
220
+ except Exception as e:
221
+ st.error(f"⚑ Generation error: {str(e)}")
222
+ else:
223
+ st.error("πŸ€– Model not loaded!")