saifeddinemk commited on
Commit
31dafcd
1 Parent(s): 15487fd

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +47 -40
app.py CHANGED
@@ -1,63 +1,70 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from typing import List
4
  from transformers import pipeline
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
-
10
  unmasker = pipeline("fill-mask", model="s2w-ai/CyBERTuned-SecurityLLM")
11
 
12
- # Define request model
13
  class LogRequest(BaseModel):
14
- log: str
15
 
16
  # Define response model
17
  class ThreatResponse(BaseModel):
18
  log: str
19
  prompt: str
20
- #threat_level_predictions: List[str]
21
- #threat_type_predictions: List[str]
22
- #detected_threat_level: str
23
- #detected_threat_type: str
24
- pred : List[object]
25
-
26
- # Function to predict masked words for threat level and type
27
- def predict_threat(log: str, unmasker, topk=5) -> List[List[object]]:
28
- # Create prompt with masked tokens for threat level and threat type
29
- prompt = f"{log} Threat level most likely is: <mask>. Most likely the Threat type: <mask>."
30
 
31
  # Predict top options for each <mask>
32
  predictions = unmasker(prompt, top_k=topk)
33
 
34
- # Extract top predictions for each <mask>
35
- #threat_level_predictions = [pred["token_str"].strip() for pred in predictions[:topk]]
36
- #threat_type_predictions = [pred["token_str"].strip() for pred in predictions[topk:2*topk]]
37
-
38
- return predictions
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # FastAPI endpoint for detecting threat level and type
41
- @app.post("/detect_threat", response_model=ThreatResponse)
42
  async def detect_threat(log_request: LogRequest):
43
- log = log_request.log
 
 
 
44
 
45
- # Predict the threat level and type for the given log entry
46
- predictions = predict_threat(log, unmasker)
47
-
48
- # Extract top predictions for threat level and type
49
- ##threat_level_predictions = predictions[0] if len(predictions) > 0 else ["Unknown"]
50
- ## threat_type_predictions = predictions[1] if len(predictions) > 1 else ["Unknown"]
51
-
52
- # Use the top prediction as the most likely threat level and type
53
- ##detected_threat_level = threat_level_predictions[0] if threat_level_predictions else "Unknown"
54
- #detected_threat_type = threat_type_predictions[0] if threat_type_predictions else "Unknown"
55
-
56
- # Prepare response
57
- response = ThreatResponse(
58
- log=log,
59
- prompt=f"{log} Threat level: <mask>. Threat type: <mask>.",
60
- pred=predictions
61
- )
62
 
63
- return response
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from typing import List, Dict, Any
4
  from transformers import pipeline
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
+ # Load the custom CyBERTuned model directly from Hugging Face
10
  unmasker = pipeline("fill-mask", model="s2w-ai/CyBERTuned-SecurityLLM")
11
 
12
+ # Define request model for multiple log entries
13
  class LogRequest(BaseModel):
14
+ logs: List[str] # Expecting a list of log entries
15
 
16
  # Define response model
17
  class ThreatResponse(BaseModel):
18
  log: str
19
  prompt: str
20
+ pred: Dict[str, List[Dict[str, Any]]] # Dictionary with structured lists for predictions
21
+
22
+ # Function to predict masked words for threat level, type, SRC IP, DEST IP, and Protocol
23
+ def predict_threat(log: str, unmasker, topk=5) -> Dict[str, List[Dict[str, Any]]]:
24
+ # Create prompt with masked tokens for each category
25
+ prompt = (
26
+ f"{log} Threat level: <mask>. Threat type: <mask>. "
27
+ f"Attack type: <mask>. SRC IP: <mask>. DEST IP: <mask>. Protocol: <mask>."
28
+ )
 
29
 
30
  # Predict top options for each <mask>
31
  predictions = unmasker(prompt, top_k=topk)
32
 
33
+ # Separate predictions for each masked category
34
+ threat_level_predictions = predictions[:topk]
35
+ threat_type_predictions = predictions[topk:2*topk]
36
+ attack_type_predictions = predictions[2*topk:3*topk]
37
+ src_ip_predictions = predictions[3*topk:4*topk]
38
+ dest_ip_predictions = predictions[4*topk:5*topk]
39
+ protocol_predictions = predictions[5*topk:6*topk]
40
+
41
+ # Structure the response in a dictionary
42
+ return {
43
+ "threat_level_predictions": threat_level_predictions,
44
+ "threat_type_predictions": threat_type_predictions,
45
+ "attack_type_predictions": attack_type_predictions,
46
+ "src_ip_predictions": src_ip_predictions,
47
+ "dest_ip_predictions": dest_ip_predictions,
48
+ "protocol_predictions": protocol_predictions
49
+ }
50
 
51
+ # FastAPI endpoint for detecting threat level and type for multiple logs
52
+ @app.post("/detect_threat", response_model=List[ThreatResponse])
53
  async def detect_threat(log_request: LogRequest):
54
+ responses = []
55
+ for log in log_request.logs:
56
+ # Predict the threat level and type for each log entry
57
+ predictions = predict_threat(log, unmasker)
58
 
59
+ # Prepare response for each log entry
60
+ response = ThreatResponse(
61
+ log=log,
62
+ prompt=(
63
+ f"{log} Threat level: <mask>. Threat type: <mask>. "
64
+ f"Attack type: <mask>. SRC IP: <mask>. DEST IP: <mask>. Protocol: <mask>."
65
+ ),
66
+ pred=predictions
67
+ )
68
+ responses.append(response)
 
 
 
 
 
 
 
69
 
70
+ return responses