from flax.jax_utils import replicate from jax import pmap from flax.training.common_utils import shard import jax import jax.numpy as jnp import gradio as gr from pathlib import Path from PIL import Image import numpy as np from diffusers import FlaxStableDiffusionPipeline import os if 'TPU_NAME' in os.environ: import requests if 'TPU_DRIVER_MODE' not in globals(): url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly' resp = requests.post(url) TPU_DRIVER_MODE = 1 from jax.config import config config.FLAGS.jax_xla_backend = "tpu_driver" config.FLAGS.jax_backend_target = os.environ['TPU_NAME'] print('Registered TPU:', config.FLAGS.jax_backend_target) else: print('No TPU detected. Can be changed under "Runtime/Change runtime type".') import jax jax.local_devices() num_devices = jax.device_count() device_type = jax.devices()[0].device_kind print(f"Found {num_devices} JAX devices of type {device_type}.") def sd2_inference(pipeline, prompts, params, seed = 42, num_inference_steps = 50 ): prng_seed = jax.random.PRNGKey(seed) prompt_ids = pipeline.prepare_inputs(prompts) params = replicate(params) prng_seed = jax.random.split(prng_seed, jax.device_count()) prompt_ids = shard(prompt_ids) images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:]) images = pipeline.numpy_to_pil(images) return images HF_ACCESS_TOKEN = os.environ["HFAUTH"] # Load Model # - Reference: https://github.com/huggingface/diffusers/blob/main/README.md pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", use_auth_token = HF_ACCESS_TOKEN, revision="bf16", dtype=jnp.bfloat16, ) import gradio as gr def text_to_image_and_image_to_text(text=None,image=None): if image != None: txt=text if text !=None: images = sd2_inference(pipeline, [text], params, seed = 42, num_inference_steps = 5 ) img = images[0] return img,txt if __name__ == '__main__': interFace = gr.Interface(fn=text_to_image_and_image_to_text, inputs=[gr.inputs.Textbox(placeholder="Enter the text to Encode to an image", label="Text to Encode to Image ",lines=1,optional=True),gr.Image(type="pil",label="Image to Decode to text",optional=True)], outputs=[gr.outputs.Image(type="pil", label="Encoded Image"),gr.outputs.Textbox( label="Decoded Text")], title="T2I2T", description="T2I2T: Text2Image2Text imformation transmiter" ) interFace.launch()