Spaces:
Running
on
Zero
Running
on
Zero
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
from huggingface_hub import login, HfFolder
|
7 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
8 |
+
from scipy.special import softmax
|
9 |
+
import logging
|
10 |
+
|
11 |
+
# Setup logging
|
12 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
|
13 |
+
|
14 |
+
# Set a seed for reproducibility
|
15 |
+
seed = 42
|
16 |
+
np.random.seed(seed)
|
17 |
+
random.seed(seed)
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
if torch.cuda.is_available():
|
20 |
+
torch.cuda.manual_seed_all(seed)
|
21 |
+
|
22 |
+
|
23 |
+
# Login to Hugging Face
|
24 |
+
token = os.getenv("hf_token")
|
25 |
+
HfFolder.save_token(token)
|
26 |
+
login(token)
|
27 |
+
|
28 |
+
# Model paths and quality mapping
|
29 |
+
model_paths = [
|
30 |
+
'karths/binary_classification_train_port',
|
31 |
+
'karths/binary_classification_train_perf',
|
32 |
+
"karths/binary_classification_train_main",
|
33 |
+
"karths/binary_classification_train_secu",
|
34 |
+
"karths/binary_classification_train_reli",
|
35 |
+
"karths/binary_classification_train_usab",
|
36 |
+
"karths/binary_classification_train_comp"
|
37 |
+
]
|
38 |
+
|
39 |
+
quality_mapping = {
|
40 |
+
'binary_classification_train_port': 'Portability',
|
41 |
+
'binary_classification_train_main': 'Maintainability',
|
42 |
+
'binary_classification_train_secu': 'Security',
|
43 |
+
'binary_classification_train_reli': 'Reliability',
|
44 |
+
'binary_classification_train_usab': 'Usability',
|
45 |
+
'binary_classification_train_perf': 'Performance',
|
46 |
+
'binary_classification_train_comp': 'Compatibility'
|
47 |
+
}
|
48 |
+
|
49 |
+
# Pre-load models and tokenizer
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
|
51 |
+
models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths}
|
52 |
+
|
53 |
+
def get_quality_name(model_name):
|
54 |
+
return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
|
55 |
+
|
56 |
+
def model_prediction(model, text, device):
|
57 |
+
model.to(device)
|
58 |
+
model.eval()
|
59 |
+
|
60 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
61 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
62 |
+
|
63 |
+
with torch.no_grad():
|
64 |
+
outputs = model(**inputs)
|
65 |
+
logits = outputs.logits
|
66 |
+
probs = softmax(logits.cpu().numpy(), axis=1)
|
67 |
+
avg_prob = np.mean(probs[:, 1])
|
68 |
+
|
69 |
+
return avg_prob
|
70 |
+
|
71 |
+
def main_interface(text):
|
72 |
+
if not text.strip():
|
73 |
+
return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", ""
|
74 |
+
|
75 |
+
# Check for text length exceeding the limit
|
76 |
+
if len(text) < 30:
|
77 |
+
return "<div style='color: red;'>Text is less than 30 characters.</div>", ""
|
78 |
+
|
79 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
80 |
+
results = []
|
81 |
+
for model_path, model in models.items():
|
82 |
+
quality_name = get_quality_name(model_path)
|
83 |
+
avg_prob = model_prediction(model, text, device)
|
84 |
+
if avg_prob >= 0.90: # Only consider probabilities >= 0.90
|
85 |
+
results.append((quality_name, avg_prob))
|
86 |
+
logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
|
87 |
+
|
88 |
+
if not results: # If no results meet the criteria
|
89 |
+
return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold. </div>", ""
|
90 |
+
|
91 |
+
top_qualities = sorted(results, key=lambda x: x[1], reverse=True)[:3]
|
92 |
+
output_html = render_html_output(top_qualities)
|
93 |
+
|
94 |
+
return output_html, ""
|
95 |
+
|
96 |
+
def render_html_output(top_qualities):
|
97 |
+
styles = """
|
98 |
+
<style>
|
99 |
+
.quality-container {
|
100 |
+
font-family: Arial, sans-serif;
|
101 |
+
text-align: center;
|
102 |
+
margin-top: 20px;
|
103 |
+
}
|
104 |
+
.quality-label, .ranking {
|
105 |
+
display: inline-block;
|
106 |
+
padding: 0.5em 1em;
|
107 |
+
font-size: 18px;
|
108 |
+
font-weight: bold;
|
109 |
+
color: white;
|
110 |
+
background-color: #007bff;
|
111 |
+
border-radius: 0.5rem;
|
112 |
+
margin-right: 10px;
|
113 |
+
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
|
114 |
+
}
|
115 |
+
.probability {
|
116 |
+
display: block;
|
117 |
+
margin-top: 10px;
|
118 |
+
font-size: 16px;
|
119 |
+
color: #007bff;
|
120 |
+
}
|
121 |
+
</style>
|
122 |
+
"""
|
123 |
+
html_content = ""
|
124 |
+
ranking_labels = ['Top 1 Prediction', 'Top 2 Prediction', 'Top 3 Prediction']
|
125 |
+
top_n = min(len(top_qualities), len(ranking_labels))
|
126 |
+
for i in range(top_n):
|
127 |
+
quality, prob = top_qualities[i]
|
128 |
+
html_content += f"""
|
129 |
+
<div class="quality-container">
|
130 |
+
<span class="ranking">{ranking_labels[i]}</span>
|
131 |
+
<span class="quality-label">{quality}</span>
|
132 |
+
</div>
|
133 |
+
"""
|
134 |
+
return styles + html_content
|
135 |
+
|
136 |
+
example_texts = [
|
137 |
+
|
138 |
+
["Identified a potential SQL injection vulnerability within the user login page. Attackers could potentially access sensitive user data without authorization. The issue was discovered during routine security audits.\n\nEnvironment: Web app version 2.3, Firefox 78, Windows 10\nReproduction: Submitting specially crafted SQL code into the username field."],
|
139 |
+
["The mobile app crashes when trying to sync large files from cloud storage. This affects both user satisfaction and data integrity, pointing towards an issue with memory management or threading.\n\nEnvironment: Mobile app version 3.4, iOS 13, iPhone X\nReproduction: Syncing more than 50 files, each over 5MB in size."],
|
140 |
+
["User interface inconsistencies across different modules of the software, affecting the learning curve and overall user satisfaction. Some settings reset unexpectedly, requiring frequent readjustments by users.\n\nEnvironment: Desktop app version 2.1, Ubuntu 20.04, Gnome\nReproduction: Adjust settings in the preferences panel, navigate away, and return."],
|
141 |
+
["Issues with newer operating systems. The application fails to start or crashes shortly after launch, likely due to deprecated libraries.\n\nEnvironment: Desktop app version 1.8, Windows 11\nReproduction: Install on a system running Windows 11, attempt to launch the application."],
|
142 |
+
|
143 |
+
["The application does not properly adjust to different screen resolutions and operating systems, resulting in UI elements being misplaced or not fully visible. This issue affects user engagement on various devices.\n\nEnvironment: Web app version 4.2, tested on Chrome 88, Firefox 85 on both Windows 10 and macOS Big Sur\nReproduction: Open the application in browsers at screen resolutions ranging from 800x600 to 2560x1440."],
|
144 |
+
|
145 |
+
]
|
146 |
+
|
147 |
+
|
148 |
+
interface = gr.Interface(
|
149 |
+
fn=main_interface,
|
150 |
+
inputs=gr.Textbox(lines=7, label="Issue Description", placeholder="Enter your issue text here"),
|
151 |
+
outputs=[gr.HTML(label="Prediction Output"), gr.Textbox(label="Predictions", visible=False)],
|
152 |
+
title="QualityTagger",
|
153 |
+
description="This tool classifies text into different quality domains such as Security, Usability, etc.",
|
154 |
+
examples=example_texts
|
155 |
+
)
|
156 |
+
|
157 |
+
interface.launch(share=True)
|