from gradio_tool import GradioTool import os class StableDiffusionTool(GradioTool): """Tool for calling stable diffusion from llm""" def __init__( self, name="StableDiffusion", description=( "An image generator. Use this to generate images based on " "text input. Input should be a description of what the image should " "look like. The output will be a path to an image file." ), src="gradio-client-demos/stable-diffusion", hf_token=None, ) -> None: super().__init__(name, description, src, hf_token) def create_job(self, query: str) -> Job: return self.client.submit(query, "", 9, fn_index=1) def postprocess(self, output: str) -> str: return [os.path.join(output, i) for i in os.listdir(output) if not i.endswith("json")][0] def _block_input(self, gr) -> "gr.components.Component": return gr.Textbox() def _block_output(self, gr) -> "gr.components.Component": return gr.Image()