Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -75,7 +75,7 @@ def model_prediction(model, text, device):
|
|
| 75 |
# --- Llama 3.2 3B Model Setup ---
|
| 76 |
LLAMA_MAX_MAX_NEW_TOKENS = 512 # Max tokens for Explanation
|
| 77 |
LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Max tokens for explantion
|
| 78 |
-
LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "
|
| 79 |
llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
|
| 80 |
llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
| 81 |
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
|
|
@@ -125,28 +125,43 @@ def llama_generate(
|
|
| 125 |
|
| 126 |
|
| 127 |
def generate_explanation(issue_text, top_qualities):
|
| 128 |
-
"""Generates an explanation using Llama 3.2 3B."""
|
| 129 |
if not top_qualities:
|
| 130 |
-
return "No explanation available as no quality tags were predicted
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
for quality, _ in top_qualities:
|
| 140 |
-
|
| 141 |
|
| 142 |
-
prompt = "".join(prompt_parts)
|
| 143 |
|
| 144 |
try:
|
| 145 |
-
explanation = llama_generate(prompt)
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
except Exception as e:
|
| 148 |
logging.error(f"Error during Llama generation: {e}")
|
| 149 |
-
return "An error occurred while generating the explanation
|
| 150 |
|
| 151 |
|
| 152 |
# @spaces.GPU(duration=60) # Apply the GPU decorator *only* to the main interface
|
|
@@ -229,10 +244,10 @@ interface = gr.Interface(
|
|
| 229 |
outputs=[
|
| 230 |
gr.HTML(label="Prediction Output"),
|
| 231 |
gr.Textbox(label="Predictions", visible=False),
|
| 232 |
-
gr.
|
| 233 |
],
|
| 234 |
title="QualityTagger",
|
| 235 |
description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
|
| 236 |
examples=example_texts
|
| 237 |
)
|
| 238 |
-
interface.launch(share=True)
|
|
|
|
| 75 |
# --- Llama 3.2 3B Model Setup ---
|
| 76 |
LLAMA_MAX_MAX_NEW_TOKENS = 512 # Max tokens for Explanation
|
| 77 |
LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Max tokens for explantion
|
| 78 |
+
LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) # Reduced
|
| 79 |
llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
|
| 80 |
llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
|
| 81 |
llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
|
|
|
|
| 125 |
|
| 126 |
|
| 127 |
def generate_explanation(issue_text, top_qualities):
|
|
|
|
| 128 |
if not top_qualities:
|
| 129 |
+
return "<div style='color: red;'>No explanation available as no quality tags were predicted.</div>"
|
| 130 |
|
| 131 |
+
prompt = f"""
|
| 132 |
+
Given the following issue description:
|
| 133 |
+
---
|
| 134 |
+
{issue_text}
|
| 135 |
+
---
|
| 136 |
+
Explain why this issue might be classified under the following quality categories. Provide a concise explanation for each category, relating it back to the issue description:
|
| 137 |
+
"""
|
| 138 |
+
for quality, _ in top_qualities:
|
| 139 |
+
prompt += f"- {quality}\n"
|
| 140 |
|
|
|
|
| 141 |
|
| 142 |
try:
|
| 143 |
+
explanation = llama_generate(prompt)
|
| 144 |
+
# Format the explanation for better readability
|
| 145 |
+
formatted_explanation = ""
|
| 146 |
+
for quality, _ in top_qualities:
|
| 147 |
+
formatted_explanation += f"<p><b>{quality}:</b></p>" # Bold the quality name
|
| 148 |
+
# Find the explanation for this specific quality. This is a simple
|
| 149 |
+
# approach that works if Llama follows the prompt structure.
|
| 150 |
+
# A more robust approach might use regex or sentence embeddings.
|
| 151 |
+
start = explanation.find(quality)
|
| 152 |
+
if start != -1:
|
| 153 |
+
start += len(quality) + 2 # Move past "Quality:"
|
| 154 |
+
end = explanation.find("\n", start) # Find next newline
|
| 155 |
+
if end == -1:
|
| 156 |
+
end = len(explanation)
|
| 157 |
+
formatted_explanation += f"<p>{explanation[start:end].strip()}</p>" # Add the explanation text
|
| 158 |
+
else:
|
| 159 |
+
formatted_explanation += f"<p>Explanation for {quality} not found.</p>"
|
| 160 |
+
|
| 161 |
+
return f"<div style='overflow-y: scroll; max-height: 400px;'>{formatted_explanation}</div>" #Added scroll
|
| 162 |
except Exception as e:
|
| 163 |
logging.error(f"Error during Llama generation: {e}")
|
| 164 |
+
return "<div style='color: red;'>An error occurred while generating the explanation.</div>"
|
| 165 |
|
| 166 |
|
| 167 |
# @spaces.GPU(duration=60) # Apply the GPU decorator *only* to the main interface
|
|
|
|
| 244 |
outputs=[
|
| 245 |
gr.HTML(label="Prediction Output"),
|
| 246 |
gr.Textbox(label="Predictions", visible=False),
|
| 247 |
+
gr.HTML(label="Explanation") # Change to gr.HTML
|
| 248 |
],
|
| 249 |
title="QualityTagger",
|
| 250 |
description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
|
| 251 |
examples=example_texts
|
| 252 |
)
|
| 253 |
+
interface.launch(share=True)
|