Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
import nest_asyncio
|
3 |
nest_asyncio.apply()
|
4 |
import streamlit as st
|
5 |
-
from transformers import pipeline
|
6 |
from huggingface_hub import login
|
7 |
from streamlit.components.v1 import html
|
8 |
import pandas as pd
|
@@ -53,7 +53,7 @@ st.set_page_config(page_title="Review Scorer & Report Generator", page_icon="
|
|
53 |
st.header("Review Scorer & Report Generator")
|
54 |
|
55 |
# Concise introduction
|
56 |
-
st.write("This model will score your reviews in your CSV file and generate a report based on those results.")
|
57 |
|
58 |
# Load models with caching to avoid reloading on every run
|
59 |
@st.cache_resource
|
@@ -65,15 +65,18 @@ def load_models():
|
|
65 |
st.info("Loading sentiment analysis model...")
|
66 |
score_pipe = pipeline("text-classification",
|
67 |
model="nlptown/bert-base-multilingual-uncased-sentiment",
|
68 |
-
device=0)
|
69 |
st.success("Sentiment analysis model loaded successfully!")
|
70 |
except Exception as e:
|
71 |
st.error(f"Error loading score model: {e}")
|
72 |
|
73 |
try:
|
74 |
st.info("Loading Gemma model...")
|
|
|
|
|
75 |
gemma_pipe = pipeline("text-generation",
|
76 |
-
model="google/gemma-
|
|
|
77 |
device=0,
|
78 |
torch_dtype=torch.bfloat16)
|
79 |
st.success("Gemma model loaded successfully!")
|
@@ -83,9 +86,34 @@ def load_models():
|
|
83 |
|
84 |
return score_pipe, gemma_pipe
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
score_pipe, gemma_pipe = load_models()
|
88 |
|
|
|
89 |
# Input: Query text for scoring and CSV file upload for candidate reviews
|
90 |
query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
|
91 |
uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
|
@@ -148,10 +176,11 @@ Candidate Reviews with their scores:
|
|
148 |
"""}
|
149 |
]
|
150 |
|
151 |
-
|
|
|
152 |
progress_bar.progress(100)
|
153 |
status_text.success("**✅ Generation complete!**")
|
154 |
html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
|
155 |
st.session_state.timer_frozen = True
|
156 |
#st.write("**Scored Candidate Reviews:**", scored_docs)
|
157 |
-
st.write("**Generated Report:**",
|
|
|
2 |
import nest_asyncio
|
3 |
nest_asyncio.apply()
|
4 |
import streamlit as st
|
5 |
+
from transformers import pipeline, AutoTokenizer
|
6 |
from huggingface_hub import login
|
7 |
from streamlit.components.v1 import html
|
8 |
import pandas as pd
|
|
|
53 |
st.header("Review Scorer & Report Generator")
|
54 |
|
55 |
# Concise introduction
|
56 |
+
st.write("This model will score your reviews in your CSV file and generate a report based on your query and those results.")
|
57 |
|
58 |
# Load models with caching to avoid reloading on every run
|
59 |
@st.cache_resource
|
|
|
65 |
st.info("Loading sentiment analysis model...")
|
66 |
score_pipe = pipeline("text-classification",
|
67 |
model="nlptown/bert-base-multilingual-uncased-sentiment",
|
68 |
+
device=0 if torch.cuda.is_available() else -1)
|
69 |
st.success("Sentiment analysis model loaded successfully!")
|
70 |
except Exception as e:
|
71 |
st.error(f"Error loading score model: {e}")
|
72 |
|
73 |
try:
|
74 |
st.info("Loading Gemma model...")
|
75 |
+
# Load the tokenizer separately with the chat template
|
76 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
77 |
gemma_pipe = pipeline("text-generation",
|
78 |
+
model="google/gemma-3-1b-it",
|
79 |
+
tokenizer=tokenizer, # Pass the loaded tokenizer here
|
80 |
device=0,
|
81 |
torch_dtype=torch.bfloat16)
|
82 |
st.success("Gemma model loaded successfully!")
|
|
|
86 |
|
87 |
return score_pipe, gemma_pipe
|
88 |
|
89 |
+
def extract_assistant_content(raw_response):
|
90 |
+
"""Extract only the assistant's content from the Gemma-3 response."""
|
91 |
+
# Convert to string and work with it directly
|
92 |
+
response_str = str(raw_response)
|
93 |
+
|
94 |
+
# Look for the assistant's content marker
|
95 |
+
assistant_marker = "'role': 'assistant', 'content': '"
|
96 |
+
if assistant_marker in response_str:
|
97 |
+
start_idx = response_str.find(assistant_marker) + len(assistant_marker)
|
98 |
+
# Extract everything after the marker until the end or closing quote
|
99 |
+
content = response_str[start_idx:]
|
100 |
+
|
101 |
+
# Find the end of the content (last single quote before the end of the string or before closing curly brace)
|
102 |
+
end_markers = ["'}", "'}]"]
|
103 |
+
end_idx = len(content)
|
104 |
+
for marker in end_markers:
|
105 |
+
pos = content.rfind(marker)
|
106 |
+
if pos != -1 and pos < end_idx:
|
107 |
+
end_idx = pos
|
108 |
+
|
109 |
+
return content[:end_idx]
|
110 |
+
|
111 |
+
# Fallback - return the original response
|
112 |
+
return response_str
|
113 |
|
114 |
score_pipe, gemma_pipe = load_models()
|
115 |
|
116 |
+
|
117 |
# Input: Query text for scoring and CSV file upload for candidate reviews
|
118 |
query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
|
119 |
uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
|
|
|
176 |
"""}
|
177 |
]
|
178 |
|
179 |
+
raw_result = gemma_pipe(messages, max_new_tokens=50)
|
180 |
+
report = extract_assistant_content(raw_result)
|
181 |
progress_bar.progress(100)
|
182 |
status_text.success("**✅ Generation complete!**")
|
183 |
html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
|
184 |
st.session_state.timer_frozen = True
|
185 |
#st.write("**Scored Candidate Reviews:**", scored_docs)
|
186 |
+
st.write("**Generated Report:**", report)
|