File size: 4,763 Bytes
0264912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr

# Example code snippets
VULNERABLE_EXAMPLE = """static int cirrus_bitblt_videotovideo_patterncopy(CirrusVGAState * s)\n{\n    
return cirrus_bitblt_common_patterncopy(s,\n\t\t\t\t\t    s->vram_ptr +\n                                            (s->cirrus_blt_srcaddr & ~7));\n}"""

NON_VULNERABLE_EXAMPLE = """static void loongarch_cpu_synchronize_from_tb(CPUState *cs,
\n const TranslationBlock *tb)\n{\n    LoongArchCPU *cpu = LOONGARCH_CPU(cs);\n    CPULoongArchState *env = &cpu->env;\n\n    env->pc = tb->pc;\n}"""

# Load the model and tokenizer
def load_model():
    """Load the model and tokenizer"""
    model_name = "moazx/Code-Vulnerability-Classifier_app"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    return model, tokenizer, device

# Load the model and tokenizer once when the app starts
model, tokenizer, device = load_model()

def classify_code_sample(code_sample):
    """Classify a single code sample and get probabilities"""
    inputs = tokenizer(
        code_sample,
        truncation=True,
        padding='max_length',
        max_length=512,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
    probabilities = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
    return probabilities

def analyze_code(code_input):
    """Analyze the code and return results"""
    if not code_input.strip():
        return "Please enter some code to analyze."
    
    try:
        # Get predictions
        probabilities = classify_code_sample(code_input)
        
        # Class names and confidence
        class_names = ["Non-vulnerable", "Vulnerable"]
        predicted_class_index = probabilities.argmax()
        predicted_class = class_names[predicted_class_index]
        confidence = probabilities[predicted_class_index] * 100
        
        # Prepare results
        result = f"**Prediction:** {predicted_class}\n"
        result += f"**Confidence:** {confidence:.1f}%\n\n"
        
        # Detailed probabilities
        result += "**Detailed Probabilities:**\n"
        for class_name, prob in zip(class_names, probabilities):
            result += f"- {class_name}: {prob * 100:.1f}%\n"
        
        # Additional warnings for vulnerable code
        if predicted_class == "Vulnerable":
            result += "\nโš ๏ธ **Warning:** This code has been flagged as potentially vulnerable. Please review it carefully for:\n"
            result += "- Security issues (e.g., input validation, authentication)\n"
            result += "- Implementation issues (e.g., memory management, resource handling)\n"
            result += "- Design issues (e.g., concurrency, logic errors)\n"
        
        return result
    
    except Exception as e:
        return f"Error during analysis: {str(e)}"

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# DiverseVul Code Vulnerability Classifier")
    gr.Markdown("""
    This tool analyzes code snippets for various types of vulnerabilities, including:
    - Security vulnerabilities (e.g., buffer overflows, injection flaws)
    - Memory management issues
    - Concurrency problems
    - Resource leaks
    - Logic errors
    - Performance issues
    - Reliability problems
    """)
    
    with gr.Row():
        with gr.Column():
            code_input = gr.Textbox(
                label="Enter your code snippet here:",
                placeholder="Paste your code here...",
                lines=10,
                max_lines=20,
                value=""
            )
            analyze_button = gr.Button("Analyze Code")
        
        with gr.Column():
            output = gr.Markdown(label="Analysis Results")
    
    # Example buttons
    gr.Markdown("### Try an Example")
    with gr.Row():
        vulnerable_example_button = gr.Button("๐Ÿ“‹ Load Vulnerable Example")
        non_vulnerable_example_button = gr.Button("๐Ÿ“‹ Load Non-Vulnerable Example")
    
    # Event handlers
    analyze_button.click(
        analyze_code,
        inputs=code_input,
        outputs=output
    )
    
    vulnerable_example_button.click(
        lambda: VULNERABLE_EXAMPLE,
        outputs=code_input
    )
    
    non_vulnerable_example_button.click(
        lambda: NON_VULNERABLE_EXAMPLE,
        outputs=code_input
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()