from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List import torch from transformers import RobertaTokenizerFast, RobertaForMaskedLM # Initialize FastAPI app app = FastAPI() # Load SecureBERT Plus model and tokenizer tokenizer = RobertaTokenizerFast.from_pretrained("ehsanaghaei/SecureBERT_Plus") model = RobertaForMaskedLM.from_pretrained("ehsanaghaei/SecureBERT_Plus") # Define request model class LogRequest(BaseModel): log: str # Define response model class ThreatResponse(BaseModel): log: str prompt: str threat_level_predictions: List[str] threat_type_predictions: List[str] detected_threat_level: str detected_threat_type: str # Function to predict masked words for threat level and type def predict_threat(log: str, tokenizer, model, topk=5) -> List[List[str]]: # Create prompt with masked tokens for threat level and threat type prompt = f"{log} Threat level: [MASK]. Threat type: [MASK]." # Tokenize the prompt token_ids = tokenizer.encode(prompt, return_tensors='pt') masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero() masked_pos = [mask.item() for mask in masked_position] predictions = [] with torch.no_grad(): output = model(token_ids) last_hidden_state = output[0].squeeze() # Get predictions for each masked token (Threat level and Threat type) for mask_index in masked_pos: mask_hidden_state = last_hidden_state[mask_index] idx = torch.topk(mask_hidden_state, k=topk, dim=0)[1] words = [tokenizer.decode(i.item()).strip().replace(" ", "") for i in idx] predictions.append(words) return predictions # Return predictions for both Threat level and Threat type # FastAPI endpoint for detecting threat level and type @app.post("/detect_threat", response_model=ThreatResponse) async def detect_threat(log_request: LogRequest): log = log_request.log # Predict the threat level and type for the given log entry predictions = predict_threat(log, tokenizer, model) # Extract top predictions for threat level and type threat_level_predictions = predictions[0] if len(predictions) > 0 else ["Unknown"] threat_type_predictions = predictions[1] if len(predictions) > 1 else ["Unknown"] # Determine the most likely threat level and type based on model predictions threat_levels = ["High", "Medium", "Low", "Critical", "Unknown"] threat_types = ["malware", "intrusion", "exploit", "attack", "phishing", "Unknown"] detected_threat_level = next((level for level in threat_levels if level in threat_level_predictions), "Unknown") detected_threat_type = next((t_type for t_type in threat_types if t_type in threat_type_predictions), "Unknown") # Prepare response response = ThreatResponse( log=log, prompt=f"{log} Threat level: [MASK]. Threat type: [MASK].", threat_level_predictions=threat_level_predictions, threat_type_predictions=threat_type_predictions, detected_threat_level=detected_threat_level, detected_threat_type=detected_threat_type ) return response # Run the FastAPI app with uvicorn # Command to run: uvicorn your_script_name:app --reload --host 0.0.0.0 --port 8000