saifeddinemk commited on
Commit
e71fade
1 Parent(s): 6324ebb

Fixed app v2

Browse files
Files changed (1) hide show
  1. app.py +26 -62
app.py CHANGED
@@ -1,79 +1,43 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- import torch
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer
5
  import uvicorn
6
 
7
  # Initialize FastAPI app
8
  app = FastAPI()
9
 
10
- # Configure and load the quantized model
11
- model_id = 'bakch92/Llama-3.1-security'
 
 
 
12
 
13
- bnb_config = BitsAndBytesConfig(
14
- load_in_4bit=True,
15
- bnb_4bit_quant_type="nf4",
16
- bnb_4bit_compute_dtype=torch.bfloat16,
17
- bnb_4bit_use_double_quant=True,
18
- )
19
-
20
- # Load tokenizer and model with 4-bit quantization settings
21
- tokenizer = AutoTokenizer.from_pretrained(model_id)
22
- tokenizer.pad_token = tokenizer.eos_token
23
-
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_id,
26
- quantization_config=bnb_config,
27
- device_map="auto",
28
- )
29
- model.eval()
30
-
31
- # Define request and response models
32
- class SecurityLogRequest(BaseModel):
33
  log_data: str
34
 
35
- class SecurityAnalysisResponse(BaseModel):
 
36
  analysis: str
37
 
38
- # Inference function
39
- def generate_response(input_text: str) -> str:
40
- streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True)
41
-
42
- messages = [
43
- {"role": "system", "content": "You are an information security AI assistant specialized in analyzing security logs. Identify potential threats, suspicious IP addresses, unauthorized access attempts, and recommend actions based on the logs."},
44
- {"role": "user", "content": f"Please analyze the following security logs and provide insights on any potential malicious activity:\n{input_text}"}
45
- ]
46
-
47
- input_ids = tokenizer.apply_chat_template(
48
- messages,
49
- tokenize=True,
50
- add_generation_prompt=True,
51
- return_tensors="pt",
52
- ).to(model.device)
53
-
54
- # Generate response with the model
55
- outputs = model.generate(
56
- input_ids,
57
- streamer=streamer,
58
- max_new_tokens=512, # Limit max tokens for faster response
59
- num_beams=1,
60
- do_sample=True,
61
- temperature=0.1,
62
- top_p=0.95,
63
- top_k=10
64
- )
65
-
66
- # Extract and return generated text
67
- response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
68
- return response_text
69
-
70
  # Define the route for security log analysis
71
- @app.post("/analyze_security_logs", response_model=SecurityAnalysisResponse)
72
- async def analyze_security_logs(request: SecurityLogRequest):
73
  try:
74
- # Run inference
75
- analysis_text = generate_response(request.log_data)
76
- return SecurityAnalysisResponse(analysis=analysis_text)
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
  raise HTTPException(status_code=500, detail=str(e))
79
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import pipeline
 
4
  import uvicorn
5
 
6
  # Initialize FastAPI app
7
  app = FastAPI()
8
 
9
+ # Load the text generation pipeline with the specified model
10
+ try:
11
+ pipe = pipeline("text-generation", model="ammarnasr/codegen2-1B-security", trust_remote_code=True)
12
+ except Exception as e:
13
+ raise RuntimeError(f"Failed to load model: {e}")
14
 
15
+ # Define request model for log data
16
+ class LogRequest(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  log_data: str
18
 
19
+ # Define response model
20
+ class AnalysisResponse(BaseModel):
21
  analysis: str
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Define the route for security log analysis
24
+ @app.post("/analyze_security_logs", response_model=AnalysisResponse)
25
+ async def analyze_security_logs(request: LogRequest):
26
  try:
27
+ # Security-focused prompt
28
+ prompt = (
29
+ "Analyze the following network log data for any indicators of malicious activity, "
30
+ "such as unusual IP addresses, unauthorized access attempts, data exfiltration, or anomalies. "
31
+ "Provide details on potential threats, IPs involved, and suggest actions if any threats are detected.\n\n"
32
+ f"{request.log_data}"
33
+ )
34
+
35
+ # Generate response from the pipeline with a controlled max length
36
+ response = pipe(prompt, max_length=512, num_return_sequences=1)
37
+
38
+ # Extract and return the analysis text
39
+ analysis_text = response[0]["generated_text"]
40
+ return AnalysisResponse(analysis=analysis_text)
41
  except Exception as e:
42
  raise HTTPException(status_code=500, detail=str(e))
43