saifeddinemk commited on
Commit
3bb8b4d
1 Parent(s): 9dee4b2

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +61 -45
app.py CHANGED
@@ -1,70 +1,86 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from typing import List, Dict, Any
4
  from transformers import pipeline
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
- # Load the custom CyBERTuned model directly from Hugging Face
10
  unmasker = pipeline("fill-mask", model="s2w-ai/CyBERTuned-SecurityLLM")
11
 
12
- # Define request model for multiple log entries
13
  class LogRequest(BaseModel):
14
- logs: List[str] # Expecting a list of log entries
15
 
16
  # Define response model
17
  class ThreatResponse(BaseModel):
18
  log: str
19
  prompt: str
20
- pred: Dict[str, List[Dict[str, Any]]] # Dictionary with structured lists for predictions
 
 
 
 
21
 
22
- # Function to predict masked words for threat level, type, SRC IP, DEST IP, and Protocol
23
- def predict_threat(log: str, unmasker, topk=5) -> Dict[str, List[Dict[str, Any]]]:
24
- # Create prompt with masked tokens for each category
25
- prompt = (
26
- f"{log} Threat level: <mask>. Threat type: <mask>. "
27
- f"Attack type: <mask>. SRC IP: <mask>. DEST IP: <mask>. Protocol: <mask>."
28
- )
29
 
30
  # Predict top options for each <mask>
31
  predictions = unmasker(prompt, top_k=topk)
32
 
33
- # Separate predictions for each masked category
34
- threat_level_predictions = predictions[:topk]
35
- threat_type_predictions = predictions[topk:2*topk]
36
- attack_type_predictions = predictions[2*topk:3*topk]
37
- src_ip_predictions = predictions[3*topk:4*topk]
38
- dest_ip_predictions = predictions[4*topk:5*topk]
39
- protocol_predictions = predictions[5*topk:6*topk]
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # Structure the response in a dictionary, with each entry containing a list of dictionaries
42
- return {
43
- "threat_level_predictions": [pred for pred in threat_level_predictions],
44
- "threat_type_predictions": [pred for pred in threat_type_predictions],
45
- "attack_type_predictions": [pred for pred in attack_type_predictions],
46
- "src_ip_predictions": [pred for pred in src_ip_predictions],
47
- "dest_ip_predictions": [pred for pred in dest_ip_predictions],
48
- "protocol_predictions": [pred for pred in protocol_predictions]
49
- }
50
 
51
- # FastAPI endpoint for detecting threat level and type for multiple logs
52
- @app.post("/detect_threat", response_model=List[ThreatResponse])
 
 
 
 
53
  async def detect_threat(log_request: LogRequest):
54
- responses = []
55
- for log in log_request.logs:
56
- # Predict the threat level and type for each log entry
57
- predictions = predict_threat(log, unmasker)
 
 
 
 
58
 
59
- # Prepare response for each log entry
60
- response = ThreatResponse(
61
- log=log,
62
- prompt=(
63
- f"{log} Threat level: <mask>. Threat type: <mask>. "
64
- f"Attack type: <mask>. SRC IP: <mask>. DEST IP: <mask>. Protocol: <mask>."
65
- ),
66
- pred=predictions
67
- )
68
- responses.append(response)
69
 
70
- return responses
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from typing import List
4
  from transformers import pipeline
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
+
10
  unmasker = pipeline("fill-mask", model="s2w-ai/CyBERTuned-SecurityLLM")
11
 
12
+ # Define request model
13
  class LogRequest(BaseModel):
14
+ log: str
15
 
16
  # Define response model
17
  class ThreatResponse(BaseModel):
18
  log: str
19
  prompt: str
20
+ #threat_level_predictions: List[str]
21
+ #threat_type_predictions: List[str]
22
+ #detected_threat_level: str
23
+ #detected_threat_type: str
24
+ pred : List[object]
25
 
26
+ # Function to predict masked words for threat level and type
27
+ def predict_threat(log: str, unmasker, topk=5) -> List[List[str]]:
28
+ # Create prompt with masked tokens for threat level and threat type
29
+ prompt = f"{log}"
 
 
 
30
 
31
  # Predict top options for each <mask>
32
  predictions = unmasker(prompt, top_k=topk)
33
 
34
+ # Extract top predictions for each <mask>
35
+ #threat_level_predictions = [pred["token_str"].strip() for pred in predictions[:topk]]
36
+ #threat_type_predictions = [pred["token_str"].strip() for pred in predictions[topk:2*topk]]
37
+
38
+ return predictions
39
+ def get_maximum_predictions(data):
40
+ # Initialize list to store maximum values for each prediction array
41
+ max_predictions = []
42
+
43
+ # Loop over each prediction array in "pred"
44
+ for index, predictions in enumerate(data["pred"]):
45
+ max_score = float('-inf')
46
+ max_prediction = None
47
+
48
+ # Find the prediction with the highest score in the current array
49
+ for pred in predictions:
50
+ if pred["score"] > max_score:
51
+ max_score = pred["score"]
52
+ max_prediction = pred["token_str"].strip()
53
 
54
+ # Append the result with the max prediction for this array
55
+ max_predictions.append({
56
+ f"max_prediction_{index + 1}": max_prediction
57
+ })
 
 
 
 
 
58
 
59
+ return max_predictions
60
+
61
+ # Get result
62
+
63
+ # FastAPI endpoint for detecting threat level and type
64
+ @app.post("/detect_threat", response_model=ThreatResponse)
65
  async def detect_threat(log_request: LogRequest):
66
+ log = log_request.log
67
+
68
+ # Predict the threat level and type for the given log entry
69
+ predictions = predict_threat(log, unmasker)
70
+
71
+ # Extract top predictions for threat level and type
72
+ ##threat_level_predictions = predictions[0] if len(predictions) > 0 else ["Unknown"]
73
+ ## threat_type_predictions = predictions[1] if len(predictions) > 1 else ["Unknown"]
74
 
75
+ # Use the top prediction as the most likely threat level and type
76
+ ##detected_threat_level = threat_level_predictions[0] if threat_level_predictions else "Unknown"
77
+ #detected_threat_type = threat_type_predictions[0] if threat_type_predictions else "Unknown"
78
+
79
+ # Prepare response
80
+ response = ThreatResponse(
81
+ log=log,
82
+ prompt=f"{log} Threat level: <mask>. Threat type: <mask>.",
83
+ pred=get_maximum_predictions(predictions)
84
+ )
85
 
86
+ return response