frankai98 commited on
Commit
1ee20a5
Β·
verified Β·
1 Parent(s): 496495e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -119
app.py CHANGED
@@ -54,76 +54,35 @@ st.write("This model will score your reviews in your CSV file and generate a rep
54
 
55
  # Cache the model loading functions
56
  @st.cache_resource
57
- def load_llama_model():
58
- """Load and cache the Llama 3.2 model"""
59
- return pipeline("text-generation",
60
- model="meta-llama/Llama-3.2-1B-Instruct",
61
- device=0, # Use GPU if available
62
- torch_dtype=torch.bfloat16) # Use FP16 for efficiency
63
-
64
- @st.cache_resource
65
- def load_sentiment_model():
66
- """Load and cache the sentiment analysis model"""
67
  return pipeline("text-classification",
68
  model="cardiffnlp/twitter-roberta-base-sentiment-latest",
69
  device=0 if torch.cuda.is_available() else -1)
70
 
71
- # Load Llama 3.2 model
72
- loading_llama_placeholder = st.empty()
73
- loading_llama_placeholder.info("Loading Llama 3.2 summarization model...")
 
 
 
74
 
75
- try:
76
- llama_pipe = load_llama_model()
77
-
78
- # Clear loading message
79
- loading_llama_placeholder.empty()
80
-
81
- # Display success message in a placeholder
82
- success_llama_placeholder = st.empty()
83
- success_llama_placeholder.success("Llama 3.2 summarization model loaded successfully!")
84
-
85
- # Use st.session_state to track when to clear the message
86
- if "clear_llama_success_time" not in st.session_state:
87
- st.session_state.clear_llama_success_time = time.time() + 5
88
-
89
- # Check if it's time to clear the message
90
- if time.time() > st.session_state.clear_llama_success_time:
91
- success_llama_placeholder.empty()
92
-
93
- except Exception as e:
94
- # Clear loading message
95
- loading_llama_placeholder.empty()
96
-
97
- st.error(f"Error loading Llama 3.2 summarization model: {e}")
98
- st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
99
 
100
- # Load sentiment analysis model
101
- loading_sentiment_placeholder = st.empty()
102
- loading_sentiment_placeholder.info("Loading sentiment analysis model...")
 
 
 
 
103
 
104
- try:
105
- score_pipe = load_sentiment_model()
106
-
107
- # Clear loading message
108
- loading_sentiment_placeholder.empty()
109
-
110
- # Display success message in a placeholder
111
- success_sentiment_placeholder = st.empty()
112
- success_sentiment_placeholder.success("Sentiment analysis model loaded successfully!")
113
-
114
- # Use st.session_state to track when to clear the message
115
- if "clear_sentiment_success_time" not in st.session_state:
116
- st.session_state.clear_sentiment_success_time = time.time() + 5
117
-
118
- # Check if it's time to clear the message
119
- if time.time() > st.session_state.clear_sentiment_success_time:
120
- success_sentiment_placeholder.empty()
121
-
122
- except Exception as e:
123
- # Clear loading message
124
- loading_sentiment_placeholder.empty()
125
-
126
- st.error(f"Error loading sentiment analysis model: {e}")
127
 
128
  def extract_assistant_content(raw_response):
129
  """Extract only the assistant's content from the Gemma-3 response."""
@@ -186,94 +145,111 @@ else:
186
  progress_bar = st.progress(0)
187
 
188
 
189
- # Stage 1: Process, summarize (if needed), and score candidate documents
190
- status_text.markdown("**πŸ” Processing and scoring candidate documents...**")
 
 
 
 
191
 
192
- # Process each review individually with summarization for long documents
193
- processed_docs = [] # Store processed (original or summarized) documents
194
- scored_results = [] # Store sentiment scores
195
-
196
  for i, doc in enumerate(candidate_docs):
197
- # Update progress based on current document
198
- progress = int((i / len(candidate_docs)) * 50) # First half of progress bar (0-50%)
199
- progress_bar.progress(progress)
200
 
201
- try:
202
- # Check if document exceeds the length limit for sentiment analysis
203
- if len(doc) > 1500: # Approximate limit for sentiment model
204
- # Use Llama 3.2 to summarize the document
205
- summary_prompt = [
206
- {"role": "user", "content": f"Summarize the following text into a shorter version that preserves the sentiment and key points: {doc[:2000]}..."}
207
- ]
208
-
 
 
 
 
 
 
 
 
 
209
  summary_result = llama_pipe(
210
  summary_prompt,
211
- max_new_tokens=30, # Limit summary length
212
  do_sample=True,
213
- temperature=0.3, # Lower temperature for more factual summaries
214
- return_full_text=False # Return only the generated text
215
  )
216
 
217
- # Extract the summary from the result
218
- processed_doc = summary_result[0]['generated_text']
219
- status_text.markdown(f"**πŸ“ Summarized document {i+1}/{len(candidate_docs)}**")
220
- else:
221
- # Use the original document if it's short enough
222
- processed_doc = doc
223
-
224
- # Store the processed document (original or summary)
225
- processed_docs.append(processed_doc)
226
-
227
- # Process the document with sentiment analysis
228
- result = score_pipe(processed_doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  # If it's a list, get the first element
231
  if isinstance(result, list):
232
  result = result[0]
233
 
 
234
  scored_results.append(result)
235
-
236
- # Free memory
237
- torch.cuda.empty_cache()
238
 
239
  except Exception as e:
240
- st.warning(f"Error processing document {i}: {str(e)}")
241
- # Add a placeholder result to maintain indexing
242
  processed_docs.append("Error processing this document")
243
  scored_results.append({"label": "NEUTRAL", "score": 0.5})
244
-
245
 
246
- # Display occasional status updates for large datasets
247
  if i % max(1, len(candidate_docs) // 10) == 0:
248
  status_text.markdown(f"**πŸ” Scoring documents... ({i}/{len(candidate_docs)})**")
249
-
250
- # Pair each review with its score assuming the output order matches the input order.
251
- scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
252
 
253
- progress_bar.progress(67)
 
254
 
255
- # Stage 2: Generate Report using Gemma in the new messages format.
256
- status_text.markdown("**πŸ“ Generating report with Gemma...**")
257
-
258
- # After using score_pipe
259
  del score_pipe
260
  gc.collect()
261
  torch.cuda.empty_cache()
262
 
263
- # After using summarization_pipe
264
- del llama_pipe
265
- gc.collect()
266
- torch.cuda.empty_cache()
 
267
 
268
- # Then reload Gemma specifically for the final step
269
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
270
  gemma_pipe = pipeline("text-generation",
271
- model="google/gemma-3-1b-it",
272
- tokenizer=tokenizer,
273
- device=0,
274
- torch_dtype=torch.bfloat16)
275
 
276
  # Sample or summarize the data for Gemma to avoid memory issues
 
 
277
  import random
278
  max_reviews = 50 # Adjust based on your GPU memory
279
  if len(scored_docs) > max_reviews:
 
54
 
55
  # Cache the model loading functions
56
  @st.cache_resource
57
+ def get_sentiment_model():
 
 
 
 
 
 
 
 
 
58
  return pipeline("text-classification",
59
  model="cardiffnlp/twitter-roberta-base-sentiment-latest",
60
  device=0 if torch.cuda.is_available() else -1)
61
 
62
+ @st.cache_resource
63
+ def get_llama_model():
64
+ return pipeline("text-generation",
65
+ model="meta-llama/Llama-3.2-1B-Instruct",
66
+ device=0,
67
+ torch_dtype=torch.bfloat16)
68
 
69
+ @st.cache_resource
70
+ def get_gemma_model():
71
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
72
+ return pipeline("text-generation",
73
+ model="google/gemma-3-1b-it",
74
+ tokenizer=tokenizer,
75
+ device=0,
76
+ torch_dtype=torch.bfloat16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
+ # Function to clear GPU memory
79
+ def clear_gpu_memory():
80
+ import gc
81
+ gc.collect()
82
+ if torch.cuda.is_available():
83
+ torch.cuda.empty_cache()
84
+ torch.cuda.ipc_collect()
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def extract_assistant_content(raw_response):
88
  """Extract only the assistant's content from the Gemma-3 response."""
 
145
  progress_bar = st.progress(0)
146
 
147
 
148
+ processed_docs = []
149
+ scored_results = []
150
+
151
+ # First, check which documents need summarization
152
+ docs_to_summarize = []
153
+ docs_indices = []
154
 
 
 
 
 
155
  for i, doc in enumerate(candidate_docs):
156
+ if len(doc) > 1500:
157
+ docs_to_summarize.append(doc)
158
+ docs_indices.append(i)
159
 
160
+ # If we have documents to summarize, load Llama model first
161
+ if docs_to_summarize:
162
+ status_text.markdown("**πŸ“ Loading summarization model...**")
163
+ llama_pipe = load_llama_model()
164
+
165
+ status_text.markdown("**πŸ“ Summarizing long documents...**")
166
+
167
+ # Process documents that need summarization
168
+ for idx, (i, doc) in enumerate(zip(docs_indices, docs_to_summarize)):
169
+ progress = int((idx / len(docs_to_summarize)) * 25) # First quarter of progress
170
+ progress_bar.progress(progress)
171
+
172
+ summary_prompt = [
173
+ {"role": "user", "content": f"Summarize the following text into a shorter version that preserves the sentiment and key points: {doc[:2000]}..."}
174
+ ]
175
+
176
+ try:
177
  summary_result = llama_pipe(
178
  summary_prompt,
179
+ max_new_tokens=30,
180
  do_sample=True,
181
+ temperature=0.3,
182
+ return_full_text=False
183
  )
184
 
185
+ # Store the summary in place of the original text
186
+ candidate_docs[i] = summary_result[0]['generated_text']
187
+
188
+ except Exception as e:
189
+ st.warning(f"Error summarizing document {i}: {str(e)}")
190
+
191
+ # Clear Llama model from memory
192
+ del llama_pipe
193
+ gc.collect()
194
+ torch.cuda.empty_cache()
195
+
196
+ # Now load sentiment model
197
+ status_text.markdown("**πŸ” Loading sentiment analysis model...**")
198
+ score_pipe = load_sentiment_model()
199
+
200
+ status_text.markdown("**πŸ” Scoring documents...**")
201
+
202
+ # Process each document with sentiment analysis
203
+ for i, doc in enumerate(candidate_docs):
204
+ progress_offset = 25 if docs_to_summarize else 0
205
+ progress = progress_offset + int((i / len(candidate_docs)) * (50 - progress_offset))
206
+ progress_bar.progress(progress)
207
+
208
+ try:
209
+ # Process with sentiment analysis
210
+ result = score_pipe(doc)
211
 
212
  # If it's a list, get the first element
213
  if isinstance(result, list):
214
  result = result[0]
215
 
216
+ processed_docs.append(doc)
217
  scored_results.append(result)
 
 
 
218
 
219
  except Exception as e:
220
+ st.warning(f"Error scoring document {i}: {str(e)}")
 
221
  processed_docs.append("Error processing this document")
222
  scored_results.append({"label": "NEUTRAL", "score": 0.5})
 
223
 
224
+ # Display occasional status updates
225
  if i % max(1, len(candidate_docs) // 10) == 0:
226
  status_text.markdown(f"**πŸ” Scoring documents... ({i}/{len(candidate_docs)})**")
 
 
 
227
 
228
+ # Pair documents with scores
229
+ scored_docs = list(zip(processed_docs, [result.get("score", 0.5) for result in scored_results]))
230
 
231
+ # Clear sentiment model from memory
 
 
 
232
  del score_pipe
233
  gc.collect()
234
  torch.cuda.empty_cache()
235
 
236
+ progress_bar.progress(67)
237
+
238
+ # Load Gemma for final report generation
239
+ status_text.markdown("**πŸ“Š Loading report generation model...**")
240
+ progress_bar.progress(67)
241
 
242
+
243
  tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
244
  gemma_pipe = pipeline("text-generation",
245
+ model="google/gemma-3-1b-it",
246
+ tokenizer=tokenizer,
247
+ device=0,
248
+ torch_dtype=torch.bfloat16)
249
 
250
  # Sample or summarize the data for Gemma to avoid memory issues
251
+ status_text.markdown("**πŸ“ Generating report...**")
252
+ progress_bar.progress(80)
253
  import random
254
  max_reviews = 50 # Adjust based on your GPU memory
255
  if len(scored_docs) > max_reviews: