import gradio as gr import openai import requests import os from dotenv import load_dotenv import io import sys import json import PIL import time from stability_sdk import client import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation import markdown2 title="Character Generator AI" inputs_label="Please tell me the characteristics of the character you want to create" outputs_label="AI will generate a character description and visualize it" visual_outputs_label="Character Visuals" description=""" - Please input within the limits of 1000 characters. It may take about 50 seconds to generate a response. """ article = """ """ load_dotenv() openai.api_key = os.getenv('OPENAI_API_KEY') os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443' stability_api = client.StabilityInference( key=os.getenv('STABILITY_KEY'), verbose=True, engine="stable-diffusion-xl-1024-v1-0", ) MODEL = "gpt-4" def get_filetext(filename, cache={}): if filename in cache: return cache[filename] else: if not os.path.exists(filename): raise ValueError(f"File '{filename}' not found") with open(filename, "r") as f: text = f.read() cache[filename] = text return text def get_functions_from_schema(filename): schema = get_filetext(filename) schema_json = json.loads(schema) functions = schema_json.get("functions") return functions class StabilityAI: @classmethod def generate_image(cls, visualize_prompt): print("visualize_prompt:"+visualize_prompt) answers = stability_api.generate( prompt=visualize_prompt, ) for resp in answers: for artifact in resp.artifacts: if artifact.finish_reason == generation.FILTER: print("NSFW") if artifact.type == generation.ARTIFACT_IMAGE: img = PIL.Image.open(io.BytesIO(artifact.binary)) return img class OpenAI: @classmethod def chat_completion(cls, prompt, start_with=""): constraints = get_filetext(filename = "constraints.md") template = get_filetext(filename = "template.md") data = { "model": MODEL, "messages": [ {"role": "system", "content": constraints} ,{"role": "assistant", "content": template} ,{"role": "user", "content": prompt} ,{"role": "assistant", "content": start_with} ], } start = time.time() response = requests.post( "https://api.openai.com/v1/chat/completions", headers={ "Content-Type": "application/json", "Authorization": f"Bearer {openai.api_key}" }, json=data ) print("gpt generation time: "+str(time.time() - start)) result = response.json() print(result) content = result["choices"][0]["message"]["content"].strip() visualize_prompt = content.split("### Prompt for Visual Expression\n\n")[1] answers = stability_api.generate( prompt=visualize_prompt, ) @classmethod def chat_completion_with_function(cls, prompt, messages, functions): print("prompt:"+prompt) # 文章生成にかかる時間を計測する start = time.time() # ChatCompletion APIを呼び出す response = openai.ChatCompletion.create( model=MODEL, messages=messages, functions=functions, function_call={"name": "create_character"} ) print("gpt generation time: "+str(time.time() - start)) # ChatCompletion APIから返された結果を取得する message = response.choices[0].message print("chat completion message: " + json.dumps(message, indent=2)) return message class NajiminoAI: def __init__(self, user_message): self.user_message = user_message def generate_recipe_prompt(self): template = get_filetext(filename="template.md") prompt = f""" {self.user_message} --- 上記を元に、下記テンプレートを埋めてください。 --- {template} """ return prompt def create_character(self, lang, title, description, prompt_for_visual_expression): template = get_filetext(filename = "template.md") debug_message = template.format( lang=lang, title=title, description=description, prompt_for_visual_expression=prompt_for_visual_expression ) print("debug_message: "+debug_message) return debug_message @classmethod def generate(cls, user_message): najiminoai = NajiminoAI(user_message) return najiminoai.generate_recipe() def generate_recipe(self): user_message = self.user_message constraints = get_filetext(filename = "constraints.md") messages = [ {"role": "system", "content": constraints} ,{"role": "user", "content": user_message} ] functions = get_functions_from_schema('schema.json') message = OpenAI.chat_completion_with_function(prompt=user_message, messages=messages, functions=functions) image = None html = None if message.get("function_call"): function_name = message["function_call"]["name"] args = json.loads(message["function_call"]["arguments"]) lang=args.get("lang") title=args.get("title") description=args.get("description") prompt_for_visual_expression_in_en=args.get("prompt_for_visual_expression_in_en") prompt_for_visual_expression = \ prompt_for_visual_expression_in_en print("prompt_for_visual_expression: "+prompt_for_visual_expression) # 画像生成にかかる時間を計測する start = time.time() image = StabilityAI.generate_image(prompt_for_visual_expression) print("image generation time: "+str(time.time() - start)) function_response = self.create_character( lang=lang, title=title, description=description, prompt_for_visual_expression=prompt_for_visual_expression ) html = ( "
" + "

" + markdown2.markdown(function_response) + "

" ) return [image, html] def main(): # インプット例をクリックした時のコールバック関数 def click_example(example): # クリックされたインプット例をテキストボックスに自動入力 inputs.value = example time.sleep(0.1) # テキストボックスに文字が表示されるまで待機 # 自動入力後に実行ボタンをクリックして結果を表示 execute_button.click() iface = gr.Interface(fn=NajiminoAI.generate, examples=[ ["子どもに愛される丸顔で青いイルカのキャラクター"], ["活発で夏が似合う赤髪の女の子"], ["黒い鎧を着た魔王"], ], inputs=gr.Textbox(label=inputs_label), outputs=[ gr.Image(label="Visual Expression"), "html" ], title=title, description=description, article=article ) iface.launch() if __name__ == '__main__': function = '' if len(sys.argv) > 1: function = sys.argv[1] if function == 'generate': NajiminoAI.generate("A brave knight with a mighty sword and strong armor") elif function == 'generate_image': image = StabilityAI.generate_image("Imagine a brave knight with a mighty sword and strong armor. He has a chiseled jawline and a confident expression on his face. His armor gleams under the sunlight, showing off its intricate design and craftsmanship. He holds his sword with pride, ready to protect his kingdom and its people at any cost.") print("image: " + image) if type(image) == PIL.PngImagePlugin.PngImageFile: image.save("image.png") else: main()