Spaces:
Sleeping
Sleeping
File size: 1,957 Bytes
15eaf83 cf7303a 15eaf83 b0b5cad 15eaf83 b0b5cad 15eaf83 1e32822 15eaf83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision import models
import torch.nn as nn
import numpy as np
from PIL import Image
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
device = torch.device('cpu')
model = models.vgg16()
model.classifier[6] = nn.Linear(4096, 2)
model.load_state_dict(torch.load('model_vgg16.pth' , map_location=device))
model.eval()
transform = T.Compose([
T.Resize((224, 224)),
T.ToTensor(),
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),)
])
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def background2white(image) :
arr = np.asarray(image)
if arr.ndim == 3 and arr.shape[-1]==4 :
white_background = Image.new("RGB", image.size, (255, 255, 255))
white_background.paste(image, (0, 0), image)
return white_background
else :
return image
@app.post("/predict/")
async def predict(file: UploadFile = File(...)):
image = Image.open(file.file)
image = background2white(image).convert("RGB")
# image.save('web/backend/input.jpg')
print(image.size)
print()
image = transform(image)
image = image.unsqueeze(0)
print(image.size)
with torch.no_grad():
output = model(image)
print(output)
_, predicted = torch.max(output, 1)
predicted_class = predicted.item()
probabilities = F.softmax(output[0] , dim=0)
probabilities = probabilities.tolist()
print(probabilities)
return JSONResponse([predicted_class , probabilities[0] , probabilities[1] , output.tolist()])
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |