Vincent-luo commited on
Commit
cfd47f7
·
1 Parent(s): 9c5ac71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -79,13 +79,12 @@ args = Namespace(
79
  controlnet_revision=None,
80
  controlnet_from_pt=False,
81
  )
82
- weight_dtype = jnp.float32
83
 
84
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
85
  args.controlnet_model_name_or_path,
86
  revision=args.controlnet_revision,
87
  from_pt=args.controlnet_from_pt,
88
- dtype=jnp.float32,
89
  )
90
 
91
  pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
@@ -93,7 +92,7 @@ pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretraine
93
  # tokenizer=tokenizer,
94
  controlnet=controlnet,
95
  safety_checker=None,
96
- dtype=weight_dtype,
97
  revision=args.revision,
98
  from_pt=args.from_pt,
99
  )
 
79
  controlnet_revision=None,
80
  controlnet_from_pt=False,
81
  )
 
82
 
83
  controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
84
  args.controlnet_model_name_or_path,
85
  revision=args.controlnet_revision,
86
  from_pt=args.controlnet_from_pt,
87
+ dtype=jnp.bfloat16,
88
  )
89
 
90
  pipeline, pipeline_params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
 
92
  # tokenizer=tokenizer,
93
  controlnet=controlnet,
94
  safety_checker=None,
95
+ dtype=jnp.bfloat16,
96
  revision=args.revision,
97
  from_pt=args.from_pt,
98
  )