Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import gradio as gr | |
# Example code snippets | |
VULNERABLE_EXAMPLE = """static int cirrus_bitblt_videotovideo_patterncopy(CirrusVGAState * s)\n{\n | |
return cirrus_bitblt_common_patterncopy(s,\n\t\t\t\t\t s->vram_ptr +\n (s->cirrus_blt_srcaddr & ~7));\n}""" | |
NON_VULNERABLE_EXAMPLE = """static void loongarch_cpu_synchronize_from_tb(CPUState *cs, | |
\n const TranslationBlock *tb)\n{\n LoongArchCPU *cpu = LOONGARCH_CPU(cs);\n CPULoongArchState *env = &cpu->env;\n\n env->pc = tb->pc;\n}""" | |
# Load the model and tokenizer | |
def load_model(): | |
"""Load the model and tokenizer""" | |
model_name = "moazx/Code-Vulnerability-Classifier_app" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
return model, tokenizer, device | |
# Load the model and tokenizer once when the app starts | |
model, tokenizer, device = load_model() | |
def classify_code_sample(code_sample): | |
"""Classify a single code sample and get probabilities""" | |
inputs = tokenizer( | |
code_sample, | |
truncation=True, | |
padding='max_length', | |
max_length=512, | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
probabilities = F.softmax(logits, dim=-1).squeeze().cpu().numpy() | |
return probabilities | |
def analyze_code(code_input): | |
"""Analyze the code and return results""" | |
if not code_input.strip(): | |
return "Please enter some code to analyze." | |
try: | |
# Get predictions | |
probabilities = classify_code_sample(code_input) | |
# Class names and confidence | |
class_names = ["Non-vulnerable", "Vulnerable"] | |
predicted_class_index = probabilities.argmax() | |
predicted_class = class_names[predicted_class_index] | |
confidence = probabilities[predicted_class_index] * 100 | |
# Prepare results | |
result = f"**Prediction:** {predicted_class}\n" | |
result += f"**Confidence:** {confidence:.1f}%\n\n" | |
# Detailed probabilities | |
result += "**Detailed Probabilities:**\n" | |
for class_name, prob in zip(class_names, probabilities): | |
result += f"- {class_name}: {prob * 100:.1f}%\n" | |
# Additional warnings for vulnerable code | |
if predicted_class == "Vulnerable": | |
result += "\nβ οΈ **Warning:** This code has been flagged as potentially vulnerable. Please review it carefully for:\n" | |
result += "- Security issues (e.g., input validation, authentication)\n" | |
result += "- Implementation issues (e.g., memory management, resource handling)\n" | |
result += "- Design issues (e.g., concurrency, logic errors)\n" | |
return result | |
except Exception as e: | |
return f"Error during analysis: {str(e)}" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# DiverseVul Code Vulnerability Classifier") | |
gr.Markdown(""" | |
This tool analyzes code snippets for various types of vulnerabilities, including: | |
- Security vulnerabilities (e.g., buffer overflows, injection flaws) | |
- Memory management issues | |
- Concurrency problems | |
- Resource leaks | |
- Logic errors | |
- Performance issues | |
- Reliability problems | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
code_input = gr.Textbox( | |
label="Enter your code snippet here:", | |
placeholder="Paste your code here...", | |
lines=10, | |
max_lines=20, | |
value="" | |
) | |
analyze_button = gr.Button("Analyze Code") | |
with gr.Column(): | |
output = gr.Markdown(label="Analysis Results") | |
# Example buttons | |
gr.Markdown("### Try an Example") | |
with gr.Row(): | |
vulnerable_example_button = gr.Button("π Load Vulnerable Example") | |
non_vulnerable_example_button = gr.Button("π Load Non-Vulnerable Example") | |
# Event handlers | |
analyze_button.click( | |
analyze_code, | |
inputs=code_input, | |
outputs=output | |
) | |
vulnerable_example_button.click( | |
lambda: VULNERABLE_EXAMPLE, | |
outputs=code_input | |
) | |
non_vulnerable_example_button.click( | |
lambda: NON_VULNERABLE_EXAMPLE, | |
outputs=code_input | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch() |