Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from huggingface_hub import login | |
from threading import Thread | |
import PyPDF2 | |
import pandas as pd | |
import torch | |
import time | |
# Check if 'peft' is installed | |
try: | |
from peft import PeftModel, PeftConfig | |
except ImportError: | |
raise ImportError( | |
"The 'peft' library is required but not installed. " | |
"Please install it using: `pip install peft`" | |
) | |
# Set page configuration | |
st.set_page_config( | |
page_title="WizNerd Insp", | |
page_icon="π", | |
layout="centered" | |
) | |
# Hardcoded Hugging Face token (replace with your actual token) | |
HF_TOKEN = "your_hugging_face_token_here" | |
# Model names | |
BASE_MODEL_NAME = "google-bert/bert-base-uncased" | |
MODEL_OPTIONS = { | |
"Full Fine-Tuned": "amiguel/instruct_BERT-base-uncased_model", | |
"LoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-lora", | |
"QLoRA Adapter": "amiguel/SmolLM2-360M-concise-reasoning-qlora" # Hypothetical, adjust if needed | |
} | |
# Title with rocket emojis | |
st.title("π WizNerd Insp π") | |
# Configure Avatars | |
USER_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/9904d9a0d445ab0488cf7395cb863cce7621d897/USER_AVATAR.png" | |
BOT_AVATAR = "https://raw.githubusercontent.com/achilela/vila_fofoka_analysis/991f4c6e4e1dc7a8e24876ca5aae5228bcdb4dba/Ataliba_Avatar.jpg" | |
# Sidebar configuration | |
with st.sidebar: | |
st.header("Model Selection π€") | |
model_type = st.selectbox("Choose Model Type", list(MODEL_OPTIONS.keys()), index=0) | |
selected_model = MODEL_OPTIONS[model_type] | |
st.header("Upload Documents π") | |
uploaded_file = st.file_uploader( | |
"Choose a PDF or XLSX file", | |
type=["pdf", "xlsx"], | |
label_visibility="collapsed" | |
) | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# File processing function | |
def process_file(uploaded_file): | |
if uploaded_file is None: | |
return "" | |
try: | |
if uploaded_file.type == "application/pdf": | |
pdf_reader = PyPDF2.PdfReader(uploaded_file) | |
return "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
elif uploaded_file.type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": | |
df = pd.read_excel(uploaded_file) | |
return df.to_markdown() | |
except Exception as e: | |
st.error(f"π Error processing file: {str(e)}") | |
return "" | |
# Model loading function | |
def load_model(hf_token, model_type, selected_model): | |
try: | |
if not hf_token: | |
st.error("π Authentication required! Please provide a valid Hugging Face token.") | |
return None | |
login(token=hf_token) | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, token=hf_token) | |
# Determine device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model based on type | |
if model_type == "Full Fine-Tuned": | |
# Load full fine-tuned model directly | |
model = AutoModelForCausalLM.from_pretrained( | |
selected_model, | |
torch_dtype=torch.bfloat16, | |
token=hf_token | |
).to(device) | |
else: | |
# Load base model and apply PEFT adapter | |
base_model = AutoModelForCausalLM.from_pretrained( | |
BASE_MODEL_NAME, | |
torch_dtype=torch.bfloat16, | |
token=hf_token | |
).to(device) | |
model = PeftModel.from_pretrained( | |
base_model, | |
selected_model, | |
torch_dtype=torch.bfloat16, | |
is_trainable=False, # Inference mode | |
token=hf_token | |
).to(device) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"π€ Model loading failed: {str(e)}") | |
return None | |
# Generation function with KV caching | |
def generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True): | |
full_prompt = f"Analyze this context:\n{file_context}\n\nQuestion: {prompt}\nAnswer:" | |
streamer = TextIteratorStreamer( | |
tokenizer, | |
skip_prompt=True, | |
skip_special_tokens=True | |
) | |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
generation_kwargs = { | |
"input_ids": inputs["input_ids"], | |
"attention_mask": inputs["attention_mask"], | |
"max_new_tokens": 1024, | |
"temperature": 0.7, | |
"top_p": 0.9, | |
"repetition_penalty": 1.1, | |
"do_sample": True, | |
"use_cache": use_cache, | |
"streamer": streamer | |
} | |
Thread(target=model.generate, kwargs=generation_kwargs).start() | |
return streamer | |
# Display chat messages | |
for message in st.session_state.messages: | |
try: | |
avatar = USER_AVATAR if message["role"] == "user" else BOT_AVATAR | |
with st.chat_message(message["role"], avatar=avatar): | |
st.markdown(message["content"]) | |
except: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Chat input handling | |
if prompt := st.chat_input("Ask your inspection question..."): | |
# Load model if not already loaded or if model type changed | |
if "model" not in st.session_state or st.session_state.get("model_type") != model_type: | |
model_data = load_model(HF_TOKEN, model_type, selected_model) | |
if model_data is None: | |
st.error("Failed to load model. Please check your token and try again.") | |
st.stop() | |
st.session_state.model, st.session_state.tokenizer = model_data | |
st.session_state.model_type = model_type | |
model = st.session_state.model | |
tokenizer = st.session_state.tokenizer | |
# Add user message | |
with st.chat_message("user", avatar=USER_AVATAR): | |
st.markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Process file | |
file_context = process_file(uploaded_file) | |
# Generate response with KV caching | |
if model and tokenizer: | |
try: | |
with st.chat_message("assistant", avatar=BOT_AVATAR): | |
start_time = time.time() | |
streamer = generate_with_kv_cache(prompt, file_context, model, tokenizer, use_cache=True) | |
response_container = st.empty() | |
full_response = "" | |
for chunk in streamer: | |
cleaned_chunk = chunk.replace("<think>", "").replace("</think>", "").strip() | |
full_response += cleaned_chunk + " " | |
response_container.markdown(full_response + "β", unsafe_allow_html=True) | |
# Calculate performance metrics | |
end_time = time.time() | |
input_tokens = len(tokenizer(prompt)["input_ids"]) | |
output_tokens = len(tokenizer(full_response)["input_ids"]) | |
speed = output_tokens / (end_time - start_time) | |
# Calculate costs (hypothetical pricing model) | |
input_cost = (input_tokens / 1000000) * 5 # $5 per million input tokens | |
output_cost = (output_tokens / 1000000) * 15 # $15 per million output tokens | |
total_cost_usd = input_cost + output_cost | |
total_cost_aoa = total_cost_usd * 1160 # Convert to AOA (Angolan Kwanza) | |
# Display metrics | |
st.caption( | |
f"π Input Tokens: {input_tokens} | Output Tokens: {output_tokens} | " | |
f"π Speed: {speed:.1f}t/s | π° Cost (USD): ${total_cost_usd:.4f} | " | |
f"π΅ Cost (AOA): {total_cost_aoa:.4f}" | |
) | |
response_container.markdown(full_response) | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
except Exception as e: | |
st.error(f"β‘ Generation error: {str(e)}") | |
else: | |
st.error("π€ Model not loaded!") |