feng2022's picture
Update app.py
e7d7286
raw
history blame
3.95 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import functools
import os
import pickle
import sys
import subprocess
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import pipeline
sys.path.append('.')
sys.path.append('./Time_TravelRephotography')
from utils import torch_helpers as th
from argparse import Namespace
from projector import (
ProjectorArguments,
main,
create_generator,
make_image,
)
sys.path.insert(0, 'StyleGAN-Human')
input_path = ''
spectral_sensitivity = 'b'
TITLE = 'Time-TravelRephotography'
DESCRIPTION = '''This is an unofficial demo for https://github.com/Time-Travel-Rephotography.
'''
ARTICLE = '<center><img src="https://visitor-badge.glitch.me/badge?page_id=Time-TravelRephotography" alt="visitor badge"/></center>'
TOKEN = "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"
pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")
scores = []
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
return parser.parse_args()
def load_model(file_name: str, path:str,device: torch.device) -> nn.Module:
path = hf_hub_download(f'{path}',
f'{file_name}',
use_auth_token=TOKEN)
with open(path, 'rb') as f:
model = torch.load(f)
model.eval()
model.to(device)
with torch.inference_mode():
z = torch.zeros((1, model.z_dim)).to(device)
label = torch.zeros([1, model.c_dim], device=device)
model(z, label, force_fp32=True)
return model
def predict(text):
return pipe(text)[0]["translation_text"]
def track_score(score):
scores.append(score)
top_scores = sorted(scores, reverse=True)[:3]
return top_scores
def main():
#torch.cuda.init()
#if torch.cuda.is_initialized():
# ini = "True1"
#else:
# ini = "False1"
#result = subprocess.check_output(['nvidia-smi'])
#load_model("stylegan2-ffhq-config-f","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",device)
"""args = ProjectorArguments().parse(
args=[str(input_path)],
namespace=Namespace(
# spectral_sensitivity=spectral_sensitivity,
encoder_ckpt=f"checkpoint/encoder/checkpoint_{spectral_sensitivity}.pt",
# encoder_name=spectral_sensitivity,
# gaussian=gaussian_radius,
log_visual_freq=1000,
input='text',
))
device = th.device()
generator = create_generator("stylegan2-ffhq-config-f.pt","feng2022/Time-TravelRephotography_stylegan2-ffhq-config-f",args, device)
latent = torch.randn((1, 512), device=device)
img_out, _, _ = generator([latent])
imgs_arr = make_image(img_out)"""
#iface = gr.Interface(
#fn=predict,
#inputs='text',
#outputs='text',
#examples=['result'],
#gr.outputs.Image(type='numpy', label='Output'),
#title=TITLE,
#description=DESCRIPTION,
#article=ARTICLE,
#theme=args.theme,
#allow_flagging=args.allow_flagging,
#live=args.live,
#)
#iface.launch(
#enable_queue=args.enable_queue,
#server_port=args.port,
#share=args.share,
#)
demo = gr.Interface(
track_score,
gr.Number(label="Score"),
gr.JSON(label="Top Scores")
)
demo.launch()
if __name__ == '__main__':
main()