import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM from janus.models import VLChatProcessor import random import numpy as np import spaces import json cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' model_checkpoint = "gokaygokay/Flux-Prompt-Enhance" enhancer_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) enhancer_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(cuda_device) model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer def generate(input_ids, width, height, temperature, cfg_weight, parallel_size: int = 1, image_token_num_per_image: int = 576, patch_size: int = 16): torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logit_sum = logit_cond - logit_uncond logits = logit_uncond + cfg_weight * logit_sum probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=1): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() @spaces.GPU() def infer( prompt, guidance_scale, temperature, progress=gr.Progress(track_tqdm=True), ): seed = random.randint(0, 2000) torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) parallel_size = 1 height=384 width=384 with torch.no_grad(): messages = [ {'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''} ] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='' ) text += vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) try: output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance_scale, parallel_size=parallel_size, temperature=temperature) images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size) return images[0] except RuntimeError as e: print(f"Error during generation: {e}") raise gr.Error("Generation failed. Please try different parameters.") finally: torch.cuda.empty_cache() def load_seeds(): try: with open('seeds.json', 'r') as f: return json.load(f) except FileNotFoundError: print("seeds.json not found") return {} @spaces.GPU() def prompt_generator(): seeds = load_seeds() if seeds: seed = random.choice(seeds["seeds"]) input_ids = enhancer_tokenizer(seed, return_tensors='pt').input_ids.to(cuda_device) random_seed = random.randint(0, 2000) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) answer = enhancer_model.generate(input_ids, max_length=256, num_return_sequences=1, temperature=1.0, repetition_penalty=1.2) final_answer = enhancer_tokenizer.decode(answer[0], skip_special_tokens=True) return final_answer return "Unable to generate prompt - no seeds available" css = """ #col-container { margin: 0 auto; max-width: 640px; } .center-container { display: flex; justify-content: center; align-items: center; } """ with gr.Blocks(css=css) as demo: gr.HTML(""" """) with gr.Column(elem_id="col-container"): with gr.Row(elem_classes="header-container"): gr.Image("./deepseek.jpg", width=100, height=100, show_fullscreen_button=False, show_download_button=False, show_share_button=False, container=False) gr.Markdown("

DeepSeek Janus

") with gr.Row(): prompt = gr.Text( label="Prompt", show_label=False, placeholder="Enter your prompt", container=False, ) with gr.Row(elem_classes="center-container"): run_prompt = gr.Button("Generate Prompt", scale=0, variant="primary") run_image = gr.Button("Generate Image", scale=0, variant="primary") result = gr.Image(label="Result", show_label=False) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0, ) temperature = gr.Slider( label="Temperature", minimum=0.0, maximum=2.0, step=0.1, value=1.0, ) gr.on( triggers=[run_image.click, prompt.submit], fn=infer, inputs=[ prompt, guidance_scale, temperature ], outputs=[result], ) gr.on( triggers=[run_prompt.click], fn=prompt_generator, outputs=[prompt], ) if __name__ == "__main__": demo.launch()