Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer | |
import uvicorn | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Configure and load the quantized model | |
model_id = 'model_result' | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
) | |
# Load tokenizer and model with 4-bit quantization settings | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
quantization_config=bnb_config, | |
device_map="auto", | |
) | |
model.eval() | |
# Define request and response models | |
class SecurityLogRequest(BaseModel): | |
log_data: str | |
class SecurityAnalysisResponse(BaseModel): | |
analysis: str | |
# Inference function | |
def generate_response(input_text: str) -> str: | |
streamer = TextStreamer(tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True) | |
messages = [ | |
{"role": "system", "content": "You are an information security AI assistant specialized in analyzing security logs. Identify potential threats, suspicious IP addresses, unauthorized access attempts, and recommend actions based on the logs."}, | |
{"role": "user", "content": f"Please analyze the following security logs and provide insights on any potential malicious activity:\n{input_text}"} | |
] | |
input_ids = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt", | |
).to(model.device) | |
# Generate response with the model | |
outputs = model.generate( | |
input_ids, | |
streamer=streamer, | |
max_new_tokens=512, # Limit max tokens for faster response | |
num_beams=1, | |
do_sample=True, | |
temperature=0.1, | |
top_p=0.95, | |
top_k=10 | |
) | |
# Extract and return generated text | |
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response_text | |
# Define the route for security log analysis | |
async def analyze_security_logs(request: SecurityLogRequest): | |
try: | |
# Run inference | |
analysis_text = generate_response(request.log_data) | |
return SecurityAnalysisResponse(analysis=analysis_text) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Run the FastAPI app using uvicorn | |
if __name__ == "__main__": | |
uvicorn.run("app:app", host="0.0.0.0", port=8000, workers=4, reload=True) | |