import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM from janus.models import VLChatProcessor import random import numpy as np import spaces import json import os cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' model_checkpoint = "./Flux-Prompt" enhancer_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) enhancer_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(cuda_device) model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer def generate(input_ids, width, height, temperature, cfg_weight, parallel_size: int = 1, image_token_num_per_image: int = 576, patch_size: int = 16): torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logit_sum = logit_cond - logit_uncond logits = logit_uncond + cfg_weight * logit_sum probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=1): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() @spaces.GPU() def infer( prompt, guidance_scale, temperature, progress=gr.Progress(track_tqdm=True), ): seed = random.randint(0, 2000) torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) parallel_size = 1 height=384 width=384 with torch.no_grad(): messages = [ {'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''} ] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='' ) text += vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) try: output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance_scale, parallel_size=parallel_size, temperature=temperature) images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size) return images[0] except RuntimeError as e: print(f"Error during generation: {e}") raise gr.Error("Generation failed. Please try different parameters.") finally: torch.cuda.empty_cache() def load_seeds(): try: with open('seeds.json', 'r') as f: return json.load(f) except FileNotFoundError: print("seeds.json not found") return {} @spaces.GPU() def prompt_generator(): seeds = load_seeds() if seeds: seed = random.choice(seeds["seeds"]) input_ids = enhancer_tokenizer(seed, return_tensors='pt').input_ids.to(cuda_device) random_seed = random.randint(0, 2000) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) answer = enhancer_model.generate(input_ids, max_length=256, num_return_sequences=1, temperature=1.0, repetition_penalty=1.2) final_answer = enhancer_tokenizer.decode(answer[0], skip_special_tokens=True) return final_answer return "Unable to generate prompt - no seeds available" css = """ #col-container { margin: 0 auto; max-width: 640px; } .center-container { display: flex; justify-content: center; align-items: center; } """ with gr.Blocks(css=css) as demo: gr.HTML(""" """) with gr.Column(elem_id="col-container"): with gr.Row(elem_classes="header-container"): gr.Image("./deepseek.jpg", width=100, height=100, show_fullscreen_button=False, show_download_button=False, show_share_button=False, container=False) gr.Markdown("

DeepSeek

Janus-Pro-7B

") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, placeholder="Enter your prompt", container=False, ) with gr.Row(elem_classes="center-container"): run_prompt = gr.Button("Generate Prompt", scale=0, variant="primary") run_image = gr.Button("Generate Image", scale=0, variant="primary") result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0, ) temperature = gr.Slider( label="Temperature", minimum=0.0, maximum=2.0, step=0.1, value=1.0, ) gr.on( triggers=[run_image.click, prompt.submit], fn=infer, inputs=[ prompt, guidance_scale, temperature ], outputs=[result], ) gr.on( triggers=[run_prompt.click], fn=prompt_generator, outputs=[prompt], ) if __name__ == "__main__": demo.launch()