Deadmon commited on
Commit
4538515
·
verified ·
1 Parent(s): bea7193

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -138
app.py CHANGED
@@ -1,16 +1,15 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import asyncio
3
- import fal_client
4
- from PIL import Image
5
- import requests
6
- import io
7
- import os
8
-
9
- # Set up your Fal API key as an environment variable
10
- os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4"
11
- fal_client.api_key = os.environ["FAL_KEY"]
12
-
13
- # Base model paths for model switching
14
  base_model_paths = {
15
  "RealisticVisionV4": "SG161222/Realistic_Vision_V4.0_noVAE",
16
  "RealisticVisionV6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
@@ -20,140 +19,142 @@ base_model_paths = {
20
  "EpicRealism": "emilianJR/epiCRealism"
21
  }
22
 
23
- # Updated function to include the API call to the Fal model
24
- async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, width: int, height: int):
25
- """
26
- Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model.
27
- Arguments:
28
- image_url: URL of the input image to use.
29
- prompt: Text prompt for generating the image.
30
- negative_prompt: Text for negative prompt to avoid unwanted characteristics in the output.
31
- model_type: Model type to use.
32
- base_model: Base model to use for image generation.
33
- seed: Seed for random generation.
34
- guidance_scale: CFG scale for how closely the model sticks to the prompt.
35
- num_inference_steps: Number of inference steps.
36
- width: Width of the generated image.
37
- height: Height of the generated image.
38
- Returns:
39
- The URL of the generated image along with other attributes like file size, dimensions, etc., or None if failed.
40
- """
41
- try:
42
- handler = fal_client.submit(
43
- "fal-ai/ip-adapter-face-id",
44
- arguments={
45
- "model_type": model_type,
46
- "prompt": prompt,
47
- "face_image_url": image_url,
48
- "negative_prompt": negative_prompt,
49
- "seed": seed,
50
- "guidance_scale": guidance_scale,
51
- "num_inference_steps": num_inference_steps,
52
- "num_samples": 1, # Adjusted to 1 sample
53
- "width": width,
54
- "height": height,
55
- "face_id_det_size": 640,
56
- "base_1_5_model_repo": base_model_paths[base_model], # Base model
57
- "base_sdxl_model_repo": "SG161222/RealVisXL_V3.0",
58
- "face_images_data_url": None
59
- },
60
- )
61
- # Retrieve the result synchronously
62
- result = handler.get()
63
-
64
- if "image" in result and "url" in result["image"]:
65
- return result["image"] # Return the full image information dictionary
66
- else:
67
- return None
68
- except Exception as e:
69
- print(f"Error generating image: {e}")
70
- return None
71
-
72
- # Fetch the image from the given URL
73
- def fetch_image_from_url(url: str) -> Image.Image:
74
- response = requests.get(url)
75
- return Image.open(io.BytesIO(response.content))
76
-
77
- # Process input images and handle the image generation
78
- async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, width: int, height: int):
79
- image_url = await upload_image_to_server(image)
80
-
81
- if not image_url:
82
- return None
83
-
84
- image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, width, height)
85
-
86
- if image_info and "url" in image_info:
87
- return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata
88
-
89
- return None, None
90
 
91
- # Upload image to server
92
- async def upload_image_to_server(image: Image.Image) -> str:
93
- byte_arr = io.BytesIO()
94
- image.save(byte_arr, format='PNG')
95
- byte_arr.seek(0)
 
 
 
 
 
96
 
97
- with open("temp_image.png", "wb") as f:
98
- f.write(byte_arr.getvalue())
 
 
 
 
 
 
 
 
99
 
100
- try:
101
- upload_url = await fal_client.upload_file_async("temp_image.png")
102
- return upload_url
103
- except Exception as e:
104
- print(f"Error uploading image: {e}")
105
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- # Change style between Photorealistic and Stylized
108
  def change_style(style):
109
  if style == "Photorealistic":
110
- return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
111
  else:
112
- return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
113
 
114
- # Gradio Interface
115
- def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, width, height):
116
- loop = asyncio.new_event_loop()
117
- asyncio.set_event_loop(loop)
118
-
119
- result_image, image_info = loop.run_until_complete(
120
- process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, width, height)
121
- )
122
-
123
- if result_image:
124
- metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}"
125
- return result_image, metadata
126
- return None, "Error generating image"
127
-
128
- # Main Gradio App
129
- with gr.Blocks() as demo:
130
- gr.Markdown("## Image Generation with Fal API and Gradio")
131
-
132
- with gr.Row():
133
- with gr.Column():
134
- image_input = gr.Image(label="Upload Image", type="pil")
135
- prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2)
136
- negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2)
137
- style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic")
138
- model_type = gr.Dropdown(label="Model Type", choices=["1_5-v1", "SDXL-v2-plus", "1_5-auraface-v1"], value="SDXL-v2-plus")
139
- base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="RealisticVisionV4")
140
- seed_input = gr.Slider(label="Seed", value=42, minimum=0, maximum=1000, step=1)
141
- guidance_scale_input = gr.Slider(label="Guidance Scale", value=7.5, minimum=1.0, maximum=20.0, step=0.1)
142
- num_inference_steps_input = gr.Slider(label="Inference Steps", value=50, minimum=10, maximum=100, step=1)
143
- width_input = gr.Slider(label="Width", value=1024, minimum=512, maximum=1024, step=64)
144
- height_input = gr.Slider(label="Height", value=1024, minimum=512, maximum=1024, step=64)
145
- generate_button = gr.Button("Generate Image")
146
-
147
- with gr.Column():
148
- generated_image = gr.Image(label="Generated Image")
149
- metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6)
150
 
151
- generate_button.click(
152
- fn=gradio_interface,
153
- inputs=[image_input, prompt_input, negative_prompt_input, model_type, base_model, seed_input, guidance_scale_input, num_inference_steps_input, width_input, height_input],
154
- outputs=[generated_image, metadata_output]
155
- )
156
 
157
- style.change(fn=change_style, inputs=style, outputs=[model_type, guidance_scale_input, num_inference_steps_input])
 
 
 
158
 
159
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
4
+ from transformers import AutoFeatureExtractor
5
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
+ from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
7
+ from huggingface_hub import hf_hub_download
8
+ from insightface.app import FaceAnalysis
9
+ from insightface.utils import face_align
10
  import gradio as gr
11
+ import cv2
12
+
 
 
 
 
 
 
 
 
 
 
13
  base_model_paths = {
14
  "RealisticVisionV4": "SG161222/Realistic_Vision_V4.0_noVAE",
15
  "RealisticVisionV6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
 
19
  "EpicRealism": "emilianJR/epiCRealism"
20
  }
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ vae_model_path = "stabilityai/sd-vae-ft-mse"
24
+ image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
25
+ ip_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sd15.bin", repo_type="model")
26
+ ip_plus_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid-plusv2_sd15.bin", repo_type="model")
27
+
28
+ safety_model_id = "CompVis/stable-diffusion-safety-checker"
29
+ safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
30
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
31
+
32
+ device = "cuda"
33
 
34
+ noise_scheduler = DDIMScheduler(
35
+ num_train_timesteps=1000,
36
+ beta_start=0.00085,
37
+ beta_end=0.012,
38
+ beta_schedule="scaled_linear",
39
+ clip_sample=False,
40
+ set_alpha_to_one=False,
41
+ steps_offset=1,
42
+ )
43
+ vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
44
 
45
+ def load_model(base_model_path):
46
+ pipe = StableDiffusionPipeline.from_pretrained(
47
+ base_model_path,
48
+ torch_dtype=torch.float16,
49
+ scheduler=noise_scheduler,
50
+ vae=vae,
51
+ feature_extractor=safety_feature_extractor,
52
+ safety_checker=None # <--- Disable safety checker
53
+ ).to(device)
54
+ return pipe
55
+
56
+ ip_model = None
57
+ ip_model_plus = None
58
+
59
+ app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
60
+ app.prepare(ctx_id=0, det_size=(640, 640))
61
+
62
+ cv2.setNumThreads(1)
63
+
64
+ @spaces.GPU(enable_queue=True)
65
+ def generate_image(images, prompt, negative_prompt, preserve_face_structure, face_strength, likeness_strength, nfaa_negative_prompt, base_model, num_inference_steps, guidance_scale, width, height, progress=gr.Progress(track_tqdm=True)):
66
+ global ip_model, ip_model_plus
67
+ base_model_path = base_model_paths[base_model]
68
+ pipe = load_model(base_model_path)
69
+ ip_model = IPAdapterFaceID(pipe, ip_ckpt, device)
70
+ ip_model_plus = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_plus_ckpt, device)
71
+
72
+ faceid_all_embeds = []
73
+ first_iteration = True
74
+ for image in images:
75
+ face = cv2.imread(image)
76
+ faces = app.get(face)
77
+ faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
78
+ faceid_all_embeds.append(faceid_embed)
79
+ if(first_iteration and preserve_face_structure):
80
+ face_image = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224) # you can also segment the face
81
+ first_iteration = False
82
+
83
+ average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
84
+
85
+ total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
86
+
87
+ if(not preserve_face_structure):
88
+ print("Generating normal")
89
+ image = ip_model.generate(
90
+ prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
91
+ scale=likeness_strength, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
92
+ )
93
+ else:
94
+ print("Generating plus")
95
+ image = ip_model_plus.generate(
96
+ prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
97
+ scale=likeness_strength, face_image=face_image, shortcut=True, s_scale=face_strength, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
98
+ )
99
+ print(image)
100
+ return image
101
 
 
102
  def change_style(style):
103
  if style == "Photorealistic":
104
+ return(gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0))
105
  else:
106
+ return(gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8))
107
 
108
+ def swap_to_gallery(images):
109
+ return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ def remove_back_to_files():
112
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
 
 
 
113
 
114
+ css = '''
115
+ h1{margin-bottom: 0 !important}
116
+ footer{display:none !important}
117
+ '''
118
 
119
+ with gr.Blocks(css=css) as demo:
120
+ gr.Markdown("")
121
+ gr.Markdown("")
122
+ with gr.Row():
123
+ with gr.Column():
124
+ files = gr.Files(
125
+ label="Drag 1 or more photos of your face",
126
+ file_types=["image"]
127
+ )
128
+ uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
129
+ with gr.Column(visible=False) as clear_button:
130
+ remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
131
+ prompt = gr.Textbox(label="Prompt",
132
+ info="Try something like 'a photo of a man/woman/person'",
133
+ placeholder="A photo of a [man/woman/person]...")
134
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality")
135
+ style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
136
+ base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic_Vision_V4.0_noVAE")
137
+ submit = gr.Button("Submit")
138
+ with gr.Accordion(open=False, label="Advanced Options"):
139
+ preserve = gr.Checkbox(label="Preserve Face Structure", info="Higher quality, less versatility (the face structure of your first photo will be preserved). Unchecking this will use the v1 model.", value=True)
140
+ face_strength = gr.Slider(label="Face Structure strength", info="Only applied if preserve face structure is checked", value=1.3, step=0.1, minimum=0, maximum=3)
141
+ likeness_strength = gr.Slider(label="Face Embed strength", value=1.0, step=0.1, minimum=0, maximum=5)
142
+ nfaa_negative_prompts = gr.Textbox(label="Appended Negative Prompts", info="Negative prompts to steer generations towards safe for all audiences outputs", value="naked, bikini, skimpy, scanty, bare skin, lingerie, swimsuit, exposed, see-through")
143
+ num_inference_steps = gr.Slider(label="Number of Inference Steps", value=30, step=1, minimum=10, maximum=100)
144
+ guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20)
145
+ width = gr.Slider(label="Width", value=512, step=64, minimum=256, maximum=1024)
146
+ height = gr.Slider(label="Height", value=512, step=64, minimum=256, maximum=1024)
147
+ with gr.Column():
148
+ gallery = gr.Gallery(label="Generated Images")
149
+ style.change(fn=change_style,
150
+ inputs=style,
151
+ outputs=[preserve, face_strength, likeness_strength])
152
+ files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
153
+ remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
154
+ submit.click(fn=generate_image,
155
+ inputs=[files,prompt,negative_prompt,preserve, face_strength, likeness_strength, nfaa_negative_prompts, base_model, num_inference_steps, guidance_scale, width, height],
156
+ outputs=gallery)
157
+
158
+ gr.Markdown("")
159
+
160
+ demo.launch()