File size: 8,250 Bytes
e476c2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037c4ae
e476c2e
 
 
 
 
 
 
037c4ae
e476c2e
 
037c4ae
e476c2e
 
 
037c4ae
e476c2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037c4ae
e476c2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
037c4ae
e476c2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from huggingface_hub import login
from threading import Thread
import PyPDF2
import pandas as pd
import torch
import time

# Check if 'peft' is installed
try:
    from peft import PeftModel, PeftConfig
except ImportError:
    raise ImportError(
        "The 'peft' library is required but not installed. "
        "Please install it using: `pip install peft`"
    )

# Set page configuration
st.set_page_config(
    page_title="WizNerd Insp",
    page_icon="πŸš€",
    layout="centered"
)

# Hardcoded Hugging Face token (replace with your actual token)
HF_TOKEN = "your_hugging_face_token_here"

# Model names
BASE_MODEL_NAME = "google-bert/bert-base-uncased"
MODEL_OPTIONS = {
    "Full Fine-Tuned": "amiguel/instruct_BERT-base-uncased_model",
    "LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora",
    "QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora"  # Hypothetical, adjust if needed
}

# Title with rocket emojis
st.title("πŸš€ WizNerd Insp πŸš€")

# Configure Avatars
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png"
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg"

# Sidebar configuration
with st.sidebar:
    st.header("Model Selection πŸ€–")
    model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0)
    selected_model = MODEL_OPTIONS[model_type]
    
    st.header("Upload Documents πŸ“‚")
    uploaded_file = st.file_uploader(
        "Choose a PDF or XLSX file",
        type=["pdf", "xlsx"],
        label_visibility="collapsed"
    )

# Initialize chat history
if "messages" not in st.session_state:
    st.session_state.messages = []

# File processing function
@st.cache_data
def process_file(uploaded_file):
    if uploaded_file is None:
        return ""
    
    try:
        if uploaded_file.type == "application/pdf":
            pdf_reader = PyPDF2.PdfReader(uploaded_file)
            return "\n".join([page.extract_text() for page in pdf_reader.pages])
        elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
            df = pd.read_excel(uploaded_file)
            return df.to_markdown()
    except Exception as e:
        st.error(f"πŸ“„ Error processing file: {str(e)}")
        return ""

# Model loading function
@st.cache_resource
def load_model(hf_token, model_type, selected_model):
    try:
        if not hf_token:
            st.error("πŸ” Authentication required! Please provide a valid Hugging Face token.")
            return None
        
        login(token=hf_token)
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token)
        
        # Determine device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # Load model based on type
        if model_type == "Full Fine-Tuned":
            # Load full fine-tuned model directly
            model = AutoModelForCausalLM.from_pretrained(
                selected_model,
                torch_dtype=torch.bfloat16,
                token=hf_token
            ).to(device)
        else:
            # Load base model and apply PEFT adapter
            base_model = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL_NAME,
                torch_dtype=torch.bfloat16,
                token=hf_token
            ).to(device)
            model = PeftModel.from_pretrained(
                base_model,
                selected_model,
                torch_dtype=torch.bfloat16,
                is_trainable=False,  # Inference mode
                token=hf_token
            ).to(device)
        
        return model, tokenizer
        
    except Exception as e:
        st.error(f"πŸ€– Model loading failed: {str(e)}")
        return None

# Generation function with KV caching
def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True):
    full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:"
    
    streamer = TextIteratorStreamer(
        tokenizer, 
        skip_prompt=True, 
        skip_special_tokens=True
    )
    
    inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device)
    
    generation_kwargs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "max_new_tokens": 1024,
        "temperature": 0.7,
        "top_p": 0.9,
        "repetition_penalty": 1.1,
        "do_sample": True,
        "use_cache": use_cache,
        "streamer": streamer
    }
    
    Thread(target=model.generate, kwargs=generation_kwargs).start()
    return streamer

# Display chat messages
for message in st.session_state.messages:
    try:
        avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR
        with st.chat_message(message["role"], avatar=avatar):
            st.markdown(message["content"])
    except:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

# Chat input handling
if prompt := st.chat_input("Ask your inspection question..."):
    # Load model if not already loaded or if model type changed
    if "model" not in st.session_state or st.session_state.get("model_type") != model_type:
        model_data = load_model(HF_TOKEN, model_type, selected_model)
        if model_data is None:
            st.error("Failed to load model. Please check your token and try again.")
            st.stop()
            
        st.session_state.model, st.session_state.tokenizer = model_data
        st.session_state.model_type = model_type
    
    model = st.session_state.model
    tokenizer = st.session_state.tokenizer
    
    # Add user message
    with st.chat_message("user", avatar=USER_AVATAR):
        st.markdown(prompt)
    st.session_state.messages.append({"role": "user", "content": prompt})

    # Process file
    file_context = process_file(uploaded_file)
    
    # Generate response with KV caching
    if model and tokenizer:
        try:
            with st.chat_message("assistant", avatar=BOT_AVATAR):
                start_time = time.time()
                streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True)
                
                response_container = st.empty()
                full_response = ""
                
                for chunk in streamer:
                    cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip()
                    full_response += cleaned_chunk + " "
                    response_container.markdown(full_response + "β–Œ", unsafe_allow_html=True)
                
                # Calculate performance metrics
                end_time = time.time()
                input_tokens = len(tokenizer(prompt)["input_ids"])
                output_tokens = len(tokenizer(full_response)["input_ids"])
                speed = output_tokens / (end_time - start_time)
                
                # Calculate costs (hypothetical pricing model)
                input_cost = (input_tokens / 1000000) * 5  # $5 per million input tokens
                output_cost = (output_tokens / 1000000) * 15  # $15 per million output tokens
                total_cost_usd = input_cost + output_cost
                total_cost_aoa = total_cost_usd * 1160  # Convert to AOA (Angolan Kwanza)
                
                # Display metrics
                st.caption(
                    f"πŸ”‘ Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | "
                    f"πŸ•’ Speed: {speed:.1f}t/s | πŸ’° Cost (USD): ${total_cost_usd:.4f} | "
                    f"πŸ’΅ Cost (AOA): {total_cost_aoa:.4f}"
                )
                
                response_container.markdown(full_response)
                st.session_state.messages.append({"role": "assistant", "content": full_response})
                
        except Exception as e:
            st.error(f"⚑ Generation error: {str(e)}")
    else:
        st.error("πŸ€– Model not loaded!")