frankai98 commited on
Commit
503f042
·
verified ·
1 Parent(s): baf5aeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -6
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import nest_asyncio
3
  nest_asyncio.apply()
4
  import streamlit as st
5
- from transformers import pipeline
6
  from huggingface_hub import login
7
  from streamlit.components.v1 import html
8
  import pandas as pd
@@ -53,7 +53,7 @@ st.set_page_config(page_title="Review Scorer & Report Generator", page_icon="
53
  st.header("Review Scorer & Report Generator")
54
 
55
  # Concise introduction
56
- st.write("This model will score your reviews in your CSV file and generate a report based on those results.")
57
 
58
  # Load models with caching to avoid reloading on every run
59
  @st.cache_resource
@@ -65,15 +65,18 @@ def load_models():
65
  st.info("Loading sentiment analysis model...")
66
  score_pipe = pipeline("text-classification",
67
  model="nlptown/bert-base-multilingual-uncased-sentiment",
68
- device=0)
69
  st.success("Sentiment analysis model loaded successfully!")
70
  except Exception as e:
71
  st.error(f"Error loading score model: {e}")
72
 
73
  try:
74
  st.info("Loading Gemma model...")
 
 
75
  gemma_pipe = pipeline("text-generation",
76
- model="google/gemma-2-2b-it",
 
77
  device=0,
78
  torch_dtype=torch.bfloat16)
79
  st.success("Gemma model loaded successfully!")
@@ -83,9 +86,34 @@ def load_models():
83
 
84
  return score_pipe, gemma_pipe
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  score_pipe, gemma_pipe = load_models()
88
 
 
89
  # Input: Query text for scoring and CSV file upload for candidate reviews
90
  query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
91
  uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
@@ -148,10 +176,11 @@ Candidate Reviews with their scores:
148
  """}
149
  ]
150
 
151
- output = gemma_pipe(messages, max_new_tokens=50)
 
152
  progress_bar.progress(100)
153
  status_text.success("**✅ Generation complete!**")
154
  html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
155
  st.session_state.timer_frozen = True
156
  #st.write("**Scored Candidate Reviews:**", scored_docs)
157
- st.write("**Generated Report:**", output)
 
2
  import nest_asyncio
3
  nest_asyncio.apply()
4
  import streamlit as st
5
+ from transformers import pipeline, AutoTokenizer
6
  from huggingface_hub import login
7
  from streamlit.components.v1 import html
8
  import pandas as pd
 
53
  st.header("Review Scorer & Report Generator")
54
 
55
  # Concise introduction
56
+ st.write("This model will score your reviews in your CSV file and generate a report based on your query and those results.")
57
 
58
  # Load models with caching to avoid reloading on every run
59
  @st.cache_resource
 
65
  st.info("Loading sentiment analysis model...")
66
  score_pipe = pipeline("text-classification",
67
  model="nlptown/bert-base-multilingual-uncased-sentiment",
68
+ device=0 if torch.cuda.is_available() else -1)
69
  st.success("Sentiment analysis model loaded successfully!")
70
  except Exception as e:
71
  st.error(f"Error loading score model: {e}")
72
 
73
  try:
74
  st.info("Loading Gemma model...")
75
+ # Load the tokenizer separately with the chat template
76
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
77
  gemma_pipe = pipeline("text-generation",
78
+ model="google/gemma-3-1b-it",
79
+ tokenizer=tokenizer, # Pass the loaded tokenizer here
80
  device=0,
81
  torch_dtype=torch.bfloat16)
82
  st.success("Gemma model loaded successfully!")
 
86
 
87
  return score_pipe, gemma_pipe
88
 
89
+ def extract_assistant_content(raw_response):
90
+ """Extract only the assistant's content from the Gemma-3 response."""
91
+ # Convert to string and work with it directly
92
+ response_str = str(raw_response)
93
+
94
+ # Look for the assistant's content marker
95
+ assistant_marker = "'role': 'assistant', 'content': '"
96
+ if assistant_marker in response_str:
97
+ start_idx = response_str.find(assistant_marker) + len(assistant_marker)
98
+ # Extract everything after the marker until the end or closing quote
99
+ content = response_str[start_idx:]
100
+
101
+ # Find the end of the content (last single quote before the end of the string or before closing curly brace)
102
+ end_markers = ["'}", "'}]"]
103
+ end_idx = len(content)
104
+ for marker in end_markers:
105
+ pos = content.rfind(marker)
106
+ if pos != -1 and pos < end_idx:
107
+ end_idx = pos
108
+
109
+ return content[:end_idx]
110
+
111
+ # Fallback - return the original response
112
+ return response_str
113
 
114
  score_pipe, gemma_pipe = load_models()
115
 
116
+
117
  # Input: Query text for scoring and CSV file upload for candidate reviews
118
  query_input = st.text_area("Enter your query text for analysis (this does not need to be part of the CSV):")
119
  uploaded_file = st.file_uploader("Upload Reviews CSV File (must contain a 'reviewText' column)", type=["csv"])
 
176
  """}
177
  ]
178
 
179
+ raw_result = gemma_pipe(messages, max_new_tokens=50)
180
+ report = extract_assistant_content(raw_result)
181
  progress_bar.progress(100)
182
  status_text.success("**✅ Generation complete!**")
183
  html("<script>localStorage.setItem('freezeTimer', 'true');</script>", height=0)
184
  st.session_state.timer_frozen = True
185
  #st.write("**Scored Candidate Reviews:**", scored_docs)
186
+ st.write("**Generated Report:**", report)