lmattingly13 commited on
Commit
822d597
·
1 Parent(s): 9908f8b

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -10
app.py CHANGED
@@ -1,14 +1,68 @@
1
  import gradio as gr
2
-
3
- # inference function takes prompt, negative prompt and image
4
- def infer(image, prompt):
5
- # implement your inference function here
6
- output_image = image
7
- return output_image
8
-
 
 
9
  title = "ControlNet for Cartoon-ifying"
10
  description = "This is a demo on ControlNet for changing images of people into cartoons of different styles."
11
  examples = [["./simpsons_human_1.jpg", "turn into a simpsons character", "./simpsons_animated_1.jpg"]]
12
-
13
- gr.Interface(fn = infer, inputs = ["image", "text"], outputs = "image",
14
- title = title, description = description, examples = examples, theme='gradio/soft').launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import jax
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+ from flax.jax_utils import replicate
6
+ from flax.training.common_utils import shard
7
+ from PIL import Image
8
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
9
+ import cv2
10
+
11
  title = "ControlNet for Cartoon-ifying"
12
  description = "This is a demo on ControlNet for changing images of people into cartoons of different styles."
13
  examples = [["./simpsons_human_1.jpg", "turn into a simpsons character", "./simpsons_animated_1.jpg"]]
14
+
15
+
16
+
17
+ # Constants
18
+ low_threshold = 100
19
+ high_threshold = 200
20
+
21
+ base_model_path = "runwayml/stable-diffusion-v1-5"
22
+ controlnet_path = "lmattingly/controlnet-uncanny-simpsons"
23
+ #controlnet_path = "JFoz/dog-cat-pose"
24
+
25
+ # Models
26
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
27
+ controlnet_path, dtype=jnp.bfloat16
28
+ )
29
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
30
+ "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
31
+ )
32
+
33
+ def create_key(seed=0):
34
+ return jax.random.PRNGKey(seed)
35
+
36
+ def infer(prompts, image):
37
+ params["controlnet"] = controlnet_params
38
+
39
+ num_samples = 1 #jax.device_count()
40
+ rng = create_key(0)
41
+ rng = jax.random.split(rng, jax.device_count())
42
+ im = image
43
+ image = Image.fromarray(im)
44
+
45
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
46
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
47
+
48
+ p_params = replicate(params)
49
+ prompt_ids = shard(prompt_ids)
50
+ processed_image = shard(processed_image)
51
+
52
+ output = pipe(
53
+ prompt_ids=prompt_ids,
54
+ image=processed_image,
55
+ params=p_params,
56
+ prng_seed=rng,
57
+ num_inference_steps=50,
58
+ jit=True,
59
+ ).images
60
+
61
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
62
+ return output_images
63
+
64
+
65
+ gr.Interface(fn = infer, inputs = ["text", "image"], outputs = "image",
66
+ title = title, description = description, theme='gradio/soft',
67
+ examples=[["a simpsons cartoon character", "simpsons_human_1.jpg"]]
68
+ ).launch()