thomasgauthier commited on
Commit
bdf9962
·
1 Parent(s): fa851d1

first commit

Browse files
Files changed (5) hide show
  1. app.py +14 -5
  2. gradio_interface.py +32 -0
  3. image_generator.py +117 -0
  4. model_loader.py +14 -0
  5. requirements.txt +7 -0
app.py CHANGED
@@ -1,7 +1,16 @@
1
- import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
1
+ import torch
2
+ import spaces
3
+ from model_loader import load_model_and_processor
4
+ from image_generator import process_and_generate
5
+ from gradio_interface import create_gradio_interface
6
 
7
+ if __name__ == "__main__":
8
+ # Set the model path
9
+ model_path = "deepseek-ai/Janus-1.3B"
10
 
11
+ # Load the model and processor
12
+ vl_gpt, vl_chat_processor = load_model_and_processor(model_path)
13
+
14
+ # Create and launch the Gradio interface
15
+ demo = create_gradio_interface(vl_gpt, vl_chat_processor, process_and_generate)
16
+ demo.launch(allowed_paths=["/"])
gradio_interface.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+
4
+ def create_gradio_interface(vl_gpt, vl_chat_processor, process_and_generate):
5
+ def gradio_process_and_generate(input_image, prompt, num_images, cfg_weight):
6
+ return process_and_generate(vl_gpt, vl_chat_processor, input_image, prompt, num_images, cfg_weight)
7
+
8
+ explanation = """Janus 1.3B uses a differerent visual encoder for understanding and generation.
9
+
10
+ ![Janus Model Architecture](file/images/janus_architecture.svg)
11
+
12
+ Here, by feeding the model an image and then asking it to generate that same image, we visualize the model's ability to translate input (understanding) embedding space to generative embedding space."""
13
+
14
+ with gr.Blocks() as demo:
15
+ gr.Markdown("# How Janus-1.3B sees itself")
16
+
17
+ with gr.Row():
18
+ input_image = gr.Image(type="filepath", label="Input Image")
19
+ output_images = gr.Gallery(label="Generated Images", columns=2, rows=2)
20
+ prompt = gr.Textbox(label="Prompt", value="Exactly what is shown in the image.")
21
+ num_images = gr.Slider(minimum=1, maximum=12, value=12, step=1, label="Number of Images to Generate")
22
+ cfg_weight = gr.Slider(minimum=1, maximum=10, value=5, step=0.1, label="CFG Weight")
23
+ generate_btn = gr.Button("Generate", variant="primary", size="lg")
24
+
25
+ generate_btn.click(
26
+ fn=gradio_process_and_generate,
27
+ inputs=[input_image, prompt, num_images, cfg_weight],
28
+ outputs=output_images
29
+ )
30
+ gr.Markdown(explanation)
31
+
32
+ return demo
image_generator.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import PIL.Image
3
+ import torch
4
+ import numpy as np
5
+ from janus.utils.io import load_pil_images
6
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
7
+ from functools import lru_cache
8
+
9
+
10
+ def prepare_classifier_free_guidance_input(input_embeds, vl_chat_processor, mmgpt, batch_size=16):
11
+ uncond_input_ids = torch.full((1, input_embeds.shape[1]),
12
+ vl_chat_processor.pad_id,
13
+ dtype=torch.long,
14
+ device=input_embeds.device)
15
+ uncond_input_ids[:, 0] = input_embeds.shape[1] - 1
16
+ uncond_input_ids[:, -1] = vl_chat_processor.tokenizer.eos_token_id
17
+
18
+ uncond_input_embeds = mmgpt.language_model.get_input_embeddings()(uncond_input_ids)
19
+ uncond_input_embeds[:, -1, :] = input_embeds[:, -1, :]
20
+
21
+ cond_input_embeds = input_embeds.repeat(batch_size, 1, 1)
22
+ uncond_input_embeds = uncond_input_embeds.repeat(batch_size, 1, 1)
23
+
24
+ combined_input_embeds = torch.stack([cond_input_embeds, uncond_input_embeds], dim=1)
25
+ combined_input_embeds = combined_input_embeds.view(batch_size * 2, -1, input_embeds.shape[-1])
26
+
27
+ return combined_input_embeds
28
+
29
+ @spaces.GPU
30
+ @torch.inference_mode()
31
+ def generate(
32
+ mmgpt: MultiModalityCausalLM,
33
+ vl_chat_processor: VLChatProcessor,
34
+ inputs_embeds,
35
+ temperature: float = 1,
36
+ parallel_size: int = 1,
37
+ cfg_weight: float = 5,
38
+ image_token_num_per_image: int = 576,
39
+ img_size: int = 384,
40
+ patch_size: int = 16,
41
+ ):
42
+ generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
43
+
44
+ inputs_embeds = prepare_classifier_free_guidance_input(inputs_embeds, vl_chat_processor, mmgpt, parallel_size)
45
+
46
+ for i in range(image_token_num_per_image):
47
+ outputs = mmgpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=outputs.past_key_values if i != 0 else None)
48
+ hidden_states = outputs.last_hidden_state
49
+
50
+ logits = mmgpt.gen_head(hidden_states[:, -1, :])
51
+
52
+ logit_cond = logits[0::2, :]
53
+ logit_uncond = logits[1::2, :]
54
+
55
+ logits = logit_uncond + cfg_weight * (logit_cond-logit_uncond)
56
+ probs = torch.softmax(logits / temperature, dim=-1)
57
+
58
+ next_token = torch.multinomial(probs, num_samples=1)
59
+ generated_tokens[:, i] = next_token.squeeze(dim=-1)
60
+
61
+ next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
62
+ img_embeds = mmgpt.prepare_gen_img_embeds(next_token)
63
+ inputs_embeds = img_embeds.unsqueeze(dim=1)
64
+
65
+ dec = mmgpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, img_size//patch_size, img_size//patch_size])
66
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
67
+
68
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
69
+
70
+ visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
71
+ visual_img[:, :, :] = dec
72
+
73
+ generated_images = []
74
+ for i in range(parallel_size):
75
+ generated_images.append(PIL.Image.fromarray(visual_img[i]))
76
+
77
+ return generated_images
78
+
79
+ @lru_cache(maxsize=1)
80
+ def get_start_tag_embed(vl_gpt, vl_chat_processor):
81
+ with torch.no_grad():
82
+ return vl_gpt.language_model.get_input_embeddings()(
83
+ vl_chat_processor.tokenizer.encode(vl_chat_processor.image_start_tag, add_special_tokens=False, return_tensors="pt").to(vl_gpt.device)
84
+ )
85
+
86
+ def process_and_generate(vl_gpt, vl_chat_processor, input_image, prompt, num_images=4, cfg_weight=5):
87
+ start_tag_embed = get_start_tag_embed(vl_gpt, vl_chat_processor)
88
+
89
+ nl = '\n'
90
+ conversation = [
91
+ {
92
+ "role": "User",
93
+ "content": f"<image_placeholder>{nl + prompt if prompt else ''}",
94
+ "images": [input_image],
95
+ },
96
+ {"role": "Assistant", "content": ""},
97
+ ]
98
+
99
+ pil_images = load_pil_images(conversation)
100
+ prepare_inputs = vl_chat_processor(
101
+ conversations=conversation, images=pil_images, force_batchify=True
102
+ ).to(vl_gpt.device)
103
+
104
+ with torch.no_grad():
105
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
106
+
107
+ inputs_embeds = torch.cat((inputs_embeds, start_tag_embed), dim=1)
108
+
109
+ generated_images = generate(
110
+ vl_gpt,
111
+ vl_chat_processor,
112
+ inputs_embeds,
113
+ parallel_size=num_images,
114
+ cfg_weight=cfg_weight
115
+ )
116
+
117
+ return generated_images
model_loader.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM
3
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
4
+
5
+ def load_model_and_processor(model_path):
6
+ vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
7
+ tokenizer = vl_chat_processor.tokenizer
8
+
9
+ vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
10
+ model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
11
+ )
12
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
13
+
14
+ return vl_gpt, vl_chat_processor
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ Pillow
4
+ gradio
5
+ janus @ git+https://github.com/deepseek-ai/Janus
6
+ transformers
7
+ spaces