gojiteji's picture
Update app.py
3fd0786
raw
history blame
3.11 kB
from flax.jax_utils import replicate
from jax import pmap
from flax.training.common_utils import shard
import jax
import jax.numpy as jnp
import gradio as gr
from pathlib import Path
from PIL import Image
import numpy as np
from diffusers import FlaxStableDiffusionPipeline
import os
if 'TPU_NAME' in os.environ:
import requests
if 'TPU_DRIVER_MODE' not in globals():
url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
print('No TPU detected. Can be changed under "Runtime/Change runtime type".')
import jax
jax.local_devices()
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ):
prng_seed = jax.random.PRNGKey(seed)
prompt_ids = pipeline.prepare_inputs(prompts)
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
return images
HF_ACCESS_TOKEN = os.environ["HFAUTH"]
# Load Model
# - Reference: https://github.com/huggingface/diffusers/blob/main/README.md
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
use_auth_token = HF_ACCESS_TOKEN,
revision="bf16",
dtype=jnp.bfloat16,
)
def text_to_image_and_image_to_text(text,image):
img=None
txt=None
if image "" != None:
images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 )
img = images[0]
img = image
if text !=None:
txt=text
return img,txt
examples = ["apple",
"banana",
"chocolate"]
if __name__ == '__main__':
interFace = gr.Interface(fn=text_to_image_and_image_to_text,
inputs=[gr.inputs.Textbox(placeholder="Enter the text to Encode to an image", label="Text "
"query",
lines=1),
gr.Image(type="pil")]
outputs=[gr.outputs.Image(type="auto", label="Generated Image"),gr.outputs.Textbox(placeholder="Decoded Text")],
verbose=True,
examples=examples,
title="Generate Image from Text",
description="",
theme="huggingface")
interFace.launch()