from flax.jax_utils import replicate from jax import pmap from flax.training.common_utils import shard import jax import jax.numpy as jnp import gradio as gr from pathlib import Path from PIL import Image import numpy as np from diffusers import FlaxStableDiffusionPipeline import os if 'TPU_NAME' in os.environ: import requests if 'TPU_DRIVER_MODE' not in globals(): url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly' resp = requests.post(url) TPU_DRIVER_MODE = 1 from jax.config import config config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = os.environ['TPU_NAME'] print('Registered TPU:', config.FLAGS.jax_backend_target) else: print('No TPU detected. Can be changed under "Runtime/Change runtime type".') import jax jax.local_devices() num_devices = jax.device_count() device_type = jax.devices()[0].device_kind print(f"Found {num_devices} JAX devices of type {device_type}.") def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ): prng_seed = jax.random.PRNGKey(seed) prompt_ids = pipeline.prepare_inputs(prompts) params = replicate(params) prng_seed = jax.random.split(prng_seed, jax.device_count()) prompt_ids = shard(prompt_ids) images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:]) images = pipeline.numpy_to_pil(images) return images HF_ACCESS_TOKEN = os.environ["HFAUTH"] # Load Model # - Reference: https://github.com/huggingface/diffusers/blob/main/README.md pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token = HF_ACCESS_TOKEN, revision="bf16", dtype=jnp.bfloat16, ) def text_to_image(text): images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 ) img = images[0] return img examples = ["apple", "banana", "chocolate"] if __name__ == '__main__': interFace = gr.Interface(fn=text_to_image, inputs=gr.inputs.Textbox(placeholder="Enter the text to Encode to an image", label="Text " "query", lines=1), outputs=gr.outputs.Image(type="auto", label="Generated Image"), verbose=True, examples=examples, title="Generate Image from Text", description="", theme="huggingface") interFace.launch()