MemeMaster / src /generation.py
承弱
add tracked files
2d9bfd2
raw
history blame
3.37 kB
import json
import os
import time
import gradio as gr
import requests
from src.log import logger
from src.util import download_images
def call_generation(prompt, mask_image_url,lora_path_ratio="0 1.0", image_width=512, image_height=512, BATCH_SIZE=1):
API_KEY = os.getenv("API_KEY_GENERATION")
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {API_KEY}",
"X-DashScope-Async": "enable",
}
data = {
"model": "jinshu-emoji",
"input": {
"prompt": prompt,
"mask_image_url": mask_image_url,
"lora_path_ratio": lora_path_ratio,
"base_model_path": 0,
},
"parameters": {
"n": BATCH_SIZE,
"image_width": image_width,
"image_height": image_height,
"text_position_revise": True,
}
}
url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/anytext/generation'
all_res_ = []
REPEAT = 1
for _ in range(REPEAT):
res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
all_res_.append(res_)
all_image_data = []
for res_ in all_res_:
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
request_id = res['request_id']
task_id = res['output']['task_id']
logger.info(f"task_id: {task_id}: Create Poster Imitation request success. Params: {data}")
# 异步查询
is_running = True
while is_running:
url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
res_ = requests.post(url_query, headers=headers)
respose_code = res_.status_code
if 200 == respose_code:
res = json.loads(res_.content.decode())
if "SUCCEEDED" == res['output']['task_status']:
logger.info(f"task_id: {task_id}: Generation task query success.")
results = res['output']
img_urls = results['result_url']
logger.info(f"task_id: {task_id}: {res}")
break
elif "FAILED" != res['output']['task_status']:
logger.debug(f"task_id: {task_id}: query result...")
time.sleep(1)
else:
raise gr.Error('Fail to get results from Generation task.')
else:
logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
raise gr.Error("Fail to query task result.")
logger.info(f"task_id: {task_id}: download generated images.")
img_data = download_images(img_urls, BATCH_SIZE)
logger.info(f"task_id: {task_id}: Generate done.")
all_image_data += img_data
else:
logger.error(f'Fail to create Generation task: {res_.content}')
raise gr.Error("Fail to create Generation task.")
if len(all_image_data) != REPEAT * BATCH_SIZE:
raise gr.Error("Fail to Generation.")
return all_image_data
if __name__ == "__main__":
call_generation()