saifeddinemk commited on
Commit
7826a83
1 Parent(s): bb55f20

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import LlamaTokenizer, AutoModelForCausalLM
4
+ import torch
5
+
6
+ # Load the tokenizer and model
7
+ tokenizer = LlamaTokenizer.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")
8
+ model = AutoModelForCausalLM.from_pretrained("WhiteRabbitNeo/WhiteRabbitNeo-13B-v1")
9
+
10
+ # Initialize the FastAPI app
11
+ app = FastAPI()
12
+
13
+ # Define a request body model for input
14
+ class LogAnalysisRequest(BaseModel):
15
+ logs: list
16
+
17
+ # Define the /analyze endpoint
18
+ @app.post("/analyze")
19
+ async def analyze_logs(request: LogAnalysisRequest):
20
+ # Check if logs are provided
21
+ if not request.logs:
22
+ raise HTTPException(status_code=400, detail="No logs provided.")
23
+
24
+ # Prepare the input for the model
25
+ input_texts = ["Analyze this log for malicious activity: " + log for log in request.logs]
26
+ inputs = tokenizer(input_texts, return_tensors="pt", padding=True, truncation=True)
27
+
28
+ # Generate predictions
29
+ with torch.no_grad():
30
+ outputs = model.generate(
31
+ inputs["input_ids"],
32
+ max_length=100,
33
+ num_return_sequences=1
34
+ )
35
+
36
+ # Decode the predictions
37
+ results = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
38
+
39
+ # Format and return the results
40
+ response = {"analysis_results": results}
41
+ return response
42
+
43
+ # Run the FastAPI app (if running this script directly)
44
+ if __name__ == "__main__":
45
+ import uvicorn
46
+ uvicorn.run(app, host="0.0.0.0", port=8000)