Ashrafb commited on
Commit
9b7cbf0
·
verified ·
1 Parent(s): 6eb5678

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -103
main.py CHANGED
@@ -1,105 +1,4 @@
1
- # Based on: https://github.com/jantic/DeOldify
2
- import os, re, time
3
 
4
- os.environ["TORCH_HOME"] = os.path.join(os.getcwd(), ".cache")
5
- os.environ["XDG_CACHE_HOME"] = os.path.join(os.getcwd(), ".cache")
6
 
7
- from fastapi import FastAPI, File, UploadFile,Form
8
- from fastapi.responses import FileResponse, StreamingResponse
9
- from fastapi.staticfiles import StaticFiles
10
- import PIL
11
- import cv2
12
- import numpy as np
13
- import uuid
14
- from zipfile import ZipFile, ZIP_DEFLATED
15
- from io import BytesIO
16
- from random import randint
17
- from datetime import datetime
18
-
19
- from src.deoldify import device
20
- from src.deoldify.device_id import DeviceId
21
- from src.deoldify.visualize import *
22
- from src.app_utils import get_model_bin
23
-
24
- app = FastAPI()
25
-
26
- device.set(device=DeviceId.CPU)
27
-
28
-
29
-
30
- def load_model(model_dir, option):
31
- if option.lower() == 'artistic':
32
- model_url = 'https://data.deepai.org/deoldify/ColorizeArtistic_gen.pth'
33
- get_model_bin(model_url, os.path.join(model_dir, "ColorizeArtistic_gen.pth"))
34
- colorizer = get_image_colorizer(artistic=True)
35
- elif option.lower() == 'stable':
36
- model_url = "https://www.dropbox.com/s/usf7uifrctqw9rl/ColorizeStable_gen.pth?dl=0"
37
- get_model_bin(model_url, os.path.join(model_dir, "ColorizeStable_gen.pth"))
38
- colorizer = get_image_colorizer(artistic=False)
39
-
40
- return colorizer
41
-
42
-
43
- def resize_img(input_img, max_size):
44
- img = input_img.copy()
45
- img_height, img_width = img.shape[0],img.shape[1]
46
-
47
- if max(img_height, img_width) > max_size:
48
- if img_height > img_width:
49
- new_width = img_width*(max_size/img_height)
50
- new_height = max_size
51
- resized_img = cv2.resize(img,(int(new_width), int(new_height)))
52
- return resized_img
53
-
54
- elif img_height <= img_width:
55
- new_width = img_height*(max_size/img_width)
56
- new_height = max_size
57
- resized_img = cv2.resize(img,(int(new_width), int(new_height)))
58
- return resized_img
59
-
60
- return img
61
-
62
-
63
- def colorize_image(input_image, colorizer, img_size=800):
64
- pil_img = input_image.convert("RGB")
65
- img_rgb = np.array(pil_img)
66
- resized_img_rgb = resize_img(img_rgb, img_size)
67
- resized_pil_img = PIL.Image.fromarray(resized_img_rgb)
68
- output_pil_img = colorizer.plot_transformed_pil_image(resized_pil_img, render_factor=35, compare=False)
69
-
70
- return output_pil_img
71
-
72
- def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
73
- if fmt not in ["jpg", "png"]:
74
- raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
75
-
76
- pil_format = "JPEG" if fmt == "jpg" else "PNG"
77
- file_format = "jpg" if fmt == "jpg" else "png"
78
- mime = "image/jpeg" if fmt == "jpg" else "image/png"
79
-
80
- buf = BytesIO()
81
- pil_image.save(buf, format=pil_format)
82
-
83
- return st.download_button(
84
- label=label,
85
- data=buf.getvalue(),
86
- file_name=f'{filename}.{file_format}',
87
- mime=mime,
88
- )
89
-
90
- @app.post("/upload/")
91
- async def upload_file(file: UploadFile = File(...)):
92
- contents = await file.read()
93
- img_input = PIL.Image.open(BytesIO(contents)).convert("RGB")
94
- colorizer = load_model('models/', 'artistic') # Load the colorizer
95
- img_output = colorize_image(img_input, colorizer) # Pass colorizer to the function
96
- img_output_bytes = io.BytesIO()
97
- img_output.save(img_output_bytes, format="JPEG")
98
- return StreamingResponse(io.BytesIO(img_output_bytes.getvalue()), media_type="image/jpeg")
99
-
100
-
101
- app.mount("/", StaticFiles(directory="AB", html=True), name="static")
102
-
103
- @app.get("/")
104
- def index() -> FileResponse:
105
- return FileResponse(path="/app/AB/index.html", media_type="text/html")
 
1
+ import os
 
2
 
 
 
3
 
4
+ exec(os.environ.get('API'))