saifeddinemk commited on
Commit
5b2d750
1 Parent(s): 52422d2

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +75 -23
app.py CHANGED
@@ -1,30 +1,82 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import pipeline
 
 
4
 
5
- # Define the FastAPI app
6
  app = FastAPI()
7
 
8
- # Initialize the pipeline for text generation with the Lily-Cybersecurity-7B-v0.2 model
9
- pipe = pipeline("text-generation", model="segolilylabs/Lily-Cybersecurity-7B-v0.2")
10
-
11
- # Define a request model for input
12
- class MessageRequest(BaseModel):
13
- content: str
14
-
15
- # Define the route for message processing
16
- @app.post("/generate_response/")
17
- async def generate_response(message: MessageRequest):
18
- try:
19
- # Prepare the input for the model
20
- input_data = [{"role": "user", "content": message.content}]
21
-
22
- # Generate a response using the model pipeline
23
- response = pipe(input_data)[0]["generated_text"]
24
-
25
- # Return the generated response as JSON
26
- return {"response": response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- except Exception as e:
29
- raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from typing import List
4
+ import torch
5
+ from transformers import RobertaTokenizerFast, RobertaForMaskedLM
6
 
7
+ # Initialize FastAPI app
8
  app = FastAPI()
9
 
10
+ # Load SecureBERT Plus model and tokenizer
11
+ tokenizer = RobertaTokenizerFast.from_pretrained("ehsanaghaei/SecureBERT_Plus")
12
+ model = RobertaForMaskedLM.from_pretrained("ehsanaghaei/SecureBERT_Plus")
13
+
14
+ # Define request model
15
+ class LogRequest(BaseModel):
16
+ log: str
17
+
18
+ # Define response model
19
+ class ThreatResponse(BaseModel):
20
+ log: str
21
+ prompt: str
22
+ threat_level_predictions: List[str]
23
+ threat_type_predictions: List[str]
24
+ detected_threat_level: str
25
+ detected_threat_type: str
26
+
27
+ # Function to predict masked words for threat level and type
28
+ def predict_threat(log: str, tokenizer, model, topk=5) -> List[List[str]]:
29
+ # Create prompt with masked tokens for threat level and threat type
30
+ prompt = f"{log} Threat level: [MASK]. Threat type: [MASK]."
31
+
32
+ # Tokenize the prompt
33
+ token_ids = tokenizer.encode(prompt, return_tensors='pt')
34
+ masked_position = (token_ids.squeeze() == tokenizer.mask_token_id).nonzero()
35
+ masked_pos = [mask.item() for mask in masked_position]
36
+ predictions = []
37
+
38
+ with torch.no_grad():
39
+ output = model(token_ids)
40
+ last_hidden_state = output[0].squeeze()
41
+
42
+ # Get predictions for each masked token (Threat level and Threat type)
43
+ for mask_index in masked_pos:
44
+ mask_hidden_state = last_hidden_state[mask_index]
45
+ idx = torch.topk(mask_hidden_state, k=topk, dim=0)[1]
46
+ words = [tokenizer.decode(i.item()).strip().replace(" ", "") for i in idx]
47
+ predictions.append(words)
48
+
49
+ return predictions # Return predictions for both Threat level and Threat type
50
+
51
+ # FastAPI endpoint for detecting threat level and type
52
+ @app.post("/detect_threat", response_model=ThreatResponse)
53
+ async def detect_threat(log_request: LogRequest):
54
+ log = log_request.log
55
+
56
+ # Predict the threat level and type for the given log entry
57
+ predictions = predict_threat(log, tokenizer, model)
58
 
59
+ # Extract top predictions for threat level and type
60
+ threat_level_predictions = predictions[0] if len(predictions) > 0 else ["Unknown"]
61
+ threat_type_predictions = predictions[1] if len(predictions) > 1 else ["Unknown"]
62
+
63
+ # Determine the most likely threat level and type based on model predictions
64
+ threat_levels = ["High", "Medium", "Low", "Critical", "Unknown"]
65
+ threat_types = ["malware", "intrusion", "exploit", "attack", "phishing", "Unknown"]
66
+ detected_threat_level = next((level for level in threat_levels if level in threat_level_predictions), "Unknown")
67
+ detected_threat_type = next((t_type for t_type in threat_types if t_type in threat_type_predictions), "Unknown")
68
+
69
+ # Prepare response
70
+ response = ThreatResponse(
71
+ log=log,
72
+ prompt=f"{log} Threat level: [MASK]. Threat type: [MASK].",
73
+ threat_level_predictions=threat_level_predictions,
74
+ threat_type_predictions=threat_type_predictions,
75
+ detected_threat_level=detected_threat_level,
76
+ detected_threat_type=detected_threat_type
77
+ )
78
+
79
+ return response
80
 
81
+ # Run the FastAPI app with uvicorn
82
+ # Command to run: uvicorn your_script_name:app --reload --host 0.0.0.0 --port 8000