File size: 3,372 Bytes
98ac5fa
 
 
 
 
 
71fb5d8
98ac5fa
4e9944c
 
 
 
ade9768
 
98ac5fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from diffusers import StableDiffusionPipeline
import gradio as gr
import requests
import base64
from PIL import Image, PngImagePlugin
from io import BytesIO
import torch

# required for stable difussion
auth_token = os.environ.get("auth_token") #in secret space

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=auth_token)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe.to(device)

def encode_pil_to_base64(pil_image):
  ''' From: https://github.com/gradio-app/gradio/blob/main/gradio/processing_utils.py'''
  with BytesIO() as output_bytes:

      # Copy any text-only metadata
      use_metadata = False
      metadata = PngImagePlugin.PngInfo()
      for key, value in pil_image.info.items():
          if isinstance(key, str) and isinstance(value, str):
              metadata.add_text(key, value)
              use_metadata = True

      pil_image.save(
          output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
      )
      bytes_data = output_bytes.getvalue()
  base64_str = str(base64.b64encode(bytes_data), "utf-8")
  return "data:image/png;base64," + base64_str

def decode_base64_to_image(encoding):
  ''' From: https://github.com/gradio-app/gradio/blob/main/gradio/processing_utils.py'''
  content = encoding.split(";")[1]
  image_encoded = content.split(",")[1]
  return Image.open(BytesIO(base64.b64decode(image_encoded)))

def improve_image(img, scale=2):
  ''' Improves an input image using GFP-GAN
  Inputs
  img (PIL): image to improve
  scale (int): scale factor for new image
  Output
  Improved image. If the request to GFPGAN is unsuccesful, it returns
  a black image.
  '''
  url = "https://hf.space/embed/NotFungibleIO/GFPGAN/+/api/predict"
  request_objt = {"data":[encode_pil_to_base64(img),'v1.3', scale]}
  try:
    imp_img = decode_base64_to_image(requests.post(url, json=request_objt).json()['data'][0])
  except AttributeError:
    return Image.new('RGB', size=(512, 512))
  return imp_img

def generate(celebrity, movie, guidance, improve_flag, scale): # add scale as var
  prompt = f"A movie poster of {celebrity} in {movie}."
  image = pipe(prompt, guidance=guidance).images[0]
  if improve_flag:
    image = improve_image(image, scale=scale)
  return image

movie_options = ["James Bond", "Snatch", "Saving Private Ryan", "Scarface", "Avatar", "Top Gun"]
title = "Movie Poster Celebrity Swap"
description = "Write the name of a celebrity, and pick a movie from the dropdown menu.\
               It will generate a new movie poster (inspired by your chosen movie)\
               with the chosen celebrity in it. See below for explanation of the\
               input variables."
article= "Inputs explained: \n Guidance: the lower, the more random the output\
          image. Improve and scale: if selected, the image will be refined\
          using GFP-GAN (google it), and scaled (if scale is >1)."

demo = gr.Interface(
    fn=generate,
    inputs=[gr.Textbox(value="Daniel Craig"), 
            gr.Dropdown(movie_options, value="Saving Private Ryan"), 
            gr.Slider(1, 20, value=7.5, step=0.5),
            gr.Checkbox(label="Improve and scale? (extra time)"),
            gr.Slider(1, 3, value=1, step=0.5)
            ],
    outputs='image',
    title=title,
    description=description,
    article=article
) 

demo.launch()