Azure99 commited on
Commit
c5bb7ca
·
verified ·
1 Parent(s): 885aeb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -10,8 +10,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
10
 
11
  device = torch.device("cuda:0")
12
 
13
- llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5-14b", torch_dtype=torch.float16, device_map="auto")
14
- tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5-14b")
15
  diffusion_pipe = DiffusionPipeline.from_pretrained(
16
  "playgroundai/playground-v2.5-1024px-aesthetic",
17
  torch_dtype=torch.float16,
@@ -34,7 +34,7 @@ def save_image(img):
34
 
35
 
36
  LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。
37
- [输出]是一个json,包含description、en_description、expanded_description三个字符串字段,请直接输出json,不要输出任何 无关内容。
38
 
39
  下面是一些示例:
40
  [作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。"
@@ -60,22 +60,29 @@ def generate(
60
  prompt: str,
61
  progress=gr.Progress(track_tqdm=True),
62
  ):
63
-
64
  input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
65
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
66
  max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
67
  llm_result = llm.generate(**generation_kwargs)
68
  llm_result = llm_result.cpu()[0][len(input_ids):]
69
  llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
 
 
70
  print(llm_result)
71
- expanded_description = json.loads(llm_result)["expanded_description"]
72
- print(expanded_description)
 
 
 
 
 
 
73
 
74
  seed = random.randint(0, 2147483647)
75
  generator = torch.Generator().manual_seed(seed)
76
 
77
  images = diffusion_pipe(
78
- prompt=expanded_description,
79
  negative_prompt=None,
80
  width=1024,
81
  height=1024,
@@ -107,7 +114,7 @@ with gr.Blocks(css=css) as demo:
107
  container=False,
108
  )
109
  run_button = gr.Button("Run", scale=0)
110
- result = gr.Gallery(label="Result", columns=1, show_label=False)
111
 
112
  gr.on(
113
  triggers=[
 
10
 
11
  device = torch.device("cuda:0")
12
 
13
+ llm = AutoModelForCausalLM.from_pretrained("Azure99/blossom-v5-4b", torch_dtype=torch.float16, device_map="auto")
14
+ tokenizer = AutoTokenizer.from_pretrained("Azure99/blossom-v5-4b")
15
  diffusion_pipe = DiffusionPipeline.from_pretrained(
16
  "playgroundai/playground-v2.5-1024px-aesthetic",
17
  torch_dtype=torch.float16,
 
34
 
35
 
36
  LLM_PROMPT = '''你的任务是从输入的[作画要求]中抽取画面描述(description),然后description翻译为英文(en_description),最后对en_description进行扩写(expanded_description),增加足够多的细节,且符合人类的第一直觉。
37
+ [输出]是一个json,包含description、en_description、expanded_description三个字符串字段,请直接输出一个完整的json,不要输出任何解释或其他无关内容。
38
 
39
  下面是一些示例:
40
  [作画要求]->"画一幅画:落霞与孤鹜齐飞,秋水共长天一色。"
 
60
  prompt: str,
61
  progress=gr.Progress(track_tqdm=True),
62
  ):
 
63
  input_ids = get_input_ids(LLM_PROMPT.replace("$USER_PROMPT", json.dumps(prompt, ensure_ascii=False)), BOT_PREFIX)
64
  generation_kwargs = dict(input_ids=torch.tensor([input_ids]).to(llm.device), do_sample=True,
65
  max_new_tokens=512, temperature=0.5, top_p=0.85, top_k=50, repetition_penalty=1.05)
66
  llm_result = llm.generate(**generation_kwargs)
67
  llm_result = llm_result.cpu()[0][len(input_ids):]
68
  llm_result = BOT_PREFIX + tokenizer.decode(llm_result, skip_special_tokens=True)
69
+ print("----------")
70
+ print(prompt)
71
  print(llm_result)
72
+ en_prompt = prompt
73
+ expanded_prompt = prompt
74
+ try:
75
+ en_prompt = json.loads(llm_result)["en_description"]
76
+ expanded_prompt = json.loads(llm_result)["expanded_description"]
77
+ except:
78
+ print("error, fallback to original prompt")
79
+ pass
80
 
81
  seed = random.randint(0, 2147483647)
82
  generator = torch.Generator().manual_seed(seed)
83
 
84
  images = diffusion_pipe(
85
+ prompt=[en_prompt, expanded_prompt],
86
  negative_prompt=None,
87
  width=1024,
88
  height=1024,
 
114
  container=False,
115
  )
116
  run_button = gr.Button("Run", scale=0)
117
+ result = gr.Gallery(label="Result", columns=2, rows=1, show_label=False)
118
 
119
  gr.on(
120
  triggers=[