karths commited on
Commit
2d57f5f
·
verified ·
1 Parent(s): 668a83f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -25
app.py CHANGED
@@ -13,8 +13,7 @@ from collections.abc import Iterator
13
  import csv
14
 
15
  # Increase CSV field size limit
16
- csv.field_size_limit(1000000) # Or an even larger value if needed
17
-
18
 
19
  # Setup logging
20
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
@@ -63,14 +62,14 @@ quality_mapping = {
63
 
64
  # Pre-load models and tokenizer for quality prediction
65
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
66
- models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths}
67
 
68
  def get_quality_name(model_name):
69
  return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
70
 
71
- @spaces.GPU
72
  def model_prediction(model, text, device):
73
- model.to(device)
74
  model.eval()
75
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
76
  inputs = {k: v.to(device) for k, v in inputs.items()}
@@ -79,30 +78,26 @@ def model_prediction(model, text, device):
79
  logits = outputs.logits
80
  probs = softmax(logits.cpu().numpy(), axis=1)
81
  avg_prob = np.mean(probs[:, 1])
 
82
  return avg_prob
83
 
84
  # --- Llama 3.2 3B Model Setup ---
85
  LLAMA_MAX_MAX_NEW_TOKENS = 2048
86
- LLAMA_DEFAULT_MAX_NEW_TOKENS = 1024
87
- LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
88
- llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicitly define device
89
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
90
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
91
  llama_model = AutoModelForCausalLM.from_pretrained(
92
  llama_model_id,
93
- device_map="auto", # Automatically distribute model across devices
94
  torch_dtype=torch.bfloat16,
95
  )
96
  llama_model.eval()
97
 
98
- # --- IMPORTANT: Set Pad Token ---
99
- # Llama3 does *not* have a default pad token. We *must* set one.
100
- # Using the EOS token as the PAD token is a common and recommended practice.
101
  if llama_tokenizer.pad_token is None:
102
  llama_tokenizer.pad_token = llama_tokenizer.eos_token
103
 
104
-
105
- @spaces.GPU(duration=150)
106
  def llama_generate(
107
  message: str,
108
  max_new_tokens: int = LLAMA_DEFAULT_MAX_NEW_TOKENS,
@@ -113,7 +108,6 @@ def llama_generate(
113
  ) -> Iterator[str]:
114
 
115
  inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
116
- #The line above was changed to add attention mask
117
 
118
  if inputs.input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
119
  inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
@@ -121,7 +115,7 @@ def llama_generate(
121
 
122
  streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
123
  generate_kwargs = dict(
124
- inputs, # Pass the entire inputs dictionary
125
  streamer=streamer,
126
  max_new_tokens=max_new_tokens,
127
  do_sample=True,
@@ -137,7 +131,7 @@ def llama_generate(
137
  for text in streamer:
138
  outputs.append(text)
139
  yield "".join(outputs)
140
-
141
 
142
 
143
  def generate_explanation(issue_text, top_qualities):
@@ -156,14 +150,14 @@ def generate_explanation(issue_text, top_qualities):
156
  explanation = ""
157
  try:
158
  for chunk in llama_generate(prompt):
159
- explanation += chunk # Accumulate generated text
160
  except Exception as e:
161
  logging.error(f"Error during Llama generation: {e}")
162
  return "An error occurred while generating the explanation."
163
 
164
  return explanation
165
 
166
-
167
  def main_interface(text):
168
  if not text.strip():
169
  return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
@@ -171,25 +165,24 @@ def main_interface(text):
171
  if len(text) < 30:
172
  return "<div style='color: red;'>Text is less than 30 characters.</div>", "", ""
173
 
174
- device = "cuda" if torch.cuda.is_available() else "cpu"
175
  results = []
176
  for model_path, model in models.items():
177
  quality_name = get_quality_name(model_path)
178
- avg_prob = model_prediction(model, text, device)
179
  if avg_prob >= 0.95:
180
  results.append((quality_name, avg_prob))
181
  logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
182
 
 
183
  if not results:
184
  return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold. </div>", "", ""
185
 
186
  top_qualities = sorted(results, key=lambda x: x[1], reverse=True)[:3]
187
  output_html = render_html_output(top_qualities)
188
-
189
- # Generate explanation using the top qualities and the original input text
190
  explanation = generate_explanation(text, top_qualities)
191
 
192
- return output_html, "", explanation # Return explanation as the third output
193
 
194
  def render_html_output(top_qualities):
195
  styles = """
@@ -244,7 +237,7 @@ interface = gr.Interface(
244
  outputs=[
245
  gr.HTML(label="Prediction Output"),
246
  gr.Textbox(label="Predictions", visible=False),
247
- gr.Textbox(label="Explanation", lines=5) # Added Textbox for explanation
248
  ],
249
  title="QualityTagger",
250
  description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
 
13
  import csv
14
 
15
  # Increase CSV field size limit
16
+ csv.field_size_limit(1000000)
 
17
 
18
  # Setup logging
19
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
 
62
 
63
  # Pre-load models and tokenizer for quality prediction
64
  tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
65
+ models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths} # Load to CPU initially
66
 
67
  def get_quality_name(model_name):
68
  return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
69
 
70
+
71
  def model_prediction(model, text, device):
72
+ model.to(device) # Move the *specific* model to the GPU
73
  model.eval()
74
  inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
75
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
78
  logits = outputs.logits
79
  probs = softmax(logits.cpu().numpy(), axis=1)
80
  avg_prob = np.mean(probs[:, 1])
81
+ model.to("cpu") # Move the model *back* to the CPU
82
  return avg_prob
83
 
84
  # --- Llama 3.2 3B Model Setup ---
85
  LLAMA_MAX_MAX_NEW_TOKENS = 2048
86
+ LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Reduced for efficiency
87
+ LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) # Reduced
88
+ llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
89
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
90
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
91
  llama_model = AutoModelForCausalLM.from_pretrained(
92
  llama_model_id,
93
+ device_map="auto", # Let Transformers handle optimal device placement
94
  torch_dtype=torch.bfloat16,
95
  )
96
  llama_model.eval()
97
 
 
 
 
98
  if llama_tokenizer.pad_token is None:
99
  llama_tokenizer.pad_token = llama_tokenizer.eos_token
100
 
 
 
101
  def llama_generate(
102
  message: str,
103
  max_new_tokens: int = LLAMA_DEFAULT_MAX_NEW_TOKENS,
 
108
  ) -> Iterator[str]:
109
 
110
  inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
 
111
 
112
  if inputs.input_ids.shape[1] > LLAMA_MAX_INPUT_TOKEN_LENGTH:
113
  inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
 
115
 
116
  streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
117
  generate_kwargs = dict(
118
+ inputs,
119
  streamer=streamer,
120
  max_new_tokens=max_new_tokens,
121
  do_sample=True,
 
131
  for text in streamer:
132
  outputs.append(text)
133
  yield "".join(outputs)
134
+ torch.cuda.empty_cache() # Clear cache after each generation
135
 
136
 
137
  def generate_explanation(issue_text, top_qualities):
 
150
  explanation = ""
151
  try:
152
  for chunk in llama_generate(prompt):
153
+ explanation += chunk
154
  except Exception as e:
155
  logging.error(f"Error during Llama generation: {e}")
156
  return "An error occurred while generating the explanation."
157
 
158
  return explanation
159
 
160
+ @spaces.GPU(duration=180) # Apply the GPU decorator *only* to the main interface
161
  def main_interface(text):
162
  if not text.strip():
163
  return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
 
165
  if len(text) < 30:
166
  return "<div style='color: red;'>Text is less than 30 characters.</div>", "", ""
167
 
168
+ device = "cuda" if torch.cuda.is_available() else "cpu" # Keep this for model_prediction
169
  results = []
170
  for model_path, model in models.items():
171
  quality_name = get_quality_name(model_path)
172
+ avg_prob = model_prediction(model, text, device) # Pass the device
173
  if avg_prob >= 0.95:
174
  results.append((quality_name, avg_prob))
175
  logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
176
 
177
+
178
  if not results:
179
  return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold. </div>", "", ""
180
 
181
  top_qualities = sorted(results, key=lambda x: x[1], reverse=True)[:3]
182
  output_html = render_html_output(top_qualities)
 
 
183
  explanation = generate_explanation(text, top_qualities)
184
 
185
+ return output_html, "", explanation
186
 
187
  def render_html_output(top_qualities):
188
  styles = """
 
237
  outputs=[
238
  gr.HTML(label="Prediction Output"),
239
  gr.Textbox(label="Predictions", visible=False),
240
+ gr.Textbox(label="Explanation", lines=5)
241
  ],
242
  title="QualityTagger",
243
  description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",