moazx's picture
Update app.py
0264912 verified
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import gradio as gr
# Example code snippets
VULNERABLE_EXAMPLE = """static int cirrus_bitblt_videotovideo_patterncopy(CirrusVGAState * s)\n{\n
return cirrus_bitblt_common_patterncopy(s,\n\t\t\t\t\t s->vram_ptr +\n (s->cirrus_blt_srcaddr & ~7));\n}"""
NON_VULNERABLE_EXAMPLE = """static void loongarch_cpu_synchronize_from_tb(CPUState *cs,
\n const TranslationBlock *tb)\n{\n LoongArchCPU *cpu = LOONGARCH_CPU(cs);\n CPULoongArchState *env = &cpu->env;\n\n env->pc = tb->pc;\n}"""
# Load the model and tokenizer
def load_model():
"""Load the model and tokenizer"""
model_name = "moazx/Code-Vulnerability-Classifier_app"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
return model, tokenizer, device
# Load the model and tokenizer once when the app starts
model, tokenizer, device = load_model()
def classify_code_sample(code_sample):
"""Classify a single code sample and get probabilities"""
inputs = tokenizer(
code_sample,
truncation=True,
padding='max_length',
max_length=512,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = F.softmax(logits, dim=-1).squeeze().cpu().numpy()
return probabilities
def analyze_code(code_input):
"""Analyze the code and return results"""
if not code_input.strip():
return "Please enter some code to analyze."
try:
# Get predictions
probabilities = classify_code_sample(code_input)
# Class names and confidence
class_names = ["Non-vulnerable", "Vulnerable"]
predicted_class_index = probabilities.argmax()
predicted_class = class_names[predicted_class_index]
confidence = probabilities[predicted_class_index] * 100
# Prepare results
result = f"**Prediction:** {predicted_class}\n"
result += f"**Confidence:** {confidence:.1f}%\n\n"
# Detailed probabilities
result += "**Detailed Probabilities:**\n"
for class_name, prob in zip(class_names, probabilities):
result += f"- {class_name}: {prob * 100:.1f}%\n"
# Additional warnings for vulnerable code
if predicted_class == "Vulnerable":
result += "\n⚠️ **Warning:** This code has been flagged as potentially vulnerable. Please review it carefully for:\n"
result += "- Security issues (e.g., input validation, authentication)\n"
result += "- Implementation issues (e.g., memory management, resource handling)\n"
result += "- Design issues (e.g., concurrency, logic errors)\n"
return result
except Exception as e:
return f"Error during analysis: {str(e)}"
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# DiverseVul Code Vulnerability Classifier")
gr.Markdown("""
This tool analyzes code snippets for various types of vulnerabilities, including:
- Security vulnerabilities (e.g., buffer overflows, injection flaws)
- Memory management issues
- Concurrency problems
- Resource leaks
- Logic errors
- Performance issues
- Reliability problems
""")
with gr.Row():
with gr.Column():
code_input = gr.Textbox(
label="Enter your code snippet here:",
placeholder="Paste your code here...",
lines=10,
max_lines=20,
value=""
)
analyze_button = gr.Button("Analyze Code")
with gr.Column():
output = gr.Markdown(label="Analysis Results")
# Example buttons
gr.Markdown("### Try an Example")
with gr.Row():
vulnerable_example_button = gr.Button("πŸ“‹ Load Vulnerable Example")
non_vulnerable_example_button = gr.Button("πŸ“‹ Load Non-Vulnerable Example")
# Event handlers
analyze_button.click(
analyze_code,
inputs=code_input,
outputs=output
)
vulnerable_example_button.click(
lambda: VULNERABLE_EXAMPLE,
outputs=code_input
)
non_vulnerable_example_button.click(
lambda: NON_VULNERABLE_EXAMPLE,
outputs=code_input
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()