Spaces:
Sleeping
Sleeping
| import os | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| import streamlit as st | |
| from transformers import pipeline, AutoTokenizer | |
| from huggingface_hub import login | |
| from streamlit.components.v1 import html | |
| import pandas as pd | |
| import torch | |
| import random | |
| # Retrieve the token from environment variables | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| st.error("Hugging Face token not found. Please set the HF_TOKEN environment variable.") | |
| st.stop() | |
| # Login with the token | |
| login(token=hf_token) | |
| # Initialize session state for timer | |
| #if 'timer_started' not in st.session_state: | |
| #st.session_state.timer_started = False | |
| #if 'timer_frozen' not in st.session_state: | |
| #st.session_state.timer_frozen = False | |
| # Timer component using HTML and JavaScript | |
| def timer(): | |
| return """ | |
| <div id="timer" style="font-size:16px;color:#666;margin-bottom:10px;">β±οΈ Elapsed: 00:00</div> | |
| <script> | |
| (function() { | |
| var start = Date.now(); | |
| var timerElement = document.getElementById('timer'); | |
| localStorage.removeItem("freezeTimer"); | |
| var interval = setInterval(function() { | |
| if(localStorage.getItem("freezeTimer") === "true"){ | |
| clearInterval(interval); | |
| timerElement.style.color = '#00cc00'; | |
| return; | |
| } | |
| var elapsed = Date.now() - start; | |
| var minutes = Math.floor(elapsed / 60000); | |
| var seconds = Math.floor((elapsed % 60000) / 1000); | |
| timerElement.innerHTML = 'β±οΈ Elapsed: ' + | |
| (minutes < 10 ? '0' : '') + minutes + ':' + | |
| (seconds < 10 ? '0' : '') + seconds; | |
| }, 1000); | |
| })(); | |
| </script> | |
| """ | |
| st.set_page_config(page_title="Review Scorer & Report Generator", page_icon="π") | |
| st.header("Review Scorer & Report Generator") | |
| # Concise introduction | |
| st.write("This model will score your reviews in your CSV file and generate a report based on your query and those results.") | |
| # Load models with caching to avoid reloading on every run | |
| def load_models(): | |
| score_pipe = None | |
| gemma_pipe = None | |
| try: | |
| st.info("Loading sentiment analysis model...") | |
| score_pipe = pipeline("text-classification", | |
| model="cardiffnlp/twitter-roberta-base-sentiment-latest", | |
| device=0 if torch.cuda.is_available() else -1) | |
| st.success("Sentiment analysis model loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading score model: {e}") | |
| try: | |
| st.info("Loading Gemma model...") | |
| # Load the tokenizer separately with the chat template | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") | |
| gemma_pipe = pipeline("text-generation", | |
| model="google/gemma-3-1b-it", | |
| tokenizer=tokenizer, # Pass the loaded tokenizer here | |
| device=0, | |
| torch_dtype=torch.bfloat16) | |
| st.success("Gemma model loaded successfully!") | |
| except Exception as e: | |
| st.error(f"Error loading Gemma model: {e}") | |
| st.error(f"Detailed error: {type(e).__name__}: {str(e)}") | |
| return score_pipe, gemma_pipe | |
| def extract_assistant_content(raw_response): | |
| """Extract only the assistant's content from the Gemma-3 response.""" | |
| # Convert to string and work with it directly | |
| response_str = str(raw_response) | |
| # Look for the assistant's content marker | |
| assistant_marker = "'role': 'assistant', 'content': '" | |
| if assistant_marker in response_str: | |
| start_idx = response_str.find(assistant_marker) + len(assistant_marker) | |
| # Extract everything after the marker until the end or closing quote | |
| content = response_str[start_idx:] | |
| # Find the end of the content (last single quote before the end of the string or before closing curly brace) | |
| end_markers = ["'}", "'}]"] | |
| end_idx = len(content) | |
| for marker in end_markers: | |
| pos = content.rfind(marker) | |
| if pos != -1 and pos < end_idx: | |
| end_idx = pos | |
| return content[:end_idx] | |
| # Fallback - return the original response | |
| return response_str | |
| score_pipe, gemma_pipe = load_models() | |
| # Input: Query text for scoring and CSV file upload for candidate reviews | |
| query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):") | |
| uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"]) | |
| if score_pipe is None or gemma_pipe is None: | |
| st.error("Model loading failed. Please check your model names, token permissions, and GPU configuration.") | |
| else: | |
| candidate_docs = [] | |
| if uploaded_file is not None: | |
| try: | |
| df = pd.read_csv(uploaded_file) | |
| if 'reviewText' not in df.columns: | |
| st.error("CSV must contain a 'reviewText' column.") | |
| else: | |
| candidate_docs = df['reviewText'].dropna().astype(str).tolist() | |
| except Exception as e: | |
| st.error(f"Error reading CSV file: {e}") | |
| if st.button("Generate Report"): | |
| # Reset timer state so that the timer always shows up | |
| st.session_state.timer_started = False | |
| st.session_state.timer_frozen = False | |
| if uploaded_file is None: | |
| st.error("Please upload a CSV file.") | |
| elif not candidate_docs: | |
| st.error("CSV must contain a 'reviewText' column.") | |
| elif not query_input.strip(): | |
| st.error("Please enter a query text!") | |
| else: | |
| if not st.session_state.timer_started and not st.session_state.timer_frozen: | |
| st.session_state.timer_started = True | |
| html(timer(), height=50) | |
| status_text = st.empty() | |
| progress_bar = st.progress(0) | |
| # Stage 1: Score candidate documents using the provided query. | |
| status_text.markdown("**π Scoring candidate documents...**") | |
| # Process each review individually to avoid memory issues | |
| scored_results = [] | |
| for i, doc in enumerate(candidate_docs): | |
| # Update progress based on current document | |
| progress = int((i / len(candidate_docs)) * 50) # First half of progress bar (0-50%) | |
| progress_bar.progress(progress) | |
| # Process single document | |
| result = score_pipe(doc) | |
| scored_results.append(result[0]) | |
| # Display occasional status updates for large datasets | |
| if i % max(1, len(candidate_docs) // 10) == 0: | |
| status_text.markdown(f"**π Scoring documents... ({i}/{len(candidate_docs)})**") | |
| # Pair each review with its score assuming the output order matches the input order. | |
| scored_docs = list(zip(candidate_docs, [result["score"] for result in scored_results])) | |
| progress_bar.progress(67) | |
| # Stage 2: Generate Report using Gemma in the new messages format. | |
| status_text.markdown("**π Generating report with Gemma...**") | |
| # For very large datasets, summarize or sample the scored_docs before sending to Gemma | |
| sampled_docs = scored_docs | |
| if len(scored_docs) > 10000: # Arbitrary threshold for what's "too large" | |
| # Option 1: Random sampling | |
| sampled_docs = random.sample(scored_docs, 1000) | |
| # Build the user content with query, sentiment results, and original review data. | |
| # Format the prompt as chat for Gemma | |
| messages = [ | |
| {"role": "user", "content": f""" | |
| Generate a concise 300-word report based on the following analysis without repeating what's in the analysis. | |
| Query: | |
| "{query_input}" | |
| Candidate Reviews with their scores: | |
| {scored_docs} | |
| """} | |
| ] | |
| raw_result = gemma_pipe(messages, max_new_tokens=50) | |
| report = extract_assistant_content(raw_result) | |
| progress_bar.progress(100) | |
| status_text.success("**β Generation complete!**") | |
| html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0) | |
| st.session_state.timer_frozen = True | |
| #st.write("**Scored Candidate Reviews:**", scored_docs) | |
| st.write("**Generated Report:**", report) |