Spaces:
Sleeping
Sleeping
import os | |
import nest_asyncio | |
nest_asyncio.apply() | |
import streamlit as st | |
from transformers import pipeline, AutoTokenizer | |
from huggingface_hub import login | |
from streamlit.components.v1 import html | |
import pandas as pd | |
import torch | |
import random | |
import gc | |
import time | |
# Retrieve the token from environment variables | |
hf_token = os.environ.get("HF_TOKEN") | |
if not hf_token: | |
st.error("Hugging Face token not found. Please set the HF_TOKEN environment variable.") | |
st.stop() | |
# Login with the token | |
login(token=hf_token) | |
# Timer component using HTML and JavaScript | |
def timer(): | |
return """ | |
<div id="timer" style="font-size:16px;color:#666;margin-bottom:10px;">β±οΈ Elapsed: 00:00</div> | |
<script> | |
(function() { | |
var start = Date.now(); | |
var timerElement = document.getElementById('timer'); | |
localStorage.removeItem("freezeTimer"); | |
var interval = setInterval(function() { | |
if(localStorage.getItem("freezeTimer") === "true"){ | |
clearInterval(interval); | |
timerElement.style.color = '#00cc00'; | |
return; | |
} | |
var elapsed = Date.now() - start; | |
var minutes = Math.floor(elapsed / 60000); | |
var seconds = Math.floor((elapsed % 60000) / 1000); | |
timerElement.innerHTML = 'β±οΈ Elapsed: ' + | |
(minutes < 10 ? '0' : '') + minutes + ':' + | |
(seconds < 10 ? '0' : '') + seconds; | |
}, 1000); | |
})(); | |
</script> | |
""" | |
st.set_page_config(page_title="Review Scorer & Report Generator", page_icon="π") | |
st.header("Review Scorer & Report Generator") | |
# Concise introduction | |
st.write("This model will score your reviews in your CSV file and generate a report based on your query and those results.") | |
# Cache the model loading functions | |
def load_llama_model(): | |
"""Load and cache the Llama 3.2 model""" | |
return pipeline("text-generation", | |
model="meta-llama/Llama-3.2-1B-Instruct", | |
device=0, # Use GPU if available | |
torch_dtype=torch.bfloat16) # Use FP16 for efficiency | |
def load_sentiment_model(): | |
"""Load and cache the sentiment analysis model""" | |
return pipeline("text-classification", | |
model="cardiffnlp/twitter-roberta-base-sentiment-latest", | |
device=0 if torch.cuda.is_available() else -1) | |
# Load Llama 3.2 model | |
loading_llama_placeholder = st.empty() | |
loading_llama_placeholder.info("Loading Llama 3.2 summarization model...") | |
try: | |
llama_pipe = load_llama_model() | |
# Clear loading message | |
loading_llama_placeholder.empty() | |
# Display success message in a placeholder | |
success_llama_placeholder = st.empty() | |
success_llama_placeholder.success("Llama 3.2 summarization model loaded successfully!") | |
# Use st.session_state to track when to clear the message | |
if "clear_llama_success_time" not in st.session_state: | |
st.session_state.clear_llama_success_time = time.time() + 5 | |
# Check if it's time to clear the message | |
if time.time() > st.session_state.clear_llama_success_time: | |
success_llama_placeholder.empty() | |
except Exception as e: | |
# Clear loading message | |
loading_llama_placeholder.empty() | |
st.error(f"Error loading Llama 3.2 summarization model: {e}") | |
st.error(f"Detailed error: {type(e).__name__}: {str(e)}") | |
# Load sentiment analysis model | |
loading_sentiment_placeholder = st.empty() | |
loading_sentiment_placeholder.info("Loading sentiment analysis model...") | |
try: | |
score_pipe = load_sentiment_model() | |
# Clear loading message | |
loading_sentiment_placeholder.empty() | |
# Display success message in a placeholder | |
success_sentiment_placeholder = st.empty() | |
success_sentiment_placeholder.success("Sentiment analysis model loaded successfully!") | |
# Use st.session_state to track when to clear the message | |
if "clear_sentiment_success_time" not in st.session_state: | |
st.session_state.clear_sentiment_success_time = time.time() + 5 | |
# Check if it's time to clear the message | |
if time.time() > st.session_state.clear_sentiment_success_time: | |
success_sentiment_placeholder.empty() | |
except Exception as e: | |
# Clear loading message | |
loading_sentiment_placeholder.empty() | |
st.error(f"Error loading sentiment analysis model: {e}") | |
def extract_assistant_content(raw_response): | |
"""Extract only the assistant's content from the Gemma-3 response.""" | |
# Convert to string and work with it directly | |
response_str = str(raw_response) | |
# Look for the assistant's content marker | |
assistant_marker = "'role': 'assistant', 'content': '" | |
if assistant_marker in response_str: | |
start_idx = response_str.find(assistant_marker) + len(assistant_marker) | |
# Extract everything after the marker until the end or closing quote | |
content = response_str[start_idx:] | |
# Find the end of the content (last single quote before the end of the string or before closing curly brace) | |
end_markers = ["'}", "'}]"] | |
end_idx = len(content) | |
for marker in end_markers: | |
pos = content.rfind(marker) | |
if pos != -1 and pos < end_idx: | |
end_idx = pos | |
return content[:end_idx] | |
# Fallback - return the original response | |
return response_str | |
# Input: Query text for scoring and CSV file upload for candidate reviews | |
query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):") | |
uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"]) | |
if score_pipe is None or llama_pipe is None: | |
st.error("Model loading failed. Please check your model names, token permissions, and GPU configuration.") | |
else: | |
candidate_docs = [] | |
if uploaded_file is not None: | |
try: | |
df = pd.read_csv(uploaded_file) | |
if 'reviewText' not in df.columns: | |
st.error("CSV must contain a 'reviewText' column.") | |
else: | |
candidate_docs = df['reviewText'].dropna().astype(str).tolist() | |
except Exception as e: | |
st.error(f"Error reading CSV file: {e}") | |
if st.button("Generate Report"): | |
# Reset timer state so that the timer always shows up | |
st.session_state.timer_started = False | |
st.session_state.timer_frozen = False | |
if uploaded_file is None: | |
st.error("Please upload a CSV file.") | |
elif not candidate_docs: | |
st.error("CSV must contain a 'reviewText' column.") | |
elif not query_input.strip(): | |
st.error("Please enter a query text!") | |
else: | |
if not st.session_state.timer_started and not st.session_state.timer_frozen: | |
st.session_state.timer_started = True | |
html(timer(), height=50) | |
status_text = st.empty() | |
progress_bar = st.progress(0) | |
# Stage 1: Process, summarize (if needed), and score candidate documents | |
status_text.markdown("**π Processing and scoring candidate documents...**") | |
# Process each review individually with summarization for long documents | |
processed_docs = [] # Store processed (original or summarized) documents | |
scored_results = [] # Store sentiment scores | |
for i, doc in enumerate(candidate_docs): | |
# Update progress based on current document | |
progress = int((i / len(candidate_docs)) * 50) # First half of progress bar (0-50%) | |
progress_bar.progress(progress) | |
try: | |
# Check if document exceeds the length limit for sentiment analysis | |
if len(doc) > 1500: # Approximate limit for sentiment model | |
# Use Llama 3.2 to summarize the document | |
summary_prompt = [ | |
{"role": "user", "content": f"Summarize the following text into a shorter version that preserves the sentiment and key points: {doc[:2000]}..."} | |
] | |
summary_result = llama_pipe( | |
summary_prompt, | |
max_new_tokens=30, # Limit summary length | |
do_sample=True, | |
temperature=0.3, # Lower temperature for more factual summaries | |
return_full_text=False # Return only the generated text | |
) | |
# Extract the summary from the result | |
processed_doc = summary_result[0]['generated_text'] | |
status_text.markdown(f"**π Summarized document {i+1}/{len(candidate_docs)}**") | |
else: | |
# Use the original document if it's short enough | |
processed_doc = doc | |
# Store the processed document (original or summary) | |
processed_docs.append(processed_doc) | |
# Process the document with sentiment analysis | |
result = score_pipe(processed_doc) | |
# If it's a list, get the first element | |
if isinstance(result, list): | |
result = result[0] | |
scored_results.append(result) | |
# Free memory | |
torch.cuda.empty_cache() | |
except Exception as e: | |
st.warning(f"Error processing document {i}: {str(e)}") | |
# Add a placeholder result to maintain indexing | |
processed_docs.append("Error processing this document") | |
scored_results.append({"label": "NEUTRAL", "score": 0.5}) | |
# Display occasional status updates for large datasets | |
if i % max(1, len(candidate_docs) // 10) == 0: | |
status_text.markdown(f"**π Scoring documents... ({i}/{len(candidate_docs)})**") | |
# Pair each review with its score assuming the output order matches the input order. | |
scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results])) | |
progress_bar.progress(67) | |
# Stage 2: Generate Report using Gemma in the new messages format. | |
status_text.markdown("**π Generating report with Gemma...**") | |
# After using score_pipe | |
del score_pipe | |
gc.collect() | |
torch.cuda.empty_cache() | |
# After using summarization_pipe | |
del llama_pipe | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Then reload Gemma specifically for the final step | |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") | |
gemma_pipe = pipeline("text-generation", | |
model="google/gemma-3-1b-it", | |
tokenizer=tokenizer, | |
device=0, | |
torch_dtype=torch.bfloat16) | |
# Sample or summarize the data for Gemma to avoid memory issues | |
import random | |
max_reviews = 50 # Adjust based on your GPU memory | |
if len(scored_docs) > max_reviews: | |
sampled_docs = random.sample(scored_docs, max_reviews) | |
st.info(f"Sampling {max_reviews} out of {len(scored_docs)} reviews for report generation") | |
else: | |
sampled_docs = scored_docs | |
# Build the user content with query, sentiment results, and original review data. | |
# Format the prompt as chat for Gemma | |
messages = [ | |
{"role": "user", "content": f""" | |
Generate a concise 300-word report based on the following analysis without repeating what's in the analysis. | |
Query: | |
"{query_input}" | |
Candidate Reviews with their scores: | |
{scored_docs} | |
"""} | |
] | |
raw_result = gemma_pipe(messages, max_new_tokens=150) | |
report = extract_assistant_content(raw_result) | |
progress_bar.progress(100) | |
status_text.success("**β Generation complete!**") | |
html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0) | |
st.session_state.timer_frozen = True | |
#st.write("**Scored Candidate Reviews:**", scored_docs) | |
st.write("**Generated Report:**", report) |