moazx commited on
Commit
0264912
·
verified ·
1 Parent(s): 28993e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -223
app.py CHANGED
@@ -1,223 +1,135 @@
1
- import streamlit as st
2
- import torch
3
- import torch.nn.functional as F
4
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
- import pandas as pd
6
-
7
- # Page configuration
8
- st.set_page_config(
9
- page_title="DiverseVul Code Vulnerability Classifier",
10
- page_icon="🔍",
11
- layout="wide"
12
- )
13
-
14
- # Example code snippets
15
- VULNERABLE_EXAMPLE = """static int cirrus_bitblt_videotovideo_patterncopy(CirrusVGAState * s)\n{\n
16
- return cirrus_bitblt_common_patterncopy(s,\n\t\t\t\t\t s->vram_ptr +\n (s->cirrus_blt_srcaddr & ~7));\n}"""
17
-
18
- NON_VULNERABLE_EXAMPLE = """static void loongarch_cpu_synchronize_from_tb(CPUState *cs,
19
- \n const TranslationBlock *tb)\n{\n LoongArchCPU *cpu = LOONGARCH_CPU(cs);\n CPULoongArchState *env = &cpu->env;\n\n env->pc = tb->pc;\n}"""
20
-
21
- @st.cache_resource
22
- def load_model():
23
- """Load the model and tokenizer with caching"""
24
- # Fine-tuned model
25
- model_name = "trained_model"
26
- tokenizer = AutoTokenizer.from_pretrained(model_name)
27
- model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
28
-
29
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- model = model.to(device)
31
- model.eval()
32
-
33
- return model, tokenizer, device
34
-
35
- def classify_code_sample(code_sample, model, tokenizer, device, max_length=512):
36
- """Classify a single code sample and get probabilities"""
37
- inputs = tokenizer(
38
- code_sample,
39
- truncation=True,
40
- padding='max_length',
41
- max_length=max_length,
42
- return_tensors="pt"
43
- ).to(device)
44
-
45
- with torch.no_grad():
46
- outputs = model(**inputs)
47
- logits = outputs.logits
48
-
49
- probabilities = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
50
- return probabilities
51
-
52
- def main():
53
- st.title("DiverseVul Code Vulnerability Classifier")
54
- st.write("""
55
- This tool analyzes code snippets for various types of vulnerabilities, including but not limited to:
56
- - Security vulnerabilities (e.g., buffer overflows, injection flaws)
57
- - Memory management issues
58
- - Concurrency problems
59
- - Resource leaks
60
- - Logic errors
61
- - Performance issues
62
- - Reliability problems
63
- """)
64
-
65
- # Load model and tokenizer
66
- try:
67
- with st.spinner("Loading model..."):
68
- model, tokenizer, device = load_model()
69
- st.success("Model loaded successfully!")
70
- except Exception as e:
71
- st.error(f"Error loading model: {str(e)}")
72
- return
73
-
74
- # Example buttons
75
- st.subheader("Try an Example")
76
- col1, col2 = st.columns(2)
77
- with col1:
78
- if st.button("📋 Load Vulnerable Example"):
79
- st.session_state['code_input'] = VULNERABLE_EXAMPLE
80
- with col2:
81
- if st.button("📋 Load Non-Vulnerable Example"):
82
- st.session_state['code_input'] = NON_VULNERABLE_EXAMPLE
83
-
84
- # Input area
85
- st.subheader("Input Code")
86
- code_input = st.text_area(
87
- "Enter your code snippet here:",
88
- value=st.session_state.get('code_input', ''),
89
- height=300,
90
- help="Paste your code here for comprehensive vulnerability analysis"
91
- )
92
-
93
- # Analysis button
94
- if st.button("Analyze Code"):
95
- if not code_input.strip():
96
- st.warning("Please enter some code to analyze.")
97
- return
98
-
99
- with st.spinner("Analyzing code..."):
100
- try:
101
- # Get predictions
102
- probabilities = classify_code_sample(code_input, model, tokenizer, device)
103
-
104
- # Create results section
105
- st.subheader("Analysis Results")
106
-
107
- # Display prediction with confidence
108
- class_names = ["Non-vulnerable", "Vulnerable"]
109
- predicted_class_index = probabilities.argmax()
110
- predicted_class = class_names[predicted_class_index]
111
- confidence = probabilities[predicted_class_index] * 100
112
-
113
- # Create columns for layout
114
- col1, col2 = st.columns(2)
115
-
116
- # Display prediction and confidence
117
- with col1:
118
- st.metric(
119
- "Prediction",
120
- predicted_class,
121
- help="The model's classification of the code"
122
- )
123
-
124
- with col2:
125
- st.metric(
126
- "Confidence",
127
- f"{confidence:.1f}%",
128
- help="How confident the model is in its prediction"
129
- )
130
-
131
- # Create a DataFrame for detailed probabilities
132
- results_df = pd.DataFrame({
133
- 'Class': class_names,
134
- 'Probability': probabilities
135
- })
136
-
137
- # Display probability distribution
138
- st.subheader("Detailed Probabilities")
139
- st.bar_chart(
140
- results_df.set_index('Class')['Probability']
141
- )
142
-
143
- # Additional information and disclaimers
144
- if predicted_class == "Vulnerable":
145
- st.warning("""
146
- ⚠️ This code has been flagged as potentially vulnerable.
147
- Please review it carefully for various types of vulnerabilities including:
148
-
149
- Security:
150
- - Input validation
151
- - Authentication issues
152
- - Access control problems
153
-
154
- Implementation:
155
- - Memory management
156
- - Resource handling
157
- - Error handling
158
-
159
- Design:
160
- - Concurrency issues
161
- - Logic errors
162
- - Performance problems
163
-
164
- Best Practices:
165
- - Code structure
166
- - Error handling patterns
167
- - Resource cleanup
168
- """)
169
-
170
- st.info("""
171
- Note: This tool is trained on the DiverseVul dataset, which covers 150 different
172
- types of Common Weakness Enumeration (CWE) categories. While comprehensive, it
173
- should be used as part of a larger code review process. False positives and
174
- negatives are possible.
175
- """)
176
-
177
- except Exception as e:
178
- st.error(f"Error during analysis: {str(e)}")
179
-
180
- # Add sidebar with information
181
- with st.sidebar:
182
- st.header("About")
183
- st.write("""
184
- This tool uses a machine learning model trained on the DiverseVul dataset, which includes:
185
- - 18,945 vulnerable functions
186
- - 330,492 non-vulnerable functions
187
- - 150 different CWE types
188
- - Code from thousands of real-world projects
189
- """)
190
-
191
- st.subheader("Example Code Explanation")
192
- st.write("""
193
- The vulnerable example contains:
194
- - SQL injection vulnerability
195
- - Path traversal vulnerability
196
- - Buffer overflow vulnerability
197
-
198
- The non-vulnerable example shows:
199
- - Parameterized SQL queries
200
- - Safe path validation
201
- - Proper buffer bounds checking
202
- """)
203
-
204
- st.subheader("How to Use")
205
- st.write("""
206
- 1. Click an example button or paste your code
207
- 2. Click 'Analyze Code'
208
- 3. Review the results and probability scores
209
- 4. Consider all flagged issues in context
210
- 5. Verify findings with manual review
211
- """)
212
-
213
- st.subheader("Limitations")
214
- st.write("""
215
- - The model may not catch all vulnerabilities
216
- - Some safe code might be flagged as vulnerable
217
- - Results should be verified by domain experts
218
- - Performance varies across different CWE types
219
- - Best used as part of a comprehensive code review process
220
- """)
221
-
222
- if __name__ == "__main__":
223
- main()
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import gradio as gr
5
+
6
+ # Example code snippets
7
+ VULNERABLE_EXAMPLE = """static int cirrus_bitblt_videotovideo_patterncopy(CirrusVGAState * s)\n{\n
8
+ return cirrus_bitblt_common_patterncopy(s,\n\t\t\t\t\t s->vram_ptr +\n (s->cirrus_blt_srcaddr & ~7));\n}"""
9
+
10
+ NON_VULNERABLE_EXAMPLE = """static void loongarch_cpu_synchronize_from_tb(CPUState *cs,
11
+ \n const TranslationBlock *tb)\n{\n LoongArchCPU *cpu = LOONGARCH_CPU(cs);\n CPULoongArchState *env = &cpu->env;\n\n env->pc = tb->pc;\n}"""
12
+
13
+ # Load the model and tokenizer
14
+ def load_model():
15
+ """Load the model and tokenizer"""
16
+ model_name = "moazx/Code-Vulnerability-Classifier_app"
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
19
+
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+ model = model.to(device)
22
+ model.eval()
23
+
24
+ return model, tokenizer, device
25
+
26
+ # Load the model and tokenizer once when the app starts
27
+ model, tokenizer, device = load_model()
28
+
29
+ def classify_code_sample(code_sample):
30
+ """Classify a single code sample and get probabilities"""
31
+ inputs = tokenizer(
32
+ code_sample,
33
+ truncation=True,
34
+ padding='max_length',
35
+ max_length=512,
36
+ return_tensors="pt"
37
+ ).to(device)
38
+
39
+ with torch.no_grad():
40
+ outputs = model(**inputs)
41
+ logits = outputs.logits
42
+
43
+ probabilities = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
44
+ return probabilities
45
+
46
+ def analyze_code(code_input):
47
+ """Analyze the code and return results"""
48
+ if not code_input.strip():
49
+ return "Please enter some code to analyze."
50
+
51
+ try:
52
+ # Get predictions
53
+ probabilities = classify_code_sample(code_input)
54
+
55
+ # Class names and confidence
56
+ class_names = ["Non-vulnerable", "Vulnerable"]
57
+ predicted_class_index = probabilities.argmax()
58
+ predicted_class = class_names[predicted_class_index]
59
+ confidence = probabilities[predicted_class_index] * 100
60
+
61
+ # Prepare results
62
+ result = f"**Prediction:** {predicted_class}\n"
63
+ result += f"**Confidence:** {confidence:.1f}%\n\n"
64
+
65
+ # Detailed probabilities
66
+ result += "**Detailed Probabilities:**\n"
67
+ for class_name, prob in zip(class_names, probabilities):
68
+ result += f"- {class_name}: {prob * 100:.1f}%\n"
69
+
70
+ # Additional warnings for vulnerable code
71
+ if predicted_class == "Vulnerable":
72
+ result += "\n⚠️ **Warning:** This code has been flagged as potentially vulnerable. Please review it carefully for:\n"
73
+ result += "- Security issues (e.g., input validation, authentication)\n"
74
+ result += "- Implementation issues (e.g., memory management, resource handling)\n"
75
+ result += "- Design issues (e.g., concurrency, logic errors)\n"
76
+
77
+ return result
78
+
79
+ except Exception as e:
80
+ return f"Error during analysis: {str(e)}"
81
+
82
+ # Gradio Interface
83
+ with gr.Blocks() as demo:
84
+ gr.Markdown("# DiverseVul Code Vulnerability Classifier")
85
+ gr.Markdown("""
86
+ This tool analyzes code snippets for various types of vulnerabilities, including:
87
+ - Security vulnerabilities (e.g., buffer overflows, injection flaws)
88
+ - Memory management issues
89
+ - Concurrency problems
90
+ - Resource leaks
91
+ - Logic errors
92
+ - Performance issues
93
+ - Reliability problems
94
+ """)
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ code_input = gr.Textbox(
99
+ label="Enter your code snippet here:",
100
+ placeholder="Paste your code here...",
101
+ lines=10,
102
+ max_lines=20,
103
+ value=""
104
+ )
105
+ analyze_button = gr.Button("Analyze Code")
106
+
107
+ with gr.Column():
108
+ output = gr.Markdown(label="Analysis Results")
109
+
110
+ # Example buttons
111
+ gr.Markdown("### Try an Example")
112
+ with gr.Row():
113
+ vulnerable_example_button = gr.Button("📋 Load Vulnerable Example")
114
+ non_vulnerable_example_button = gr.Button("📋 Load Non-Vulnerable Example")
115
+
116
+ # Event handlers
117
+ analyze_button.click(
118
+ analyze_code,
119
+ inputs=code_input,
120
+ outputs=output
121
+ )
122
+
123
+ vulnerable_example_button.click(
124
+ lambda: VULNERABLE_EXAMPLE,
125
+ outputs=code_input
126
+ )
127
+
128
+ non_vulnerable_example_button.click(
129
+ lambda: NON_VULNERABLE_EXAMPLE,
130
+ outputs=code_input
131
+ )
132
+
133
+ # Launch the Gradio app
134
+ if __name__ == "__main__":
135
+ demo.launch()