frankai98 commited on
Commit
1bd0566
·
verified ·
1 Parent(s): d9e2c35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -37
app.py CHANGED
@@ -69,51 +69,63 @@ def display_temp_message(message, message_type="info", duration=5):
69
  placeholder.empty()
70
 
71
  # Load models with caching to avoid reloading on every run
72
- @st.cache_resource
73
- def load_models():
74
- llama_pipe = None
75
- score_pipe = None
76
-
 
77
  try:
78
- st.info("Loading Llama 3.2 summarization model...")
79
- llama_pipe = pipeline("text-generation",
80
- model="meta-llama/Llama-3.2-1B-Instruct",
81
- device=0, # Use GPU if available
82
- torch_dtype=torch.bfloat16,) # Use FP16 for efficiency
83
- #st.success("Llama 3.2 summarization model loaded successfully!")
84
-
85
- # Display success message that will disappear after 5 seconds
86
  Thread(
87
  target=display_temp_message,
88
- args=("Llama 3.2 summarization model loaded successfully!", "success"),
89
  daemon=True
90
  ).start()
91
-
 
 
92
  except Exception as e:
93
- st.error(f"Error loading Llama 3.2 summarization model: {e}")
 
 
 
 
94
  st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
95
-
96
- try:
97
- st.info("Loading sentiment analysis model...")
98
- score_pipe = pipeline("text-classification",
99
- model="cardiffnlp/twitter-roberta-base-sentiment-latest",
100
- device=0 if torch.cuda.is_available() else -1)
101
- #st.success("Sentiment analysis model loaded successfully!")
102
 
103
- # Display success message that will disappear after 5 seconds
104
- Thread(
105
- target=display_temp_message,
106
- args=("Sentiment analysis model loaded successfully!", "success"),
107
- daemon=True
108
- ).start()
109
-
110
- except Exception as e:
111
- st.error(f"Error loading sentiment analysis model: {e}")
112
-
113
-
114
- return llama_pipe, score_pipe
 
 
 
 
 
 
 
115
 
116
- llama_pipe, score_pipe = load_models()
 
 
 
 
 
117
 
118
  def extract_assistant_content(raw_response):
119
  """Extract only the assistant's content from the Gemma-3 response."""
@@ -144,7 +156,7 @@ def extract_assistant_content(raw_response):
144
  query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
145
  uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
146
 
147
- if score_pipe is None or gemma_pipe is None:
148
  st.error("Model loading failed. Please check your model names, token permissions, and GPU configuration.")
149
  else:
150
  candidate_docs = []
 
69
  placeholder.empty()
70
 
71
  # Load models with caching to avoid reloading on every run
72
+ def load_model_with_messages(loading_message, success_message, loading_function, error_message_prefix):
73
+ """Load a model with temporary loading and success messages."""
74
+ # Create placeholder for the loading message
75
+ loading_placeholder = st.empty()
76
+ loading_placeholder.info(loading_message)
77
+
78
  try:
79
+ # Load the model
80
+ result = loading_function()
81
+
82
+ # Clear the loading message
83
+ loading_placeholder.empty()
84
+
85
+ # Show temporary success message
 
86
  Thread(
87
  target=display_temp_message,
88
+ args=(success_message, "success"),
89
  daemon=True
90
  ).start()
91
+
92
+ return result
93
+
94
  except Exception as e:
95
+ # Clear the loading message
96
+ loading_placeholder.empty()
97
+
98
+ # Show error message
99
+ st.error(f"{error_message_prefix}: {e}")
100
  st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
101
+ return None
 
 
 
 
 
 
102
 
103
+ # Define loading functions
104
+ def load_llama_model():
105
+ return pipeline("text-generation",
106
+ model="meta-llama/Llama-3.2-1B-Instruct",
107
+ device=0, # Use GPU if available
108
+ torch_dtype=torch.bfloat16) # Use FP16 for efficiency
109
+
110
+ def load_sentiment_model():
111
+ return pipeline("text-classification",
112
+ model="cardiffnlp/twitter-roberta-base-sentiment-latest",
113
+ device=0 if torch.cuda.is_available() else -1)
114
+
115
+ # Load models with temporary messages
116
+ llama_pipe = load_model_with_messages(
117
+ "Loading Llama 3.2 summarization model...",
118
+ "Llama 3.2 summarization model loaded successfully!",
119
+ load_llama_model,
120
+ "Error loading Llama 3.2 summarization model"
121
+ )
122
 
123
+ score_pipe = load_model_with_messages(
124
+ "Loading sentiment analysis model...",
125
+ "Sentiment analysis model loaded successfully!",
126
+ load_sentiment_model,
127
+ "Error loading sentiment analysis model"
128
+ )
129
 
130
  def extract_assistant_content(raw_response):
131
  """Extract only the assistant's content from the Gemma-3 response."""
 
156
  query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
157
  uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
158
 
159
+ if score_pipe is None or llama_pipe is None:
160
  st.error("Model loading failed. Please check your model names, token permissions, and GPU configuration.")
161
  else:
162
  candidate_docs = []