#!/usr/bin/env python from __future__ import annotations import argparse import functools import os import pathlib import sys from typing import Callable import gradio as gr import huggingface_hub import numpy as np import PIL.Image from io import BytesIO from fastai.vision import * from fastai.vision import load_learner ORIGINAL_REPO_URL = 'https://github.com/vijishmadhavan/ArtLine' TITLE = 'vijishmadhavan/ArtLine' DESCRIPTION = f"""This is a demo for {ORIGINAL_REPO_URL}. """ ARTICLE = """ """ MODEL_REPO = 'hylee/artline_model' def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--theme', type=str) parser.add_argument('--live', action='store_true') parser.add_argument('--share', action='store_true') parser.add_argument('--port', type=int) parser.add_argument('--disable-queue', dest='enable_queue', action='store_false') parser.add_argument('--allow-flagging', type=str, default='never') parser.add_argument('--allow-screenshot', action='store_true') return parser.parse_args() def load_model(): dir = 'model' name = 'ArtLine_650.pkl' model_path = huggingface_hub.hf_hub_download(MODEL_REPO, name, cache_dir=dir, force_filename=name) return model_path def run( image, learn, ) -> tuple[PIL.Image.Image]: img = PIL.Image.open(image.name) img_t = T.ToTensor()(img) img_fast = PIL.Image(img_t) p, img_hr, b = learn.predict(img_fast) r = PIL.Image(img_hr) return r learn = None def main(): gr.close_all() args = parse_args() model_path = load_model() # singleton start def load_pkl(self) -> Any: global learn path = Path("model") learn = load_learner(path, 'ArtLine_650.pkl') PklLoader = type('PklLoader', (), {"load_pkl": load_pkl}) pl = PklLoader() pl.load_pkl() func = functools.partial(run, learn=learn) func = functools.update_wrapper(func, run) gr.Interface( func, [ gr.inputs.Image(type='file', label='Input Image'), ], [ gr.outputs.Image( type='pil', label='Result'), ], #examples=examples, theme=args.theme, title=TITLE, description=DESCRIPTION, article=ARTICLE, allow_screenshot=args.allow_screenshot, allow_flagging=args.allow_flagging, live=args.live, ).launch( enable_queue=args.enable_queue, server_port=args.port, share=args.share, ) if __name__ == '__main__': main()