akhil20187's picture
Create main.py
9c22932 verified
raw
history blame
3.68 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
import replicate
import os
import json
from huggingface_hub import InferenceClient
app = FastAPI()
#defining hf api key
hf_token=os.environ.get("ht_token")
file_path = "replicate_workflow.json"
def load_workflow_from_file(file_path):
with open(file_path, 'r') as file:
return json.load(file)
@app.post("/prompt")
async def generate_story(prompt: dict):
print (prompt)
prompt = json.dumps(prompt)
client = InferenceClient(model="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
response = client.text_generation(f"""<s><INST>You are a kids story writer. Your task is to analyse the input and provide a nice little short kids story. Given the prompt {prompt}, include "anime style, ultra HD, 4K" in every imageDescription prompt suitable for stable diffusion with upto 20 words, write 6 scenes, maximum of 10 words in comic scene of storyForTheComic, generate a JSON response like [/INST]</s>{[
{
"sceneNumber": 1,
"imageDescription": "Describe the scene in 30 words ",
"storyForTheComic": "Write the story here in 10 to 20 words"
},
{
"sceneNumber": 2,
"imageDescription": "Describe the scene in 30 words",
"storyForTheComic": "Write the story here in 10 to 20 words"
},
{
"sceneNumber": 3,
"imageDescription": "Describe the scene in 30 words",
"storyForTheComic": "Write the story here in 10 to 20 words"
}
]}""", max_new_tokens=800)
print (response)
return response
@app.post("/workflow")
async def run_workflow(input: dict):
workflow = load_workflow_from_file(file_path)
#print (input)
print (input['input'][1])
workflow["16"]["inputs"]["header_text"] = "Kids going to park" #Title
workflow["16"]["inputs"]["footer_text"] = "" #Author
workflow["13"]["inputs"]["text_bottom"] = f"{input['input'][7]}" #comic1
workflow["83"]["inputs"]["text_bottom"] = f"{input['input'][8]}" #comic2
workflow["93"]["inputs"]["text_bottom"] = f"{input['input'][9]}"#comic3
workflow["103"]["inputs"]["text_bottom"] = f"{input['input'][10]}" #comic4
workflow["113"]["inputs"]["text_bottom"] = f"{input['input'][11]}" #comic5
workflow["123"]["inputs"]["text_bottom"] = f"{input['input'][12]}" #comic6
workflow["72"]["inputs"]["text"] = f"{input['input'][1]}"#image1
workflow["89"]["inputs"]["text"] = f"{input['input'][2]}"#image2
workflow["99"]["inputs"]["text"] = f"{input['input'][3]}"#image3
workflow["109"]["inputs"]["text"] = f"{input['input'][4]}"#image4
workflow["119"]["inputs"]["text"] = f"{input['input'][5]}"#image5
workflow["129"]["inputs"]["text"] = f"{input['input'][6]}"#image6
auth_token = os.environ.get("REPLICATE_API_TOKEN")
if not auth_token:
raise HTTPException(status_code=400, detail="REPLICATE_API_TOKEN not found in environment variables")
output = replicate.run(
"fofr/any-comfyui-workflow:5c922132f43c8f7f35825d0687a4733ab974749d7704e6cf0fb62b4a258fa55a",
input={
"workflow_json": json.dumps(workflow),
"randomise_seeds": True,
"return_temp_files": False
},
#auth_token=auth_token
)
# Extracting the link from the output
link = output[0] if isinstance(output, list) and len(output) > 0 else None
# Returning the link in the response JSON
print ({"output": output})
print ({"link": link})
#image_link = output.get("output",[None])[0]
return {"link": link}
#return {"output": output}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)