Spaces:
Sleeping
Sleeping
File size: 4,763 Bytes
0264912 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
# Example code snippets
VULNERABLE_EXAMPLE = """static int cirrus_bitblt_videotovideo_patterncopy(CirrusVGAState * s)\n{\n
return cirrus_bitblt_common_patterncopy(s,\n\t\t\t\t\t s->vram_ptr +\n (s->cirrus_blt_srcaddr & ~7));\n}"""
NON_VULNERABLE_EXAMPLE = """static void loongarch_cpu_synchronize_from_tb(CPUState *cs,
\n const TranslationBlock *tb)\n{\n LoongArchCPU *cpu = LOONGARCH_CPU(cs);\n CPULoongArchState *env = &cpu->env;\n\n env->pc = tb->pc;\n}"""
# Load the model and tokenizer
def load_model():
"""Load the model and tokenizer"""
model_name = "moazx/Code-Vulnerability-Classifier_app"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
return model, tokenizer, device
# Load the model and tokenizer once when the app starts
model, tokenizer, device = load_model()
def classify_code_sample(code_sample):
"""Classify a single code sample and get probabilities"""
inputs = tokenizer(
code_sample,
truncation=True,
padding='max_length',
max_length=512,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
return probabilities
def analyze_code(code_input):
"""Analyze the code and return results"""
if not code_input.strip():
return "Please enter some code to analyze."
try:
# Get predictions
probabilities = classify_code_sample(code_input)
# Class names and confidence
class_names = ["Non-vulnerable", "Vulnerable"]
predicted_class_index = probabilities.argmax()
predicted_class = class_names[predicted_class_index]
confidence = probabilities[predicted_class_index] * 100
# Prepare results
result = f"**Prediction:** {predicted_class}\n"
result += f"**Confidence:** {confidence:.1f}%\n\n"
# Detailed probabilities
result += "**Detailed Probabilities:**\n"
for class_name, prob in zip(class_names, probabilities):
result += f"- {class_name}: {prob * 100:.1f}%\n"
# Additional warnings for vulnerable code
if predicted_class == "Vulnerable":
result += "\nโ ๏ธ **Warning:** This code has been flagged as potentially vulnerable. Please review it carefully for:\n"
result += "- Security issues (e.g., input validation, authentication)\n"
result += "- Implementation issues (e.g., memory management, resource handling)\n"
result += "- Design issues (e.g., concurrency, logic errors)\n"
return result
except Exception as e:
return f"Error during analysis: {str(e)}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# DiverseVul Code Vulnerability Classifier")
gr.Markdown("""
This tool analyzes code snippets for various types of vulnerabilities, including:
- Security vulnerabilities (e.g., buffer overflows, injection flaws)
- Memory management issues
- Concurrency problems
- Resource leaks
- Logic errors
- Performance issues
- Reliability problems
""")
with gr.Row():
with gr.Column():
code_input = gr.Textbox(
label="Enter your code snippet here:",
placeholder="Paste your code here...",
lines=10,
max_lines=20,
value=""
)
analyze_button = gr.Button("Analyze Code")
with gr.Column():
output = gr.Markdown(label="Analysis Results")
# Example buttons
gr.Markdown("### Try an Example")
with gr.Row():
vulnerable_example_button = gr.Button("๐ Load Vulnerable Example")
non_vulnerable_example_button = gr.Button("๐ Load Non-Vulnerable Example")
# Event handlers
analyze_button.click(
analyze_code,
inputs=code_input,
outputs=output
)
vulnerable_example_button.click(
lambda: VULNERABLE_EXAMPLE,
outputs=code_input
)
non_vulnerable_example_button.click(
lambda: NON_VULNERABLE_EXAMPLE,
outputs=code_input
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch() |