deepseek-janus / app.py
HatmanStack
local model load
484699a
import gradio as gr
import torch
from transformers import AutoConfig, AutoModelForCausalLM, pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from janus.models import VLChatProcessor
import random
import numpy as np
import spaces
import json
import os
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_checkpoint = "./Flux-Prompt"
enhancer_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
enhancer_model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).to(cuda_device)
model_path = "deepseek-ai/Janus-Pro-7B"
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
language_config=language_config,
trust_remote_code=True)
if torch.cuda.is_available():
vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
vl_gpt = vl_gpt.to(torch.float16)
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
def generate(input_ids,
width,
height,
temperature,
cfg_weight,
parallel_size: int = 1,
image_token_num_per_image: int = 576,
patch_size: int = 16):
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)
for i in range(parallel_size * 2):
tokens[i, :] = input_ids
if i % 2 != 0:
tokens[i, 1:-1] = vl_chat_processor.pad_id
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)
pkv = None
for i in range(image_token_num_per_image):
with torch.no_grad():
outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds,
use_cache=True,
past_key_values=pkv)
pkv = outputs.past_key_values
hidden_states = outputs.last_hidden_state
logits = vl_gpt.gen_head(hidden_states[:, -1, :])
logit_cond = logits[0::2, :]
logit_uncond = logits[1::2, :]
logit_sum = logit_cond - logit_uncond
logits = logit_uncond + cfg_weight * logit_sum
probs = torch.softmax(logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_tokens[:, i] = next_token.squeeze(dim=-1)
next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
inputs_embeds = img_embeds.unsqueeze(dim=1)
patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int),
shape=[parallel_size, 8, width // patch_size, height // patch_size])
return generated_tokens.to(dtype=torch.int), patches
def unpack(dec, width, height, parallel_size=1):
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img
@torch.inference_mode()
@spaces.GPU()
def infer(
prompt,
guidance_scale,
temperature,
progress=gr.Progress(track_tqdm=True),
):
seed = random.randint(0, 2000)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
parallel_size = 1
height=384
width=384
with torch.no_grad():
messages = [
{'role': '<|User|>', 'content': prompt},
{'role': '<|Assistant|>', 'content': ''}
]
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=messages,
sft_format=vl_chat_processor.sft_format,
system_prompt=''
)
text += vl_chat_processor.image_start_tag
input_ids = torch.LongTensor(tokenizer.encode(text))
try:
output, patches = generate(input_ids,
width // 16 * 16,
height // 16 * 16,
cfg_weight=guidance_scale,
parallel_size=parallel_size,
temperature=temperature)
images = unpack(patches,
width // 16 * 16,
height // 16 * 16,
parallel_size=parallel_size)
return images[0]
except RuntimeError as e:
print(f"Error during generation: {e}")
raise gr.Error("Generation failed. Please try different parameters.")
finally:
torch.cuda.empty_cache()
def load_seeds():
try:
with open('seeds.json', 'r') as f:
return json.load(f)
except FileNotFoundError:
print("seeds.json not found")
return {}
@spaces.GPU()
def prompt_generator():
seeds = load_seeds()
if seeds:
seed = random.choice(seeds["seeds"])
input_ids = enhancer_tokenizer(seed, return_tensors='pt').input_ids.to(cuda_device)
random_seed = random.randint(0, 2000)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
answer = enhancer_model.generate(input_ids, max_length=256, num_return_sequences=1, temperature=1.0, repetition_penalty=1.2)
final_answer = enhancer_tokenizer.decode(answer[0], skip_special_tokens=True)
return final_answer
return "Unable to generate prompt - no seeds available"
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
.center-container {
display: flex;
justify-content: center;
align-items: center;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("""
<style>
::-webkit-scrollbar {
display: none;
}
.header-container {
display: flex;
align-items: center;
justify-content: center;
gap: 1rem;
margin-bottom: 2rem;
}
.header-container h1 {
margin: 0;
font-size: 2.5rem;
font-weight: bold;
}
</style>
""")
with gr.Column(elem_id="col-container"):
with gr.Row(elem_classes="header-container"):
gr.Image("./deepseek.jpg",
width=100,
height=100,
show_fullscreen_button=False,
show_download_button=False,
show_share_button=False,
container=False)
gr.Markdown("<h1>DeepSeek</h1><h1>Janus-Pro-7B</h1>")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
placeholder="Enter your prompt",
container=False,
)
with gr.Row(elem_classes="center-container"):
run_prompt = gr.Button("Generate Prompt", scale=0, variant="primary")
run_image = gr.Button("Generate Image", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
temperature = gr.Slider(
label="Temperature",
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
)
gr.on(
triggers=[run_image.click, prompt.submit],
fn=infer,
inputs=[
prompt,
guidance_scale,
temperature
],
outputs=[result],
)
gr.on(
triggers=[run_prompt.click],
fn=prompt_generator,
outputs=[prompt],
)
if __name__ == "__main__":
demo.launch()