karths commited on
Commit
1ca494b
·
verified ·
1 Parent(s): a754efe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -69
app.py CHANGED
@@ -53,14 +53,14 @@ quality_mapping = {
53
 
54
  # Pre-load models and tokenizer for quality prediction
55
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
56
- models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths} # Load to CPU initially
57
 
58
  def get_quality_name(model_name):
59
  return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
60
 
61
 
62
  def model_prediction(model, text, device):
63
- model.to(device) # Move the *specific* model to the GPU
64
  model.eval()
65
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
  inputs = {k: v.to(device) for k, v in inputs.items()}
@@ -69,19 +69,19 @@ def model_prediction(model, text, device):
69
  logits = outputs.logits
70
  probs = softmax(logits.cpu().numpy(), axis=1)
71
  avg_prob = np.mean(probs[:, 1])
72
- model.to("cpu") # Move the model *back* to the CPU
73
  return avg_prob
74
 
75
  # --- Llama 3.2 3B Model Setup ---
76
- LLAMA_MAX_MAX_NEW_TOKENS = 512 # Max tokens for Explanation
77
- LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Max tokens for explantion
78
- LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) # Reduced
79
- llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
80
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
81
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
82
  llama_model = AutoModelForCausalLM.from_pretrained(
83
  llama_model_id,
84
- device_map="auto", # Let Transformers handle optimal device placement
85
  torch_dtype=torch.bfloat16,
86
  )
87
  llama_model.eval()
@@ -96,7 +96,7 @@ def llama_generate(
96
  top_p: float = 0.9,
97
  top_k: int = 50,
98
  repetition_penalty: float = 1.2,
99
- ) -> str: # Return string, not iterator
100
 
101
  inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
102
 
@@ -104,8 +104,7 @@ def llama_generate(
104
  inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
105
  gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
106
 
107
- # Generate *without* streaming
108
- with torch.no_grad(): # Ensure no gradient calculation
109
  generate_ids = llama_model.generate(
110
  **inputs,
111
  max_new_tokens=max_new_tokens,
@@ -115,56 +114,41 @@ def llama_generate(
115
  temperature=temperature,
116
  num_beams=1,
117
  repetition_penalty=repetition_penalty,
118
- pad_token_id=llama_tokenizer.pad_token_id, # Pass pad_token_id here
119
- eos_token_id=llama_tokenizer.eos_token_id, # Pass eos_token_id here
120
 
121
  )
122
  output_text = llama_tokenizer.decode(generate_ids[0], skip_special_tokens=True)
123
- torch.cuda.empty_cache() # Clear cache after each generation
124
  return output_text
125
 
126
 
127
- def generate_explanation(issue_text, top_qualities):
128
- if not top_qualities:
129
- return "<div style='color: red;'>No explanation available as no quality tags were predicted.</div>"
 
 
 
130
 
131
  prompt = f"""
132
  Given the following issue description:
133
  ---
134
  {issue_text}
135
  ---
136
- Explain why this issue might be classified under the following quality categories. Provide a concise explanation for each category, relating it back to the issue description:
137
  """
138
- for quality, _ in top_qualities:
139
- prompt += f"- {quality}\n"
140
-
141
-
142
  try:
143
  explanation = llama_generate(prompt)
144
- # Format the explanation for better readability
145
- formatted_explanation = ""
146
- for quality, _ in top_qualities:
147
- formatted_explanation += f"<p><b>{quality}:</b></p>" # Bold the quality name
148
- # Find the explanation for this specific quality. This is a simple
149
- # approach that works if Llama follows the prompt structure.
150
- # A more robust approach might use regex or sentence embeddings.
151
- start = explanation.find(quality)
152
- if start != -1:
153
- start += len(quality) + 2 # Move past "Quality:"
154
- end = explanation.find("\n", start) # Find next newline
155
- if end == -1:
156
- end = len(explanation)
157
- formatted_explanation += f"<p>{explanation[start:end].strip()}</p>" # Add the explanation text
158
- else:
159
- formatted_explanation += f"<p>Explanation for {quality} not found.</p>"
160
-
161
- return f"<div style='overflow-y: scroll; max-height: 400px;'>{formatted_explanation}</div>" #Added scroll
162
  except Exception as e:
163
  logging.error(f"Error during Llama generation: {e}")
164
  return "<div style='color: red;'>An error occurred while generating the explanation.</div>"
165
 
166
 
167
- # @spaces.GPU(duration=60) # Apply the GPU decorator *only* to the main interface
 
168
  def main_interface(text):
169
  if not text.strip():
170
  return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
@@ -172,26 +156,33 @@ def main_interface(text):
172
  if len(text) < 30:
173
  return "<div style='color: red;'>Text is less than 30 characters.</div>", "", ""
174
 
175
- device = "cuda" if torch.cuda.is_available() else "cpu" # Keep this for model_prediction
176
  results = []
177
  for model_path, model in models.items():
178
  quality_name = get_quality_name(model_path)
179
- avg_prob = model_prediction(model, text, device) # Pass the device
180
- if avg_prob >= 0.95:
181
  results.append((quality_name, avg_prob))
182
  logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
183
 
184
-
185
  if not results:
186
- return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold. </div>", "", ""
 
 
 
 
 
 
 
 
 
 
187
 
188
- top_qualities = sorted(results, key=lambda x: x[1], reverse=True)[:3]
189
- output_html = render_html_output(top_qualities)
190
- explanation = generate_explanation(text, top_qualities)
191
 
192
  return output_html, "", explanation
193
 
194
  def render_html_output(top_qualities):
 
195
  styles = """
196
  <style>
197
  .quality-container {
@@ -210,25 +201,18 @@ def render_html_output(top_qualities):
210
  margin-right: 10px;
211
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
212
  }
213
- .probability {
214
- display: block;
215
- margin-top: 10px;
216
- font-size: 16px;
217
- color: #007bff;
218
- }
219
  </style>
220
  """
221
- html_content = ""
222
- ranking_labels = ['Top 1 Prediction', 'Top 2 Prediction', 'Top 3 Prediction']
223
- top_n = min(len(top_qualities), len(ranking_labels))
224
- for i in range(top_n):
225
- quality, prob = top_qualities[i]
226
- html_content += f"""
227
- <div class="quality-container">
228
- <span class="ranking">{ranking_labels[i]}</span>
229
- <span class="quality-label">{quality}</span>
230
- </div>
231
- """
232
  return styles + html_content
233
 
234
  example_texts = [
@@ -237,17 +221,48 @@ example_texts = [
237
  ["There is frequent miscommunication between the development and QA teams regarding feature specifications.\n\nEnvironment: Inter-team meetings\nReproduction: Audit recent communication logs and meeting notes between the teams."],
238
  ["The service-oriented architecture does not effectively isolate failures, leading to cascading failures across services.\n\nEnvironment: Microservices architecture\nReproduction: Simulate a service failure and observe the impact on other services."]
239
  ]
240
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  interface = gr.Interface(
242
  fn=main_interface,
243
  inputs=gr.Textbox(lines=7, label="Issue Description", placeholder="Enter your issue text here"),
244
  outputs=[
245
  gr.HTML(label="Prediction Output"),
246
  gr.Textbox(label="Predictions", visible=False),
247
- gr.HTML(label="Explanation") # Change to gr.HTML
248
  ],
249
  title="QualityTagger",
250
  description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
251
- examples=example_texts
 
252
  )
253
  interface.launch(share=True)
 
53
 
54
  # Pre-load models and tokenizer for quality prediction
55
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
56
+ models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths}
57
 
58
  def get_quality_name(model_name):
59
  return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
60
 
61
 
62
  def model_prediction(model, text, device):
63
+ model.to(device)
64
  model.eval()
65
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
69
  logits = outputs.logits
70
  probs = softmax(logits.cpu().numpy(), axis=1)
71
  avg_prob = np.mean(probs[:, 1])
72
+ model.to("cpu")
73
  return avg_prob
74
 
75
  # --- Llama 3.2 3B Model Setup ---
76
+ LLAMA_MAX_MAX_NEW_TOKENS = 512
77
+ LLAMA_DEFAULT_MAX_NEW_TOKENS = 512
78
+ LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
79
+ llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
80
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
81
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
82
  llama_model = AutoModelForCausalLM.from_pretrained(
83
  llama_model_id,
84
+ device_map="auto",
85
  torch_dtype=torch.bfloat16,
86
  )
87
  llama_model.eval()
 
96
  top_p: float = 0.9,
97
  top_k: int = 50,
98
  repetition_penalty: float = 1.2,
99
+ ) -> str:
100
 
101
  inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
102
 
 
104
  inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
105
  gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
106
 
107
+ with torch.no_grad():
 
108
  generate_ids = llama_model.generate(
109
  **inputs,
110
  max_new_tokens=max_new_tokens,
 
114
  temperature=temperature,
115
  num_beams=1,
116
  repetition_penalty=repetition_penalty,
117
+ pad_token_id=llama_tokenizer.pad_token_id,
118
+ eos_token_id=llama_tokenizer.eos_token_id,
119
 
120
  )
121
  output_text = llama_tokenizer.decode(generate_ids[0], skip_special_tokens=True)
122
+ torch.cuda.empty_cache()
123
  return output_text
124
 
125
 
126
+ def generate_explanation(issue_text, top_quality):
127
+ """Generates an explanation for the *single* top quality above threshold."""
128
+ if not top_quality:
129
+ return "<div style='color: red;'>No explanation available as no quality tags met the threshold.</div>"
130
+
131
+ quality_name = top_quality[0] # Get the name of the top quality
132
 
133
  prompt = f"""
134
  Given the following issue description:
135
  ---
136
  {issue_text}
137
  ---
138
+ Explain why this issue might be classified as a **{quality_name}** issue. Provide a concise explanation, relating it back to the issue description.
139
  """
 
 
 
 
140
  try:
141
  explanation = llama_generate(prompt)
142
+ # Format for better readability, directly including the quality name.
143
+ formatted_explanation = f"<p><b>{quality_name}:</b></p><p>{explanation}</p>"
144
+ return f"<div style='overflow-y: scroll; max-height: 400px;'>{formatted_explanation}</div>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  except Exception as e:
146
  logging.error(f"Error during Llama generation: {e}")
147
  return "<div style='color: red;'>An error occurred while generating the explanation.</div>"
148
 
149
 
150
+
151
+ # @spaces.GPU(duration=60)
152
  def main_interface(text):
153
  if not text.strip():
154
  return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
 
156
  if len(text) < 30:
157
  return "<div style='color: red;'>Text is less than 30 characters.</div>", "", ""
158
 
159
+ device = "cuda" if torch.cuda.is_available() else "cpu"
160
  results = []
161
  for model_path, model in models.items():
162
  quality_name = get_quality_name(model_path)
163
+ avg_prob = model_prediction(model, text, device)
164
+ if avg_prob >= 0.95: # Keep *all* results above the threshold
165
  results.append((quality_name, avg_prob))
166
  logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
167
 
 
168
  if not results:
169
+ return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold.</div>", "", ""
170
+
171
+ # Sort and get the top result (if any meet the threshold)
172
+ top_result = sorted(results, key=lambda x: x[1], reverse=True)
173
+ if top_result:
174
+ top_quality = top_result[:1] # Select only the top result
175
+ output_html = render_html_output(top_quality)
176
+ explanation = generate_explanation(text, top_quality)
177
+ else: # Handle case no predictions >= 0.95
178
+ output_html = "<div style='color: red;'>No quality tag met the prediction probability threshold (>= 0.95).</div>"
179
+ explanation = ""
180
 
 
 
 
181
 
182
  return output_html, "", explanation
183
 
184
  def render_html_output(top_qualities):
185
+ #Simplified to show only the top prediction
186
  styles = """
187
  <style>
188
  .quality-container {
 
201
  margin-right: 10px;
202
  box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
203
  }
 
 
 
 
 
 
204
  </style>
205
  """
206
+ if not top_qualities: # Handle empty case
207
+ return styles + "<div class='quality-container'>No Top Prediction</div>"
208
+
209
+ quality, _ = top_qualities[0] #We know there is only one
210
+ html_content = f"""
211
+ <div class="quality-container">
212
+ <span class="ranking">Top Prediction</span>
213
+ <span class="quality-label">{quality}</span>
214
+ </div>
215
+ """
 
216
  return styles + html_content
217
 
218
  example_texts = [
 
221
  ["There is frequent miscommunication between the development and QA teams regarding feature specifications.\n\nEnvironment: Inter-team meetings\nReproduction: Audit recent communication logs and meeting notes between the teams."],
222
  ["The service-oriented architecture does not effectively isolate failures, leading to cascading failures across services.\n\nEnvironment: Microservices architecture\nReproduction: Simulate a service failure and observe the impact on other services."]
223
  ]
224
+ # Improved CSS for better layout and appearance
225
+ css = """
226
+ .quality-container {
227
+ font-family: Arial, sans-serif;
228
+ text-align: center;
229
+ margin-top: 20px;
230
+ padding: 10px;
231
+ border: 1px solid #ddd; /* Added border */
232
+ border-radius: 8px; /* Rounded corners */
233
+ background-color: #f9f9f9; /* Light background */
234
+ }
235
+ .quality-label, .ranking {
236
+ display: inline-block;
237
+ padding: 0.5em 1em;
238
+ font-size: 18px;
239
+ font-weight: bold;
240
+ color: white;
241
+ background-color: #007bff;
242
+ border-radius: 0.5rem;
243
+ margin-right: 10px;
244
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
245
+ }
246
+ #explanation {
247
+ border: 1px solid #ccc;
248
+ padding: 10px;
249
+ margin-top: 10px;
250
+ border-radius: 4px;
251
+ background-color: #fff; /* White background for explanation */
252
+ overflow-y: auto; /* Ensure scrollbar appears if needed */
253
+ }
254
+ """
255
  interface = gr.Interface(
256
  fn=main_interface,
257
  inputs=gr.Textbox(lines=7, label="Issue Description", placeholder="Enter your issue text here"),
258
  outputs=[
259
  gr.HTML(label="Prediction Output"),
260
  gr.Textbox(label="Predictions", visible=False),
261
+ gr.HTML(label="Explanation")
262
  ],
263
  title="QualityTagger",
264
  description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
265
+ examples=example_texts,
266
+ css=css # Apply the CSS
267
  )
268
  interface.launch(share=True)