Tokentesting / app.py
frankai98's picture
Update app.py
496495e verified
raw
history blame
12.9 kB
import os
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
from transformers import pipeline, AutoTokenizer
from huggingface_hub import login
from streamlit.components.v1 import html
import pandas as pd
import torch
import random
import gc
import time
# Retrieve the token from environment variables
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
st.error("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
st.stop()
# Login with the token
login(token=hf_token)
# Timer component using HTML and JavaScript
def timer():
return """
<div id="timer" style="font-size:16px;color:#666;margin-bottom:10px;">⏱️ Elapsed: 00:00</div>
<script>
(function() {
var start = Date.now();
var timerElement = document.getElementById('timer');
localStorage.removeItem("freezeTimer");
var interval = setInterval(function() {
if(localStorage.getItem("freezeTimer") === "true"){
clearInterval(interval);
timerElement.style.color = '#00cc00';
return;
}
var elapsed = Date.now() - start;
var minutes = Math.floor(elapsed / 60000);
var seconds = Math.floor((elapsed % 60000) / 1000);
timerElement.innerHTML = '⏱️ Elapsed: ' +
(minutes < 10 ? '0' : '') + minutes + ':' +
(seconds < 10 ? '0' : '') + seconds;
}, 1000);
})();
</script>
"""
st.set_page_config(page_title="Review Scorer & Report Generator", page_icon="πŸ“")
st.header("Review Scorer & Report Generator")
# Concise introduction
st.write("This model will score your reviews in your CSV file and generate a report based on your query and those results.")
# Cache the model loading functions
@st.cache_resource
def load_llama_model():
"""Load and cache the Llama 3.2 model"""
return pipeline("text-generation",
model="meta-llama/Llama-3.2-1B-Instruct",
device=0, # Use GPU if available
torch_dtype=torch.bfloat16) # Use FP16 for efficiency
@st.cache_resource
def load_sentiment_model():
"""Load and cache the sentiment analysis model"""
return pipeline("text-classification",
model="cardiffnlp/twitter-roberta-base-sentiment-latest",
device=0 if torch.cuda.is_available() else -1)
# Load Llama 3.2 model
loading_llama_placeholder = st.empty()
loading_llama_placeholder.info("Loading Llama 3.2 summarization model...")
try:
llama_pipe = load_llama_model()
# Clear loading message
loading_llama_placeholder.empty()
# Display success message in a placeholder
success_llama_placeholder = st.empty()
success_llama_placeholder.success("Llama 3.2 summarization model loaded successfully!")
# Use st.session_state to track when to clear the message
if "clear_llama_success_time" not in st.session_state:
st.session_state.clear_llama_success_time = time.time() + 5
# Check if it's time to clear the message
if time.time() > st.session_state.clear_llama_success_time:
success_llama_placeholder.empty()
except Exception as e:
# Clear loading message
loading_llama_placeholder.empty()
st.error(f"Error loading Llama 3.2 summarization model: {e}")
st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
# Load sentiment analysis model
loading_sentiment_placeholder = st.empty()
loading_sentiment_placeholder.info("Loading sentiment analysis model...")
try:
score_pipe = load_sentiment_model()
# Clear loading message
loading_sentiment_placeholder.empty()
# Display success message in a placeholder
success_sentiment_placeholder = st.empty()
success_sentiment_placeholder.success("Sentiment analysis model loaded successfully!")
# Use st.session_state to track when to clear the message
if "clear_sentiment_success_time" not in st.session_state:
st.session_state.clear_sentiment_success_time = time.time() + 5
# Check if it's time to clear the message
if time.time() > st.session_state.clear_sentiment_success_time:
success_sentiment_placeholder.empty()
except Exception as e:
# Clear loading message
loading_sentiment_placeholder.empty()
st.error(f"Error loading sentiment analysis model: {e}")
def extract_assistant_content(raw_response):
"""Extract only the assistant's content from the Gemma-3 response."""
# Convert to string and work with it directly
response_str = str(raw_response)
# Look for the assistant's content marker
assistant_marker = "'role': 'assistant', 'content': '"
if assistant_marker in response_str:
start_idx = response_str.find(assistant_marker) + len(assistant_marker)
# Extract everything after the marker until the end or closing quote
content = response_str[start_idx:]
# Find the end of the content (last single quote before the end of the string or before closing curly brace)
end_markers = ["'}", "'}]"]
end_idx = len(content)
for marker in end_markers:
pos = content.rfind(marker)
if pos != -1 and pos < end_idx:
end_idx = pos
return content[:end_idx]
# Fallback - return the original response
return response_str
# Input: Query text for scoring and CSV file upload for candidate reviews
query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
if score_pipe is None or llama_pipe is None:
st.error("Model loading failed. Please check your model names, token permissions, and GPU configuration.")
else:
candidate_docs = []
if uploaded_file is not None:
try:
df = pd.read_csv(uploaded_file)
if 'reviewText' not in df.columns:
st.error("CSV must contain a 'reviewText' column.")
else:
candidate_docs = df['reviewText'].dropna().astype(str).tolist()
except Exception as e:
st.error(f"Error reading CSV file: {e}")
if st.button("Generate Report"):
# Reset timer state so that the timer always shows up
st.session_state.timer_started = False
st.session_state.timer_frozen = False
if uploaded_file is None:
st.error("Please upload a CSV file.")
elif not candidate_docs:
st.error("CSV must contain a 'reviewText' column.")
elif not query_input.strip():
st.error("Please enter a query text!")
else:
if not st.session_state.timer_started and not st.session_state.timer_frozen:
st.session_state.timer_started = True
html(timer(), height=50)
status_text = st.empty()
progress_bar = st.progress(0)
# Stage 1: Process, summarize (if needed), and score candidate documents
status_text.markdown("**πŸ” Processing and scoring candidate documents...**")
# Process each review individually with summarization for long documents
processed_docs = [] # Store processed (original or summarized) documents
scored_results = [] # Store sentiment scores
for i, doc in enumerate(candidate_docs):
# Update progress based on current document
progress = int((i / len(candidate_docs)) * 50) # First half of progress bar (0-50%)
progress_bar.progress(progress)
try:
# Check if document exceeds the length limit for sentiment analysis
if len(doc) > 1500: # Approximate limit for sentiment model
# Use Llama 3.2 to summarize the document
summary_prompt = [
{"role": "user", "content": f"Summarize the following text into a shorter version that preserves the sentiment and key points: {doc[:2000]}..."}
]
summary_result = llama_pipe(
summary_prompt,
max_new_tokens=30, # Limit summary length
do_sample=True,
temperature=0.3, # Lower temperature for more factual summaries
return_full_text=False # Return only the generated text
)
# Extract the summary from the result
processed_doc = summary_result[0]['generated_text']
status_text.markdown(f"**πŸ“ Summarized document {i+1}/{len(candidate_docs)}**")
else:
# Use the original document if it's short enough
processed_doc = doc
# Store the processed document (original or summary)
processed_docs.append(processed_doc)
# Process the document with sentiment analysis
result = score_pipe(processed_doc)
# If it's a list, get the first element
if isinstance(result, list):
result = result[0]
scored_results.append(result)
# Free memory
torch.cuda.empty_cache()
except Exception as e:
st.warning(f"Error processing document {i}: {str(e)}")
# Add a placeholder result to maintain indexing
processed_docs.append("Error processing this document")
scored_results.append({"label": "NEUTRAL", "score": 0.5})
# Display occasional status updates for large datasets
if i % max(1, len(candidate_docs) // 10) == 0:
status_text.markdown(f"**πŸ” Scoring documents... ({i}/{len(candidate_docs)})**")
# Pair each review with its score assuming the output order matches the input order.
scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
progress_bar.progress(67)
# Stage 2: Generate Report using Gemma in the new messages format.
status_text.markdown("**πŸ“ Generating report with Gemma...**")
# After using score_pipe
del score_pipe
gc.collect()
torch.cuda.empty_cache()
# After using summarization_pipe
del llama_pipe
gc.collect()
torch.cuda.empty_cache()
# Then reload Gemma specifically for the final step
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
gemma_pipe = pipeline("text-generation",
model="google/gemma-3-1b-it",
tokenizer=tokenizer,
device=0,
torch_dtype=torch.bfloat16)
# Sample or summarize the data for Gemma to avoid memory issues
import random
max_reviews = 50 # Adjust based on your GPU memory
if len(scored_docs) > max_reviews:
sampled_docs = random.sample(scored_docs, max_reviews)
st.info(f"Sampling {max_reviews} out of {len(scored_docs)} reviews for report generation")
else:
sampled_docs = scored_docs
# Build the user content with query, sentiment results, and original review data.
# Format the prompt as chat for Gemma
messages = [
{"role": "user", "content": f"""
Generate a concise 300-word report based on the following analysis without repeating what's in the analysis.
Query:
"{query_input}"
Candidate Reviews with their scores:
{scored_docs}
"""}
]
raw_result = gemma_pipe(messages, max_new_tokens=150)
report = extract_assistant_content(raw_result)
progress_bar.progress(100)
status_text.success("**βœ… Generation complete!**")
html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
st.session_state.timer_frozen = True
#st.write("**Scored Candidate Reviews:**", scored_docs)
st.write("**Generated Report:**", report)