Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -31,33 +31,24 @@ token = os.getenv("hf_token")
|
|
31 |
HfFolder.save_token(token)
|
32 |
login(token)
|
33 |
|
34 |
-
# --- Quality Prediction Model Setup ---
|
35 |
model_paths = [
|
36 |
-
'karths/
|
37 |
-
|
38 |
-
"karths/
|
39 |
-
"karths/
|
40 |
-
"karths/
|
41 |
-
"karths/
|
42 |
-
"karths/
|
43 |
-
"karths/binary_classification_train_build",
|
44 |
-
"karths/binary_classification_train_automation",
|
45 |
-
"karths/binary_classification_train_people",
|
46 |
-
"karths/binary_classification_train_architecture",
|
47 |
]
|
48 |
|
49 |
quality_mapping = {
|
50 |
-
'
|
51 |
-
'
|
52 |
-
'
|
53 |
-
'
|
54 |
-
'
|
55 |
-
'
|
56 |
-
'
|
57 |
-
'binary_classification_train_build': 'Build',
|
58 |
-
'binary_classification_train_automation': 'Automation',
|
59 |
-
'binary_classification_train_people': 'People',
|
60 |
-
'binary_classification_train_architecture': 'Architecture'
|
61 |
}
|
62 |
|
63 |
# Pre-load models and tokenizer for quality prediction
|
@@ -82,9 +73,9 @@ def model_prediction(model, text, device):
|
|
82 |
return avg_prob
|
83 |
|
84 |
# --- Llama 3.2 3B Model Setup ---
|
85 |
-
LLAMA_MAX_MAX_NEW_TOKENS =
|
86 |
-
LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 #
|
87 |
-
LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "
|
88 |
llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
|
89 |
llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
90 |
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
|
@@ -105,7 +96,7 @@ def llama_generate(
|
|
105 |
top_p: float = 0.9,
|
106 |
top_k: int = 50,
|
107 |
repetition_penalty: float = 1.2,
|
108 |
-
) ->
|
109 |
|
110 |
inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
|
111 |
|
@@ -113,25 +104,24 @@ def llama_generate(
|
|
113 |
inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
|
114 |
gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
outputs.append(text)
|
133 |
-
yield "".join(outputs)
|
134 |
torch.cuda.empty_cache() # Clear cache after each generation
|
|
|
135 |
|
136 |
|
137 |
def generate_explanation(issue_text, top_qualities):
|
@@ -139,25 +129,27 @@ def generate_explanation(issue_text, top_qualities):
|
|
139 |
if not top_qualities:
|
140 |
return "No explanation available as no quality tags were predicted."
|
141 |
|
142 |
-
prompt
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
151 |
try:
|
152 |
-
|
153 |
-
|
154 |
except Exception as e:
|
155 |
logging.error(f"Error during Llama generation: {e}")
|
156 |
return "An error occurred while generating the explanation."
|
157 |
|
158 |
-
return explanation
|
159 |
|
160 |
-
# @spaces.GPU(duration=
|
161 |
def main_interface(text):
|
162 |
if not text.strip():
|
163 |
return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
|
|
|
31 |
HfFolder.save_token(token)
|
32 |
login(token)
|
33 |
|
|
|
34 |
model_paths = [
|
35 |
+
'karths/binary_classification_train_port',
|
36 |
+
'karths/binary_classification_train_perf',
|
37 |
+
"karths/binary_classification_train_main",
|
38 |
+
"karths/binary_classification_train_secu",
|
39 |
+
"karths/binary_classification_train_reli",
|
40 |
+
"karths/binary_classification_train_usab",
|
41 |
+
"karths/binary_classification_train_comp"
|
|
|
|
|
|
|
|
|
42 |
]
|
43 |
|
44 |
quality_mapping = {
|
45 |
+
'binary_classification_train_port': 'Portability',
|
46 |
+
'binary_classification_train_main': 'Maintainability',
|
47 |
+
'binary_classification_train_secu': 'Security',
|
48 |
+
'binary_classification_train_reli': 'Reliability',
|
49 |
+
'binary_classification_train_usab': 'Usability',
|
50 |
+
'binary_classification_train_perf': 'Performance',
|
51 |
+
'binary_classification_train_comp': 'Compatibility'
|
|
|
|
|
|
|
|
|
52 |
}
|
53 |
|
54 |
# Pre-load models and tokenizer for quality prediction
|
|
|
73 |
return avg_prob
|
74 |
|
75 |
# --- Llama 3.2 3B Model Setup ---
|
76 |
+
LLAMA_MAX_MAX_NEW_TOKENS = 512 # Max tokens for Explanation
|
77 |
+
LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Max tokens for explantion
|
78 |
+
LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "700")) # Reduced
|
79 |
llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
|
80 |
llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
81 |
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
|
|
|
96 |
top_p: float = 0.9,
|
97 |
top_k: int = 50,
|
98 |
repetition_penalty: float = 1.2,
|
99 |
+
) -> str: # Return string, not iterator
|
100 |
|
101 |
inputs = llama_tokenizer(message, return_tensors="pt", padding=True, truncation=True, max_length=LLAMA_MAX_INPUT_TOKEN_LENGTH).to(llama_model.device)
|
102 |
|
|
|
104 |
inputs.input_ids = inputs.input_ids[:, -LLAMA_MAX_INPUT_TOKEN_LENGTH:]
|
105 |
gr.Warning(f"Trimmed input from conversation as it was longer than {LLAMA_MAX_INPUT_TOKEN_LENGTH} tokens.")
|
106 |
|
107 |
+
# Generate *without* streaming
|
108 |
+
with torch.no_grad(): # Ensure no gradient calculation
|
109 |
+
generate_ids = llama_model.generate(
|
110 |
+
**inputs,
|
111 |
+
max_new_tokens=max_new_tokens,
|
112 |
+
do_sample=True,
|
113 |
+
top_p=top_p,
|
114 |
+
top_k=top_k,
|
115 |
+
temperature=temperature,
|
116 |
+
num_beams=1,
|
117 |
+
repetition_penalty=repetition_penalty,
|
118 |
+
pad_token_id=llama_tokenizer.pad_token_id, # Pass pad_token_id here
|
119 |
+
eos_token_id=llama_tokenizer.eos_token_id, # Pass eos_token_id here
|
120 |
+
|
121 |
+
)
|
122 |
+
output_text = llama_tokenizer.decode(generate_ids[0], skip_special_tokens=True)
|
|
|
|
|
123 |
torch.cuda.empty_cache() # Clear cache after each generation
|
124 |
+
return output_text
|
125 |
|
126 |
|
127 |
def generate_explanation(issue_text, top_qualities):
|
|
|
129 |
if not top_qualities:
|
130 |
return "No explanation available as no quality tags were predicted."
|
131 |
|
132 |
+
# Build the prompt, explicitly mentioning each quality
|
133 |
+
prompt_parts = [
|
134 |
+
"Given the following issue description:\n---\n",
|
135 |
+
issue_text,
|
136 |
+
"\n---\n",
|
137 |
+
"Explain why this issue might be classified under the following quality categories. Provide a concise explanation for each category, relating it back to the issue description:\n"
|
138 |
+
]
|
139 |
+
for quality, _ in top_qualities: # Iterate through qualities
|
140 |
+
prompt_parts.append(f"- {quality}\n")
|
141 |
+
|
142 |
+
prompt = "".join(prompt_parts)
|
143 |
+
|
144 |
try:
|
145 |
+
explanation = llama_generate(prompt) # Get the explanation (not streamed)
|
146 |
+
return explanation
|
147 |
except Exception as e:
|
148 |
logging.error(f"Error during Llama generation: {e}")
|
149 |
return "An error occurred while generating the explanation."
|
150 |
|
|
|
151 |
|
152 |
+
# @spaces.GPU(duration=60) # Apply the GPU decorator *only* to the main interface
|
153 |
def main_interface(text):
|
154 |
if not text.strip():
|
155 |
return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
|