AltDiffusion / app.py
root
modified app.py
d8a5a4d
raw
history blame
5.7 kB
import gradio as gr
import io
from PIL import Image
import base64
import requests
import json
from PIL import Image
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
def base2picture(resbase64):
res=resbase64.split(',')[1]
img_b64decode = base64.b64decode(res)
image = io.BytesIO(img_b64decode)
img = Image.open(image)
return img
def filter_content(raw_style: str):
if "(" in raw_style:
i = raw_style.index("(")
else :
i = -1
if i == -1:
return raw_style
else :
return raw_style[:i]
def request_images(raw_text, class_draw, style_draw, batch_size):
if filter_content(class_draw) != "国画":
if filter_content(class_draw) != "通用":
raw_text = raw_text + f",{filter_content(class_draw)}"
for sty in style_draw:
raw_text = raw_text + f",{filter_content(sty)}"
print(f"raw text is {raw_text}")
url = "http://flagart.baai.ac.cn/api/general/"
elif filter_content(class_draw) == "国画":
if raw_text.endswith("国画"):
pass
else :
raw_text = raw_text + ",国画"
url = "http://flagart.baai.ac.cn/api/guohua/"
d = {"data":[raw_text, batch_size]}
r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"})
result_text = r.text
content = json.loads(result_text)["data"][0]
images = []
for i in range(batch_size):
# print(content[i])
images.append(base2picture(content[i]))
return images
examples = [
'水墨蝴蝶和牡丹花,国画',
'苍劲有力的墨竹,国画',
'暴风雨中的灯塔',
'机械小松鼠,科学幻想',
'中国水墨山水画,国画',
"Lighthouse in the storm",
"A dog",
"Landscape by 张大千",
"A tiger 长了兔子耳朵",
"A baby bird 铅笔素描",
]
if __name__ == "__main__":
block = gr.Blocks(css=read_content('style.css'))
with block:
gr.HTML(read_content("header.html"))
with gr.Group():
with gr.Box():
with gr.Row().style(mobile_collapse=False, equal_height=True):
text = gr.Textbox(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Input text(输入文字)",
interactive=True,
).style(
border=(True, False, True, True),
rounded=(True, False, False, True),
container=False,
)
btn = gr.Button("Generate image").style(
margin=False,
rounded=(True, True, True, True),
)
with gr.Row().style(mobile_collapse=False, equal_height=True):
class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)",
"照片,摄影(picture photography)", "油画(oil painting)",
"铅笔素描(pencil sketch)", "CG",
"水彩画(watercolor painting)", "水墨画(ink and wash)",
"插画(illustrations)", "3D"],
label="生成类型(type)",
show_label=True,
value="通用(general)")
with gr.Row().style(mobile_collapse=False, equal_height=True):
style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)",
"概念艺术(concept art)", "Warming lighting",
"Dramatic lighting", "Natural lighting",
"虚幻引擎(unreal engine)", "4k", "8k",
"充满细节(full details)"],
label="画面风格(style)",
show_label=True,
)
with gr.Row().style(mobile_collapse=False, equal_height=True):
sample_size = gr.Slider(minimum=1,
maximum=4,
step=1,
label="生成数量(number)",
show_label=True,
interactive=True,
)
gallery = gr.Gallery(
label="Generated images", show_label=False, elem_id="gallery"
).style(grid=[2], height="auto")
gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery)
btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery)
gr.HTML(read_content("footer.html"))
# gr.Image('./contributors.png')
block.queue(max_size=50, concurrency_count=20).launch()