hylee commited on
Commit
72c39f4
1 Parent(s): 3e2f491
Files changed (3) hide show
  1. app.py +121 -0
  2. packages.txt +2 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+ import argparse
5
+ import functools
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ from typing import Callable
10
+
11
+
12
+ import gradio as gr
13
+ import huggingface_hub
14
+ import numpy as np
15
+ import PIL.Image
16
+
17
+ from io import BytesIO
18
+ from fastai.vision import *
19
+ from fastai.vision import load_learner
20
+
21
+ ORIGINAL_REPO_URL = 'https://github.com/vijishmadhavan/ArtLine'
22
+ TITLE = 'vijishmadhavan/ArtLine'
23
+ DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}.
24
+
25
+ """
26
+ ARTICLE = """
27
+
28
+ """
29
+
30
+
31
+ MODEL_REPO = 'hylee/artline_model'
32
+
33
+ def parse_args() -> argparse.Namespace:
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument('--device', type=str, default='cpu')
36
+ parser.add_argument('--theme', type=str)
37
+ parser.add_argument('--live', action='store_true')
38
+ parser.add_argument('--share', action='store_true')
39
+ parser.add_argument('--port', type=int)
40
+ parser.add_argument('--disable-queue',
41
+ dest='enable_queue',
42
+ action='store_false')
43
+ parser.add_argument('--allow-flagging', type=str, default='never')
44
+ parser.add_argument('--allow-screenshot', action='store_true')
45
+ return parser.parse_args()
46
+
47
+ def load_model():
48
+ dir = 'model'
49
+ name = 'ArtLine_650.pkl'
50
+ model_path = huggingface_hub.hf_hub_download(MODEL_REPO,
51
+ name,
52
+ cache_dir=dir,
53
+ force_filename=name)
54
+ return model_path
55
+
56
+
57
+
58
+ def run(
59
+ image,
60
+ learn,
61
+ ) -> tuple[PIL.Image.Image]:
62
+
63
+ img = PIL.Image.open(image.name)
64
+ img_t = T.ToTensor()(img)
65
+ img_fast = PIL.Image(img_t)
66
+
67
+ p, img_hr, b = learn.predict(img_fast)
68
+ r = PIL.Image(img_hr)
69
+
70
+ return r
71
+
72
+ learn = None
73
+ def main():
74
+ gr.close_all()
75
+
76
+ args = parse_args()
77
+
78
+ model_path = load_model()
79
+
80
+ # singleton start
81
+ def load_pkl(self) -> Any:
82
+ global learn
83
+ path = Path("model")
84
+ learn = load_learner(path, 'ArtLine_650.pkl')
85
+
86
+ PklLoader = type('PklLoader', (), {"load_pkl": load_pkl})
87
+ pl = PklLoader()
88
+ pl.load_pkl()
89
+
90
+
91
+ func = functools.partial(run, learn=learn)
92
+ func = functools.update_wrapper(func, run)
93
+
94
+
95
+ gr.Interface(
96
+ func,
97
+ [
98
+ gr.inputs.Image(type='file', label='Input Image'),
99
+ ],
100
+ [
101
+ gr.outputs.Image(
102
+ type='pil',
103
+ label='Result'),
104
+ ],
105
+ #examples=examples,
106
+ theme=args.theme,
107
+ title=TITLE,
108
+ description=DESCRIPTION,
109
+ article=ARTICLE,
110
+ allow_screenshot=args.allow_screenshot,
111
+ allow_flagging=args.allow_flagging,
112
+ live=args.live,
113
+ ).launch(
114
+ enable_queue=args.enable_queue,
115
+ server_port=args.port,
116
+ share=args.share,
117
+ )
118
+
119
+
120
+ if __name__ == '__main__':
121
+ main()
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ opencv-python-headless==4.5.5.62
2
+ Pillow==9.0.1
3
+ scipy==1.7.3
4
+ fastai==1.0.61
5
+ numpy==1.17.2
6
+ pandas==1.1.2
7
+ torch==1.6.0
8
+ torchvision===0.7.0