Spaces:
Runtime error
Runtime error
import gradio as gr | |
import io | |
from PIL import Image | |
import base64 | |
import requests | |
import json | |
from PIL import Image | |
def read_content(file_path: str) -> str: | |
"""read the content of target file | |
""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def base2picture(resbase64): | |
res=resbase64.split(',')[1] | |
img_b64decode = base64.b64decode(res) | |
image = io.BytesIO(img_b64decode) | |
img = Image.open(image) | |
return img | |
def filter_content(raw_style: str): | |
if "(" in raw_style: | |
i = raw_style.index("(") | |
else : | |
i = -1 | |
if i == -1: | |
return raw_style | |
else : | |
return raw_style[:i] | |
def request_images(raw_text, class_draw, style_draw, batch_size): | |
if filter_content(class_draw) != "国画": | |
if filter_content(class_draw) != "通用": | |
raw_text = raw_text + f",{filter_content(class_draw)}" | |
for sty in style_draw: | |
raw_text = raw_text + f",{filter_content(sty)}" | |
print(f"raw text is {raw_text}") | |
url = "http://flagart.baai.ac.cn/api/general/" | |
elif filter_content(class_draw) == "国画": | |
if raw_text.endswith("国画"): | |
pass | |
else : | |
raw_text = raw_text + ",国画" | |
url = "http://flagart.baai.ac.cn/api/guohua/" | |
d = {"data":[raw_text, batch_size]} | |
r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) | |
result_text = r.text | |
content = json.loads(result_text)["data"][0] | |
images = [] | |
for i in range(batch_size): | |
# print(content[i]) | |
images.append(base2picture(content[i])) | |
return images | |
examples = [ | |
'水墨蝴蝶和牡丹花,国画', | |
'苍劲有力的墨竹,国画', | |
'暴风雨中的灯塔', | |
'机械小松鼠,科学幻想', | |
'中国水墨山水画,国画', | |
"Lighthouse in the storm", | |
"A dog", | |
"Landscape by 张大千", | |
"A tiger 长了兔子耳朵", | |
"A baby bird 铅笔素描", | |
] | |
if __name__ == "__main__": | |
block = gr.Blocks(css=read_content('style.css')) | |
with block: | |
gr.HTML(read_content("header.html")) | |
with gr.Group(): | |
with gr.Box(): | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
text = gr.Textbox( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Input text(输入文字)", | |
interactive=True, | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
btn = gr.Button("Generate image").style( | |
margin=False, | |
rounded=(True, True, True, True), | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", | |
"照片,摄影(picture photography)", "油画(oil painting)", | |
"铅笔素描(pencil sketch)", "CG", | |
"水彩画(watercolor painting)", "水墨画(ink and wash)", | |
"插画(illustrations)", "3D"], | |
label="生成类型(type)", | |
show_label=True, | |
value="通用(general)") | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", | |
"概念艺术(concept art)", "Warming lighting", | |
"Dramatic lighting", "Natural lighting", | |
"虚幻引擎(unreal engine)", "4k", "8k", | |
"充满细节(full details)"], | |
label="画面风格(style)", | |
show_label=True, | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
sample_size = gr.Slider(minimum=1, | |
maximum=4, | |
step=1, | |
label="生成数量(number)", | |
show_label=True, | |
interactive=True, | |
) | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(grid=[2], height="auto") | |
gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) | |
text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery) | |
btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery) | |
gr.HTML(read_content("footer.html")) | |
# gr.Image('./contributors.png') | |
block.queue(max_size=50, concurrency_count=20).launch() |