saifeddinemk commited on
Commit
15487fd
1 Parent(s): 4ee560a

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -6,7 +6,7 @@ from transformers import pipeline
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
- # Load the custom CyBERTuned model directly from Hugging Face
10
  unmasker = pipeline("fill-mask", model="s2w-ai/CyBERTuned-SecurityLLM")
11
 
12
  # Define request model
@@ -17,20 +17,25 @@ class LogRequest(BaseModel):
17
  class ThreatResponse(BaseModel):
18
  log: str
19
  prompt: str
20
- pred: List[str]
 
 
 
 
21
 
22
  # Function to predict masked words for threat level and type
23
- def predict_threat(log: str, unmasker, topk=5) -> List[str]:
24
  # Create prompt with masked tokens for threat level and threat type
25
  prompt = f"{log} Threat level most likely is: <mask>. Most likely the Threat type: <mask>."
26
 
27
  # Predict top options for each <mask>
28
  predictions = unmasker(prompt, top_k=topk)
29
 
30
- # Extract only the token_str values to return a list of strings
31
- pred_strings = [pred["token_str"].strip() for pred in predictions]
 
32
 
33
- return pred_strings
34
 
35
  # FastAPI endpoint for detecting threat level and type
36
  @app.post("/detect_threat", response_model=ThreatResponse)
@@ -38,13 +43,21 @@ async def detect_threat(log_request: LogRequest):
38
  log = log_request.log
39
 
40
  # Predict the threat level and type for the given log entry
41
- pred_strings = predict_threat(log, unmasker)
 
 
 
 
 
 
 
 
42
 
43
  # Prepare response
44
  response = ThreatResponse(
45
  log=log,
46
  prompt=f"{log} Threat level: <mask>. Threat type: <mask>.",
47
- pred=pred_strings
48
  )
49
 
50
  return response
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
+
10
  unmasker = pipeline("fill-mask", model="s2w-ai/CyBERTuned-SecurityLLM")
11
 
12
  # Define request model
 
17
  class ThreatResponse(BaseModel):
18
  log: str
19
  prompt: str
20
+ #threat_level_predictions: List[str]
21
+ #threat_type_predictions: List[str]
22
+ #detected_threat_level: str
23
+ #detected_threat_type: str
24
+ pred : List[object]
25
 
26
  # Function to predict masked words for threat level and type
27
+ def predict_threat(log: str, unmasker, topk=5) -> List[List[object]]:
28
  # Create prompt with masked tokens for threat level and threat type
29
  prompt = f"{log} Threat level most likely is: <mask>. Most likely the Threat type: <mask>."
30
 
31
  # Predict top options for each <mask>
32
  predictions = unmasker(prompt, top_k=topk)
33
 
34
+ # Extract top predictions for each <mask>
35
+ #threat_level_predictions = [pred["token_str"].strip() for pred in predictions[:topk]]
36
+ #threat_type_predictions = [pred["token_str"].strip() for pred in predictions[topk:2*topk]]
37
 
38
+ return predictions
39
 
40
  # FastAPI endpoint for detecting threat level and type
41
  @app.post("/detect_threat", response_model=ThreatResponse)
 
43
  log = log_request.log
44
 
45
  # Predict the threat level and type for the given log entry
46
+ predictions = predict_threat(log, unmasker)
47
+
48
+ # Extract top predictions for threat level and type
49
+ ##threat_level_predictions = predictions[0] if len(predictions) > 0 else ["Unknown"]
50
+ ## threat_type_predictions = predictions[1] if len(predictions) > 1 else ["Unknown"]
51
+
52
+ # Use the top prediction as the most likely threat level and type
53
+ ##detected_threat_level = threat_level_predictions[0] if threat_level_predictions else "Unknown"
54
+ #detected_threat_type = threat_type_predictions[0] if threat_type_predictions else "Unknown"
55
 
56
  # Prepare response
57
  response = ThreatResponse(
58
  log=log,
59
  prompt=f"{log} Threat level: <mask>. Threat type: <mask>.",
60
+ pred=predictions
61
  )
62
 
63
  return response