Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -54,76 +54,35 @@ st.write("This model will score your reviews in your CSV file and generate a rep
|
|
54 |
|
55 |
# Cache the model loading functions
|
56 |
@st.cache_resource
|
57 |
-
def
|
58 |
-
"""Load and cache the Llama 3.2 model"""
|
59 |
-
return pipeline("text-generation",
|
60 |
-
model="meta-llama/Llama-3.2-1B-Instruct",
|
61 |
-
device=0, # Use GPU if available
|
62 |
-
torch_dtype=torch.bfloat16) # Use FP16 for efficiency
|
63 |
-
|
64 |
-
@st.cache_resource
|
65 |
-
def load_sentiment_model():
|
66 |
-
"""Load and cache the sentiment analysis model"""
|
67 |
return pipeline("text-classification",
|
68 |
model="cardiffnlp/twitter-roberta-base-sentiment-latest",
|
69 |
device=0 if torch.cuda.is_available() else -1)
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
success_llama_placeholder.success("Llama 3.2 summarization model loaded successfully!")
|
84 |
-
|
85 |
-
# Use st.session_state to track when to clear the message
|
86 |
-
if "clear_llama_success_time" not in st.session_state:
|
87 |
-
st.session_state.clear_llama_success_time = time.time() + 5
|
88 |
-
|
89 |
-
# Check if it's time to clear the message
|
90 |
-
if time.time() > st.session_state.clear_llama_success_time:
|
91 |
-
success_llama_placeholder.empty()
|
92 |
-
|
93 |
-
except Exception as e:
|
94 |
-
# Clear loading message
|
95 |
-
loading_llama_placeholder.empty()
|
96 |
-
|
97 |
-
st.error(f"Error loading Llama 3.2 summarization model: {e}")
|
98 |
-
st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
|
99 |
|
100 |
-
#
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
try:
|
105 |
-
score_pipe = load_sentiment_model()
|
106 |
-
|
107 |
-
# Clear loading message
|
108 |
-
loading_sentiment_placeholder.empty()
|
109 |
-
|
110 |
-
# Display success message in a placeholder
|
111 |
-
success_sentiment_placeholder = st.empty()
|
112 |
-
success_sentiment_placeholder.success("Sentiment analysis model loaded successfully!")
|
113 |
-
|
114 |
-
# Use st.session_state to track when to clear the message
|
115 |
-
if "clear_sentiment_success_time" not in st.session_state:
|
116 |
-
st.session_state.clear_sentiment_success_time = time.time() + 5
|
117 |
-
|
118 |
-
# Check if it's time to clear the message
|
119 |
-
if time.time() > st.session_state.clear_sentiment_success_time:
|
120 |
-
success_sentiment_placeholder.empty()
|
121 |
-
|
122 |
-
except Exception as e:
|
123 |
-
# Clear loading message
|
124 |
-
loading_sentiment_placeholder.empty()
|
125 |
-
|
126 |
-
st.error(f"Error loading sentiment analysis model: {e}")
|
127 |
|
128 |
def extract_assistant_content(raw_response):
|
129 |
"""Extract only the assistant's content from the Gemma-3 response."""
|
@@ -186,94 +145,111 @@ else:
|
|
186 |
progress_bar = st.progress(0)
|
187 |
|
188 |
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
191 |
|
192 |
-
# Process each review individually with summarization for long documents
|
193 |
-
processed_docs = [] # Store processed (original or summarized) documents
|
194 |
-
scored_results = [] # Store sentiment scores
|
195 |
-
|
196 |
for i, doc in enumerate(candidate_docs):
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
summary_result = llama_pipe(
|
210 |
summary_prompt,
|
211 |
-
max_new_tokens=30,
|
212 |
do_sample=True,
|
213 |
-
temperature=0.3,
|
214 |
-
return_full_text=False
|
215 |
)
|
216 |
|
217 |
-
#
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
# If it's a list, get the first element
|
231 |
if isinstance(result, list):
|
232 |
result = result[0]
|
233 |
|
|
|
234 |
scored_results.append(result)
|
235 |
-
|
236 |
-
# Free memory
|
237 |
-
torch.cuda.empty_cache()
|
238 |
|
239 |
except Exception as e:
|
240 |
-
st.warning(f"Error
|
241 |
-
# Add a placeholder result to maintain indexing
|
242 |
processed_docs.append("Error processing this document")
|
243 |
scored_results.append({"label": "NEUTRAL", "score": 0.5})
|
244 |
-
|
245 |
|
246 |
-
# Display occasional status updates
|
247 |
if i % max(1, len(candidate_docs) // 10) == 0:
|
248 |
status_text.markdown(f"**π Scoring documents... ({i}/{len(candidate_docs)})**")
|
249 |
-
|
250 |
-
# Pair each review with its score assuming the output order matches the input order.
|
251 |
-
scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
|
252 |
|
253 |
-
|
|
|
254 |
|
255 |
-
#
|
256 |
-
status_text.markdown("**π Generating report with Gemma...**")
|
257 |
-
|
258 |
-
# After using score_pipe
|
259 |
del score_pipe
|
260 |
gc.collect()
|
261 |
torch.cuda.empty_cache()
|
262 |
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
267 |
|
268 |
-
|
269 |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
270 |
gemma_pipe = pipeline("text-generation",
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
|
276 |
# Sample or summarize the data for Gemma to avoid memory issues
|
|
|
|
|
277 |
import random
|
278 |
max_reviews = 50 # Adjust based on your GPU memory
|
279 |
if len(scored_docs) > max_reviews:
|
|
|
54 |
|
55 |
# Cache the model loading functions
|
56 |
@st.cache_resource
|
57 |
+
def get_sentiment_model():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
return pipeline("text-classification",
|
59 |
model="cardiffnlp/twitter-roberta-base-sentiment-latest",
|
60 |
device=0 if torch.cuda.is_available() else -1)
|
61 |
|
62 |
+
@st.cache_resource
|
63 |
+
def get_llama_model():
|
64 |
+
return pipeline("text-generation",
|
65 |
+
model="meta-llama/Llama-3.2-1B-Instruct",
|
66 |
+
device=0,
|
67 |
+
torch_dtype=torch.bfloat16)
|
68 |
|
69 |
+
@st.cache_resource
|
70 |
+
def get_gemma_model():
|
71 |
+
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
72 |
+
return pipeline("text-generation",
|
73 |
+
model="google/gemma-3-1b-it",
|
74 |
+
tokenizer=tokenizer,
|
75 |
+
device=0,
|
76 |
+
torch_dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
+
# Function to clear GPU memory
|
79 |
+
def clear_gpu_memory():
|
80 |
+
import gc
|
81 |
+
gc.collect()
|
82 |
+
if torch.cuda.is_available():
|
83 |
+
torch.cuda.empty_cache()
|
84 |
+
torch.cuda.ipc_collect()
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
def extract_assistant_content(raw_response):
|
88 |
"""Extract only the assistant's content from the Gemma-3 response."""
|
|
|
145 |
progress_bar = st.progress(0)
|
146 |
|
147 |
|
148 |
+
processed_docs = []
|
149 |
+
scored_results = []
|
150 |
+
|
151 |
+
# First, check which documents need summarization
|
152 |
+
docs_to_summarize = []
|
153 |
+
docs_indices = []
|
154 |
|
|
|
|
|
|
|
|
|
155 |
for i, doc in enumerate(candidate_docs):
|
156 |
+
if len(doc) > 1500:
|
157 |
+
docs_to_summarize.append(doc)
|
158 |
+
docs_indices.append(i)
|
159 |
|
160 |
+
# If we have documents to summarize, load Llama model first
|
161 |
+
if docs_to_summarize:
|
162 |
+
status_text.markdown("**π Loading summarization model...**")
|
163 |
+
llama_pipe = load_llama_model()
|
164 |
+
|
165 |
+
status_text.markdown("**π Summarizing long documents...**")
|
166 |
+
|
167 |
+
# Process documents that need summarization
|
168 |
+
for idx, (i, doc) in enumerate(zip(docs_indices, docs_to_summarize)):
|
169 |
+
progress = int((idx / len(docs_to_summarize)) * 25) # First quarter of progress
|
170 |
+
progress_bar.progress(progress)
|
171 |
+
|
172 |
+
summary_prompt = [
|
173 |
+
{"role": "user", "content": f"Summarize the following text into a shorter version that preserves the sentiment and key points: {doc[:2000]}..."}
|
174 |
+
]
|
175 |
+
|
176 |
+
try:
|
177 |
summary_result = llama_pipe(
|
178 |
summary_prompt,
|
179 |
+
max_new_tokens=30,
|
180 |
do_sample=True,
|
181 |
+
temperature=0.3,
|
182 |
+
return_full_text=False
|
183 |
)
|
184 |
|
185 |
+
# Store the summary in place of the original text
|
186 |
+
candidate_docs[i] = summary_result[0]['generated_text']
|
187 |
+
|
188 |
+
except Exception as e:
|
189 |
+
st.warning(f"Error summarizing document {i}: {str(e)}")
|
190 |
+
|
191 |
+
# Clear Llama model from memory
|
192 |
+
del llama_pipe
|
193 |
+
gc.collect()
|
194 |
+
torch.cuda.empty_cache()
|
195 |
+
|
196 |
+
# Now load sentiment model
|
197 |
+
status_text.markdown("**π Loading sentiment analysis model...**")
|
198 |
+
score_pipe = load_sentiment_model()
|
199 |
+
|
200 |
+
status_text.markdown("**π Scoring documents...**")
|
201 |
+
|
202 |
+
# Process each document with sentiment analysis
|
203 |
+
for i, doc in enumerate(candidate_docs):
|
204 |
+
progress_offset = 25 if docs_to_summarize else 0
|
205 |
+
progress = progress_offset + int((i / len(candidate_docs)) * (50 - progress_offset))
|
206 |
+
progress_bar.progress(progress)
|
207 |
+
|
208 |
+
try:
|
209 |
+
# Process with sentiment analysis
|
210 |
+
result = score_pipe(doc)
|
211 |
|
212 |
# If it's a list, get the first element
|
213 |
if isinstance(result, list):
|
214 |
result = result[0]
|
215 |
|
216 |
+
processed_docs.append(doc)
|
217 |
scored_results.append(result)
|
|
|
|
|
|
|
218 |
|
219 |
except Exception as e:
|
220 |
+
st.warning(f"Error scoring document {i}: {str(e)}")
|
|
|
221 |
processed_docs.append("Error processing this document")
|
222 |
scored_results.append({"label": "NEUTRAL", "score": 0.5})
|
|
|
223 |
|
224 |
+
# Display occasional status updates
|
225 |
if i % max(1, len(candidate_docs) // 10) == 0:
|
226 |
status_text.markdown(f"**π Scoring documents... ({i}/{len(candidate_docs)})**")
|
|
|
|
|
|
|
227 |
|
228 |
+
# Pair documents with scores
|
229 |
+
scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
|
230 |
|
231 |
+
# Clear sentiment model from memory
|
|
|
|
|
|
|
232 |
del score_pipe
|
233 |
gc.collect()
|
234 |
torch.cuda.empty_cache()
|
235 |
|
236 |
+
progress_bar.progress(67)
|
237 |
+
|
238 |
+
# Load Gemma for final report generation
|
239 |
+
status_text.markdown("**π Loading report generation model...**")
|
240 |
+
progress_bar.progress(67)
|
241 |
|
242 |
+
|
243 |
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
|
244 |
gemma_pipe = pipeline("text-generation",
|
245 |
+
model="google/gemma-3-1b-it",
|
246 |
+
tokenizer=tokenizer,
|
247 |
+
device=0,
|
248 |
+
torch_dtype=torch.bfloat16)
|
249 |
|
250 |
# Sample or summarize the data for Gemma to avoid memory issues
|
251 |
+
status_text.markdown("**π Generating report...**")
|
252 |
+
progress_bar.progress(80)
|
253 |
import random
|
254 |
max_reviews = 50 # Adjust based on your GPU memory
|
255 |
if len(scored_docs) > max_reviews:
|