File size: 3,684 Bytes
9c22932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from fastapi import FastAPI, HTTPException 
from fastapi.responses import RedirectResponse
import replicate
import os
import json
from huggingface_hub import InferenceClient

app = FastAPI()
#defining hf api key

hf_token=os.environ.get("ht_token")

file_path = "replicate_workflow.json"

def load_workflow_from_file(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)


@app.post("/prompt")
async def generate_story(prompt: dict):
    print (prompt)
    prompt = json.dumps(prompt)
    client = InferenceClient(model="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
    response = client.text_generation(f"""<s><INST>You are a kids story writer. Your task is to analyse the input and provide a nice little short kids story. Given the prompt {prompt}, include "anime style, ultra HD, 4K" in every imageDescription prompt suitable for stable diffusion with upto 20 words, write 6 scenes, maximum of 10 words in comic scene of storyForTheComic, generate a JSON response like [/INST]</s>{[
 {
    "sceneNumber": 1,
    "imageDescription": "Describe the scene in 30 words ",
    "storyForTheComic": "Write the story here in 10 to 20 words"
 },
 {
    "sceneNumber": 2,
    "imageDescription": "Describe the scene in 30 words",
    "storyForTheComic": "Write the story here in 10 to 20 words"
 },
 {
    "sceneNumber": 3,
    "imageDescription": "Describe the scene in 30 words",
    "storyForTheComic": "Write the story here in 10 to 20 words"
 }
]}""", max_new_tokens=800)
    print (response)
    return response


@app.post("/workflow")
async def run_workflow(input: dict):
    workflow = load_workflow_from_file(file_path)
    #print (input)    
    print (input['input'][1])

    workflow["16"]["inputs"]["header_text"] = "Kids going to park" #Title
    workflow["16"]["inputs"]["footer_text"] = "" #Author
    workflow["13"]["inputs"]["text_bottom"] = f"{input['input'][7]}" #comic1
    workflow["83"]["inputs"]["text_bottom"] = f"{input['input'][8]}" #comic2
    workflow["93"]["inputs"]["text_bottom"] = f"{input['input'][9]}"#comic3
    workflow["103"]["inputs"]["text_bottom"] =  f"{input['input'][10]}" #comic4
    workflow["113"]["inputs"]["text_bottom"] = f"{input['input'][11]}" #comic5
    workflow["123"]["inputs"]["text_bottom"] = f"{input['input'][12]}" #comic6
    workflow["72"]["inputs"]["text"] = f"{input['input'][1]}"#image1
    workflow["89"]["inputs"]["text"] = f"{input['input'][2]}"#image2
    workflow["99"]["inputs"]["text"] = f"{input['input'][3]}"#image3
    workflow["109"]["inputs"]["text"] = f"{input['input'][4]}"#image4
    workflow["119"]["inputs"]["text"] = f"{input['input'][5]}"#image5
    workflow["129"]["inputs"]["text"] = f"{input['input'][6]}"#image6

    auth_token = os.environ.get("REPLICATE_API_TOKEN")
    if not auth_token:
        raise HTTPException(status_code=400, detail="REPLICATE_API_TOKEN not found in environment variables")

    output = replicate.run(
        "fofr/any-comfyui-workflow:5c922132f43c8f7f35825d0687a4733ab974749d7704e6cf0fb62b4a258fa55a",
        input={
            "workflow_json": json.dumps(workflow),
            "randomise_seeds": True,
            "return_temp_files": False
        },
        #auth_token=auth_token
    )
    # Extracting the link from the output
    link = output[0] if isinstance(output, list) and len(output) > 0 else None

    # Returning the link in the response JSON
    
    print ({"output": output})
    print ({"link": link})
    #image_link = output.get("output",[None])[0]
    return {"link": link}
    #return {"output": output}
    
    



if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)