karths commited on
Commit
e428475
·
verified ·
1 Parent(s): 3728cd0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from huggingface_hub import login, HfFolder
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from scipy.special import softmax
9
+ import logging
10
+
11
+ # Setup logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
13
+
14
+ # Set a seed for reproducibility
15
+ seed = 42
16
+ np.random.seed(seed)
17
+ random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ if torch.cuda.is_available():
20
+ torch.cuda.manual_seed_all(seed)
21
+
22
+
23
+ # Login to Hugging Face
24
+ token = os.getenv("hf_token")
25
+ HfFolder.save_token(token)
26
+ login(token)
27
+
28
+ # Model paths and quality mapping
29
+ model_paths = [
30
+ 'karths/binary_classification_train_port',
31
+ 'karths/binary_classification_train_perf',
32
+ "karths/binary_classification_train_main",
33
+ "karths/binary_classification_train_secu",
34
+ "karths/binary_classification_train_reli",
35
+ "karths/binary_classification_train_usab",
36
+ "karths/binary_classification_train_comp"
37
+ ]
38
+
39
+ quality_mapping = {
40
+ 'binary_classification_train_port': 'Portability',
41
+ 'binary_classification_train_main': 'Maintainability',
42
+ 'binary_classification_train_secu': 'Security',
43
+ 'binary_classification_train_reli': 'Reliability',
44
+ 'binary_classification_train_usab': 'Usability',
45
+ 'binary_classification_train_perf': 'Performance',
46
+ 'binary_classification_train_comp': 'Compatibility'
47
+ }
48
+
49
+ # Pre-load models and tokenizer
50
+ tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
51
+ models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths}
52
+
53
+ def get_quality_name(model_name):
54
+ return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
55
+
56
+ def model_prediction(model, text, device):
57
+ model.to(device)
58
+ model.eval()
59
+
60
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
61
+ inputs = {k: v.to(device) for k, v in inputs.items()}
62
+
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ logits = outputs.logits
66
+ probs = softmax(logits.cpu().numpy(), axis=1)
67
+ avg_prob = np.mean(probs[:, 1])
68
+
69
+ return avg_prob
70
+
71
+ def main_interface(text):
72
+ if not text.strip():
73
+ return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", ""
74
+
75
+ # Check for text length exceeding the limit
76
+ if len(text) < 30:
77
+ return "<div style='color: red;'>Text is less than 30 characters.</div>", ""
78
+
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ results = []
81
+ for model_path, model in models.items():
82
+ quality_name = get_quality_name(model_path)
83
+ avg_prob = model_prediction(model, text, device)
84
+ if avg_prob >= 0.90: # Only consider probabilities >= 0.90
85
+ results.append((quality_name, avg_prob))
86
+ logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
87
+
88
+ if not results: # If no results meet the criteria
89
+ return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold. </div>", ""
90
+
91
+ top_qualities = sorted(results, key=lambda x: x[1], reverse=True)[:3]
92
+ output_html = render_html_output(top_qualities)
93
+
94
+ return output_html, ""
95
+
96
+ def render_html_output(top_qualities):
97
+ styles = """
98
+ <style>
99
+ .quality-container {
100
+ font-family: Arial, sans-serif;
101
+ text-align: center;
102
+ margin-top: 20px;
103
+ }
104
+ .quality-label, .ranking {
105
+ display: inline-block;
106
+ padding: 0.5em 1em;
107
+ font-size: 18px;
108
+ font-weight: bold;
109
+ color: white;
110
+ background-color: #007bff;
111
+ border-radius: 0.5rem;
112
+ margin-right: 10px;
113
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
114
+ }
115
+ .probability {
116
+ display: block;
117
+ margin-top: 10px;
118
+ font-size: 16px;
119
+ color: #007bff;
120
+ }
121
+ </style>
122
+ """
123
+ html_content = ""
124
+ ranking_labels = ['Top 1 Prediction', 'Top 2 Prediction', 'Top 3 Prediction']
125
+ top_n = min(len(top_qualities), len(ranking_labels))
126
+ for i in range(top_n):
127
+ quality, prob = top_qualities[i]
128
+ html_content += f"""
129
+ <div class="quality-container">
130
+ <span class="ranking">{ranking_labels[i]}</span>
131
+ <span class="quality-label">{quality}</span>
132
+ </div>
133
+ """
134
+ return styles + html_content
135
+
136
+ example_texts = [
137
+
138
+ ["Identified a potential SQL injection vulnerability within the user login page. Attackers could potentially access sensitive user data without authorization. The issue was discovered during routine security audits.\n\nEnvironment: Web app version 2.3, Firefox 78, Windows 10\nReproduction: Submitting specially crafted SQL code into the username field."],
139
+ ["The mobile app crashes when trying to sync large files from cloud storage. This affects both user satisfaction and data integrity, pointing towards an issue with memory management or threading.\n\nEnvironment: Mobile app version 3.4, iOS 13, iPhone X\nReproduction: Syncing more than 50 files, each over 5MB in size."],
140
+ ["User interface inconsistencies across different modules of the software, affecting the learning curve and overall user satisfaction. Some settings reset unexpectedly, requiring frequent readjustments by users.\n\nEnvironment: Desktop app version 2.1, Ubuntu 20.04, Gnome\nReproduction: Adjust settings in the preferences panel, navigate away, and return."],
141
+ ["Issues with newer operating systems. The application fails to start or crashes shortly after launch, likely due to deprecated libraries.\n\nEnvironment: Desktop app version 1.8, Windows 11\nReproduction: Install on a system running Windows 11, attempt to launch the application."],
142
+
143
+ ["The application does not properly adjust to different screen resolutions and operating systems, resulting in UI elements being misplaced or not fully visible. This issue affects user engagement on various devices.\n\nEnvironment: Web app version 4.2, tested on Chrome 88, Firefox 85 on both Windows 10 and macOS Big Sur\nReproduction: Open the application in browsers at screen resolutions ranging from 800x600 to 2560x1440."],
144
+
145
+ ]
146
+
147
+
148
+ interface = gr.Interface(
149
+ fn=main_interface,
150
+ inputs=gr.Textbox(lines=7, label="Issue Description", placeholder="Enter your issue text here"),
151
+ outputs=[gr.HTML(label="Prediction Output"), gr.Textbox(label="Predictions", visible=False)],
152
+ title="QualityTagger",
153
+ description="This tool classifies text into different quality domains such as Security, Usability, etc.",
154
+ examples=example_texts
155
+ )
156
+
157
+ interface.launch(share=True)