bguisard commited on
Commit
680c3bf
1 Parent(s): 54771e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -1,5 +1,7 @@
1
  import gradio as gr
2
  import jax
 
 
3
  from diffusers import FlaxStableDiffusionPipeline
4
 
5
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(
 
1
  import gradio as gr
2
  import jax
3
+ from flax.jax_utils import replicate
4
+ from flax.training.common_utils import shard
5
  from diffusers import FlaxStableDiffusionPipeline
6
 
7
  pipeline, pipeline_params = FlaxStableDiffusionPipeline.from_pretrained(