karths commited on
Commit
740d8bb
·
verified ·
1 Parent(s): 1afa8fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -58
app.py CHANGED
@@ -31,33 +31,24 @@ token = os.getenv("hf_token")
31
  HfFolder.save_token(token)
32
  login(token)
33
 
34
- # --- Quality Prediction Model Setup ---
35
  model_paths = [
36
- 'karths/binary_classification_train_test',
37
- "karths/binary_classification_train_process",
38
- "karths/binary_classification_train_infrastructure",
39
- "karths/binary_classification_train_documentation",
40
- "karths/binary_classification_train_design",
41
- "karths/binary_classification_train_defect",
42
- "karths/binary_classification_train_code",
43
- "karths/binary_classification_train_build",
44
- "karths/binary_classification_train_automation",
45
- "karths/binary_classification_train_people",
46
- "karths/binary_classification_train_architecture",
47
  ]
48
 
49
  quality_mapping = {
50
- 'binary_classification_train_test': 'Test',
51
- 'binary_classification_train_process': 'Process',
52
- 'binary_classification_train_infrastructure': 'Infrastructure',
53
- 'binary_classification_train_documentation': 'Documentation',
54
- 'binary_classification_train_design': 'Design',
55
- 'binary_classification_train_defect': 'Defect',
56
- 'binary_classification_train_code': 'Code',
57
- 'binary_classification_train_build': 'Build',
58
- 'binary_classification_train_automation': 'Automation',
59
- 'binary_classification_train_people': 'People',
60
- 'binary_classification_train_architecture': 'Architecture'
61
  }
62
 
63
  # Pre-load models and tokenizer for quality prediction
@@ -82,9 +73,9 @@ def model_prediction(model, text, device):
82
  return avg_prob
83
 
84
  # --- Llama 3.2 3B Model Setup ---
85
- LLAMA_MAX_MAX_NEW_TOKENS = 2048
86
- LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Reduced for efficiency
87
- LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) # Reduced
88
  llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
89
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
90
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
@@ -105,7 +96,7 @@ def llama_generate(
105
  top_p: float = 0.9,
106
  top_k: int = 50,
107
  repetition_penalty: float = 1.2,
108
- ) -> Iterator[str]:
109
 
110
  inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
111
 
@@ -113,25 +104,24 @@ def llama_generate(
113
  inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
114
  gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
115
 
116
- streamer = TextIteratorStreamer(llama_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
117
- generate_kwargs = dict(
118
- inputs,
119
- streamer=streamer,
120
- max_new_tokens=max_new_tokens,
121
- do_sample=True,
122
- top_p=top_p,
123
- top_k=top_k,
124
- temperature=temperature,
125
- num_beams=1,
126
- repetition_penalty=repetition_penalty,
127
- )
128
- t = Thread(target=llama_model.generate, kwargs=generate_kwargs)
129
- t.start()
130
- outputs = []
131
- for text in streamer:
132
- outputs.append(text)
133
- yield "".join(outputs)
134
  torch.cuda.empty_cache() # Clear cache after each generation
 
135
 
136
 
137
  def generate_explanation(issue_text, top_qualities):
@@ -139,25 +129,27 @@ def generate_explanation(issue_text, top_qualities):
139
  if not top_qualities:
140
  return "No explanation available as no quality tags were predicted."
141
 
142
- prompt = f"""
143
- Given the following issue description:
144
- ---
145
- {issue_text}
146
- ---
147
- Explain why this issue might be classified under the following quality categories: {', '.join([q[0] for q in top_qualities])}.
148
- Provide a concise explanation for each category, relating it back to the issue description.
149
- """
150
- explanation = ""
 
 
 
151
  try:
152
- for chunk in llama_generate(prompt):
153
- explanation += chunk
154
  except Exception as e:
155
  logging.error(f"Error during Llama generation: {e}")
156
  return "An error occurred while generating the explanation."
157
 
158
- return explanation
159
 
160
- # @spaces.GPU(duration=180) # Apply the GPU decorator *only* to the main interface
161
  def main_interface(text):
162
  if not text.strip():
163
  return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
 
31
  HfFolder.save_token(token)
32
  login(token)
33
 
 
34
  model_paths = [
35
+ 'karths/binary_classification_train_port',
36
+ 'karths/binary_classification_train_perf',
37
+ "karths/binary_classification_train_main",
38
+ "karths/binary_classification_train_secu",
39
+ "karths/binary_classification_train_reli",
40
+ "karths/binary_classification_train_usab",
41
+ "karths/binary_classification_train_comp"
 
 
 
 
42
  ]
43
 
44
  quality_mapping = {
45
+ 'binary_classification_train_port': 'Portability',
46
+ 'binary_classification_train_main': 'Maintainability',
47
+ 'binary_classification_train_secu': 'Security',
48
+ 'binary_classification_train_reli': 'Reliability',
49
+ 'binary_classification_train_usab': 'Usability',
50
+ 'binary_classification_train_perf': 'Performance',
51
+ 'binary_classification_train_comp': 'Compatibility'
 
 
 
 
52
  }
53
 
54
  # Pre-load models and tokenizer for quality prediction
 
73
  return avg_prob
74
 
75
  # --- Llama 3.2 3B Model Setup ---
76
+ LLAMA_MAX_MAX_NEW_TOKENS = 512 # Max tokens for Explanation
77
+ LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Max tokens for explantion
78
+ LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "700")) # Reduced
79
  llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
80
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
81
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
 
96
  top_p: float = 0.9,
97
  top_k: int = 50,
98
  repetition_penalty: float = 1.2,
99
+ ) -> str: # Return string, not iterator
100
 
101
  inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
102
 
 
104
  inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
105
  gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
106
 
107
+ # Generate *without* streaming
108
+ with torch.no_grad(): # Ensure no gradient calculation
109
+ generate_ids = llama_model.generate(
110
+ **inputs,
111
+ max_new_tokens=max_new_tokens,
112
+ do_sample=True,
113
+ top_p=top_p,
114
+ top_k=top_k,
115
+ temperature=temperature,
116
+ num_beams=1,
117
+ repetition_penalty=repetition_penalty,
118
+ pad_token_id=llama_tokenizer.pad_token_id, # Pass pad_token_id here
119
+ eos_token_id=llama_tokenizer.eos_token_id, # Pass eos_token_id here
120
+
121
+ )
122
+ output_text = llama_tokenizer.decode(generate_ids[0], skip_special_tokens=True)
 
 
123
  torch.cuda.empty_cache() # Clear cache after each generation
124
+ return output_text
125
 
126
 
127
  def generate_explanation(issue_text, top_qualities):
 
129
  if not top_qualities:
130
  return "No explanation available as no quality tags were predicted."
131
 
132
+ # Build the prompt, explicitly mentioning each quality
133
+ prompt_parts = [
134
+ "Given the following issue description:\n---\n",
135
+ issue_text,
136
+ "\n---\n",
137
+ "Explain why this issue might be classified under the following quality categories. Provide a concise explanation for each category, relating it back to the issue description:\n"
138
+ ]
139
+ for quality, _ in top_qualities: # Iterate through qualities
140
+ prompt_parts.append(f"- {quality}\n")
141
+
142
+ prompt = "".join(prompt_parts)
143
+
144
  try:
145
+ explanation = llama_generate(prompt) # Get the explanation (not streamed)
146
+ return explanation
147
  except Exception as e:
148
  logging.error(f"Error during Llama generation: {e}")
149
  return "An error occurred while generating the explanation."
150
 
 
151
 
152
+ # @spaces.GPU(duration=60) # Apply the GPU decorator *only* to the main interface
153
  def main_interface(text):
154
  if not text.strip():
155
  return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""