TripleKdev commited on
Commit
15eaf83
·
1 Parent(s): cf7303a

Add application file

Browse files
Files changed (2) hide show
  1. app.py +79 -4
  2. requirements.txt +6 -2
app.py CHANGED
@@ -1,7 +1,82 @@
1
- from fastapi import FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  app = FastAPI()
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+ import torch.nn as nn
6
+
7
+ import numpy as np
8
+
9
+ from PIL import Image
10
+ from fastapi import FastAPI, UploadFile, File
11
+ from fastapi.responses import JSONResponse
12
+ import uvicorn
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+
15
+ device = torch.device('cpu')
16
+
17
+ model = models.vgg16()
18
+ model.classifier[6] = nn.Linear(4096, 2)
19
+ model.load_state_dict(torch.load('model_vgg16.pth' , map_location=device))
20
+
21
+ model.eval()
22
+
23
+ transform = T.Compose([
24
+ T.Resize((224, 224)),
25
+ T.ToTensor(),
26
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),)
27
+ ])
28
 
29
  app = FastAPI()
30
 
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"], # Allows all origins
34
+ allow_credentials=True,
35
+ allow_methods=["*"], # Allows all methods
36
+ allow_headers=["*"], # Allows all headers
37
+ )
38
+
39
+ def background2white(image) :
40
+ arr = np.asarray(image)
41
+
42
+ if arr.ndim == 3 and arr.shape[-1]==4 :
43
+ white_background = Image.new("RGB", image.size, (255, 255, 255))
44
+ white_background.paste(image, (0, 0), image)
45
+ return white_background
46
+ else :
47
+ return image
48
+
49
+ @app.post("/predict/")
50
+ async def predict(file: UploadFile = File(...)):
51
+ image = Image.open(file.file)
52
+
53
+ image = background2white(image).convert("RGB")
54
+
55
+ # image.save('web/backend/input.jpg')
56
+
57
+ print(image.size)
58
+ print()
59
+ image = transform(image)
60
+ image = image.unsqueeze(0)
61
+
62
+ print(image.size)
63
+
64
+ with torch.no_grad():
65
+ output = model(image)
66
+
67
+ print(output)
68
+
69
+ _, predicted = torch.max(output, 1)
70
+
71
+ predicted_class = predicted.item()
72
+
73
+ probabilities = F.softmax(output[0] , dim=0)
74
+ probabilities = probabilities.tolist()
75
+
76
+ print(probabilities)
77
+
78
+ return JSONResponse([predicted_class , probabilities[0] , probabilities[1] , output.tolist()])
79
+
80
+ if __name__ == "__main__":
81
+ import uvicorn
82
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -1,2 +1,6 @@
1
- fastapi
2
- uvicorn
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ numpy==1.26.4
4
+ pillow==10.3.0
5
+ fastapi==0.111.0
6
+ uvicorn==0.30.1