# from fastapi import FastAPI # app = FastAPI() # @app.get("/") # def greet_json(): # return {"Hello": "World!"} from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import torch app = FastAPI() # Check if CUDA is available if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") # Define the request model class URLRequest(BaseModel): url: str # Load the tokenizer and model using pipeline pipe = pipeline("text-classification", model="kmack/malicious-url-detection", device=device.index if torch.cuda.is_available() else -1) # Define the prediction function def get_prediction(url_to_check: str): result = pipe(url_to_check) return result # Define the API endpoint for URL prediction @app.post("/predict") async def predict(url_request: URLRequest): url_to_check = url_request.url result = get_prediction(url_to_check) return {"prediction": result} # Health check endpoint @app.get("/") async def read_root(): return {"message": "API is up and running"}