akhil20187 commited on
Commit
9c22932
·
verified ·
1 Parent(s): 5486374

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +97 -0
main.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import RedirectResponse
3
+ import replicate
4
+ import os
5
+ import json
6
+ from huggingface_hub import InferenceClient
7
+
8
+ app = FastAPI()
9
+ #defining hf api key
10
+
11
+ hf_token=os.environ.get("ht_token")
12
+
13
+ file_path = "replicate_workflow.json"
14
+
15
+ def load_workflow_from_file(file_path):
16
+ with open(file_path, 'r') as file:
17
+ return json.load(file)
18
+
19
+
20
+ @app.post("/prompt")
21
+ async def generate_story(prompt: dict):
22
+ print (prompt)
23
+ prompt = json.dumps(prompt)
24
+ client = InferenceClient(model="mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
25
+ response = client.text_generation(f"""<s><INST>You are a kids story writer. Your task is to analyse the input and provide a nice little short kids story. Given the prompt {prompt}, include "anime style, ultra HD, 4K" in every imageDescription prompt suitable for stable diffusion with upto 20 words, write 6 scenes, maximum of 10 words in comic scene of storyForTheComic, generate a JSON response like [/INST]</s>{[
26
+ {
27
+ "sceneNumber": 1,
28
+ "imageDescription": "Describe the scene in 30 words ",
29
+ "storyForTheComic": "Write the story here in 10 to 20 words"
30
+ },
31
+ {
32
+ "sceneNumber": 2,
33
+ "imageDescription": "Describe the scene in 30 words",
34
+ "storyForTheComic": "Write the story here in 10 to 20 words"
35
+ },
36
+ {
37
+ "sceneNumber": 3,
38
+ "imageDescription": "Describe the scene in 30 words",
39
+ "storyForTheComic": "Write the story here in 10 to 20 words"
40
+ }
41
+ ]}""", max_new_tokens=800)
42
+ print (response)
43
+ return response
44
+
45
+
46
+ @app.post("/workflow")
47
+ async def run_workflow(input: dict):
48
+ workflow = load_workflow_from_file(file_path)
49
+ #print (input)
50
+ print (input['input'][1])
51
+
52
+ workflow["16"]["inputs"]["header_text"] = "Kids going to park" #Title
53
+ workflow["16"]["inputs"]["footer_text"] = "" #Author
54
+ workflow["13"]["inputs"]["text_bottom"] = f"{input['input'][7]}" #comic1
55
+ workflow["83"]["inputs"]["text_bottom"] = f"{input['input'][8]}" #comic2
56
+ workflow["93"]["inputs"]["text_bottom"] = f"{input['input'][9]}"#comic3
57
+ workflow["103"]["inputs"]["text_bottom"] = f"{input['input'][10]}" #comic4
58
+ workflow["113"]["inputs"]["text_bottom"] = f"{input['input'][11]}" #comic5
59
+ workflow["123"]["inputs"]["text_bottom"] = f"{input['input'][12]}" #comic6
60
+ workflow["72"]["inputs"]["text"] = f"{input['input'][1]}"#image1
61
+ workflow["89"]["inputs"]["text"] = f"{input['input'][2]}"#image2
62
+ workflow["99"]["inputs"]["text"] = f"{input['input'][3]}"#image3
63
+ workflow["109"]["inputs"]["text"] = f"{input['input'][4]}"#image4
64
+ workflow["119"]["inputs"]["text"] = f"{input['input'][5]}"#image5
65
+ workflow["129"]["inputs"]["text"] = f"{input['input'][6]}"#image6
66
+
67
+ auth_token = os.environ.get("REPLICATE_API_TOKEN")
68
+ if not auth_token:
69
+ raise HTTPException(status_code=400, detail="REPLICATE_API_TOKEN not found in environment variables")
70
+
71
+ output = replicate.run(
72
+ "fofr/any-comfyui-workflow:5c922132f43c8f7f35825d0687a4733ab974749d7704e6cf0fb62b4a258fa55a",
73
+ input={
74
+ "workflow_json": json.dumps(workflow),
75
+ "randomise_seeds": True,
76
+ "return_temp_files": False
77
+ },
78
+ #auth_token=auth_token
79
+ )
80
+ # Extracting the link from the output
81
+ link = output[0] if isinstance(output, list) and len(output) > 0 else None
82
+
83
+ # Returning the link in the response JSON
84
+
85
+ print ({"output": output})
86
+ print ({"link": link})
87
+ #image_link = output.get("output",[None])[0]
88
+ return {"link": link}
89
+ #return {"output": output}
90
+
91
+
92
+
93
+
94
+
95
+ if __name__ == "__main__":
96
+ import uvicorn
97
+ uvicorn.run(app, host="0.0.0.0", port=8000)