frankai98 commited on
Commit
9ba919a
Β·
verified Β·
1 Parent(s): fbd3d7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -69
app.py CHANGED
@@ -18,10 +18,10 @@ if not hf_token:
18
  login(token=hf_token)
19
 
20
  # Initialize session state for timer
21
- if 'timer_started' not in st.session_state:
22
- st.session_state.timer_started = False
23
- if 'timer_frozen' not in st.session_state:
24
- st.session_state.timer_frozen = False
25
 
26
  # Timer component using HTML and JavaScript
27
  def timer():
@@ -57,79 +57,90 @@ st.write("This model will score your reviews in your CSV file and generate a rep
57
 
58
  # Load models with caching to avoid reloading on every run
59
  @st.cache_resource
 
60
  def load_models():
61
- # Load the scoring model via pipeline.
62
- score_pipe = pipeline("text-classification", model="mixedbread-ai/mxbai-rerank-base-v1", device=0)
63
- # Load the Gemma text generation pipeline.
64
- gemma_pipe = pipeline("text-generation", model="google/gemma-3-1b-it", device=0)
 
 
 
 
 
 
65
  return score_pipe, gemma_pipe
66
 
 
67
  score_pipe, gemma_pipe = load_models()
68
 
69
  # Input: Query text for scoring and CSV file upload for candidate reviews
70
  query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
71
  uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
72
 
73
- candidate_docs = []
74
- if uploaded_file is not None:
75
- try:
76
- df = pd.read_csv(uploaded_file)
77
- if 'reviewText' not in df.columns:
78
- st.error("CSV must contain a 'reviewText' column.")
79
- else:
80
- candidate_docs = df['reviewText'].dropna().astype(str).tolist()
81
- except Exception as e:
82
- st.error(f"Error reading CSV file: {e}")
83
-
84
- if st.button("Generate Report"):
85
- # Reset timer state so that the timer always shows up
86
- st.session_state.timer_started = False
87
- st.session_state.timer_frozen = False
88
- # Display the timer every time
89
- html(timer(), height=50)
90
- if uploaded_file is None:
91
- st.error("Please upload a CSV file.")
92
- elif not candidate_docs:
93
- st.error("CSV must contain a 'reviewText' column.")
94
- elif not query_input.strip():
95
- st.error("Please enter a query text!")
96
- else:
97
- if not st.session_state.timer_started and not st.session_state.timer_frozen:
98
- st.session_state.timer_started = True
99
- html(timer(), height=50)
100
- status_text = st.empty()
101
- progress_bar = st.progress(0)
102
  try:
103
- # Stage 1: Score candidate documents using the provided query.
104
- status_text.markdown("**πŸ” Scoring candidate documents...**")
105
- progress_bar.progress(0)
106
-
107
- scored_docs = []
108
- for doc in candidate_docs:
109
- combined_text = f"Query: {query_input} Document: {doc}"
110
- result = score_pipe(combined_text)[0]
111
- scored_docs.append((doc, result["score"]))
112
-
113
- progress_bar.progress(50)
114
-
115
- # Stage 2: Generate Report using Gemma, including query and scored results.
116
- status_text.markdown("**πŸ“ Generating report with Gemma...**")
117
- prompt = f"""
118
- Generate a detailed report based on the following analysis.
119
- Query:
120
- "{query_input}"
121
- Candidate Reviews with their scores:
122
- {scored_docs}
123
- Please provide a concise summary report explaining the insights derived from these scores.
124
- """
125
- report = gemma_pipe(prompt, max_length=200)
126
- progress_bar.progress(100)
127
- status_text.success("**βœ… Generation complete!**")
128
- html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
129
- st.session_state.timer_frozen = True
130
- st.write("**Scored Candidate Reviews:**", scored_docs)
131
- st.write("**Generated Report:**", report[0]['generated_text'])
132
  except Exception as e:
133
- html("<script>document.getElementById('timer').remove();</script>")
134
- status_text.error(f"**❌ Error:** {str(e)}")
135
- progress_bar.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  login(token=hf_token)
19
 
20
  # Initialize session state for timer
21
+ #if 'timer_started' not in st.session_state:
22
+ #st.session_state.timer_started = False
23
+ #if 'timer_frozen' not in st.session_state:
24
+ #st.session_state.timer_frozen = False
25
 
26
  # Timer component using HTML and JavaScript
27
  def timer():
 
57
 
58
  # Load models with caching to avoid reloading on every run
59
  @st.cache_resource
60
+ @st.cache_resource
61
  def load_models():
62
+ try:
63
+ score_pipe = pipeline("text-classification", model="mixedbread-ai/mxbai-rerank-base-v1", device=0)
64
+ except Exception as e:
65
+ st.error(f"Error loading score model: {e}")
66
+ score_pipe = None
67
+ try:
68
+ gemma_pipe = pipeline("text-generation", model="google/gemma-3-1b-it", device=0)
69
+ except Exception as e:
70
+ st.error(f"Error loading Gemma model: {e}")
71
+ gemma_pipe = None
72
  return score_pipe, gemma_pipe
73
 
74
+
75
  score_pipe, gemma_pipe = load_models()
76
 
77
  # Input: Query text for scoring and CSV file upload for candidate reviews
78
  query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
79
  uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
80
 
81
+ if score_pipe is None or gemma_pipe is None:
82
+ st.error("Model loading failed. Please check your model names, token permissions, and GPU configuration.")
83
+ else:
84
+ candidate_docs = []
85
+ if uploaded_file is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  try:
87
+ df = pd.read_csv(uploaded_file)
88
+ if 'reviewText' not in df.columns:
89
+ st.error("CSV must contain a 'reviewText' column.")
90
+ else:
91
+ candidate_docs = df['reviewText'].dropna().astype(str).tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  except Exception as e:
93
+ st.error(f"Error reading CSV file: {e}")
94
+
95
+ if st.button("Generate Report"):
96
+ # Reset timer state so that the timer always shows up
97
+ st.session_state.timer_started = False
98
+ st.session_state.timer_frozen = False
99
+ # Display the timer every time
100
+ html(timer(), height=50)
101
+ if uploaded_file is None:
102
+ st.error("Please upload a CSV file.")
103
+ elif not candidate_docs:
104
+ st.error("CSV must contain a 'reviewText' column.")
105
+ elif not query_input.strip():
106
+ st.error("Please enter a query text!")
107
+ else:
108
+ if not st.session_state.timer_started and not st.session_state.timer_frozen:
109
+ st.session_state.timer_started = True
110
+ html(timer(), height=50)
111
+ status_text = st.empty()
112
+ progress_bar = st.progress(0)
113
+ try:
114
+ # Stage 1: Score candidate documents using the provided query.
115
+ status_text.markdown("**πŸ” Scoring candidate documents...**")
116
+ progress_bar.progress(0)
117
+
118
+ scored_docs = []
119
+ for doc in candidate_docs:
120
+ combined_text = f"Query: {query_input} Document: {doc}"
121
+ result = score_pipe(combined_text)[0]
122
+ scored_docs.append((doc, result["score"]))
123
+
124
+ progress_bar.progress(50)
125
+
126
+ # Stage 2: Generate Report using Gemma, including query and scored results.
127
+ status_text.markdown("**πŸ“ Generating report with Gemma...**")
128
+ prompt = f"""
129
+ Generate a detailed report based on the following analysis.
130
+ Query:
131
+ "{query_input}"
132
+ Candidate Reviews with their scores:
133
+ {scored_docs}
134
+ Please provide a concise summary report explaining the insights derived from these scores.
135
+ """
136
+ report = gemma_pipe(prompt, max_length=200)
137
+ progress_bar.progress(100)
138
+ status_text.success("**βœ… Generation complete!**")
139
+ html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
140
+ st.session_state.timer_frozen = True
141
+ st.write("**Scored Candidate Reviews:**", scored_docs)
142
+ st.write("**Generated Report:**", report[0]['generated_text'])
143
+ except Exception as e:
144
+ html("<script>document.getElementById('timer').remove();</script>")
145
+ status_text.error(f"**❌ Error:** {str(e)}")
146
+ progress_bar.empty()