Geek7 commited on
Commit
1ff4547
1 Parent(s): 39c36a6

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +59 -44
myapp.py CHANGED
@@ -1,45 +1,60 @@
1
- from flask import Flask, request, jsonify, send_file
2
- from flask_cors import CORS
3
- from gradio_client import Client
4
- from all_models import models # Import the models list
5
-
6
- myapp = Flask(__name__)
7
- CORS(myapp)
8
-
9
- # Initialize Gradio Client with the first model in the list
10
- client = Client("Geek7/mdztxi2")
11
-
12
- @app.route('/predict', methods=['POST'])
13
- def predict():
14
- data = request.get_json()
15
-
16
- # Validate required fields
17
- if not data or 'model_str' not in data or 'prompt' not in data or 'seed' not in data:
18
- return jsonify({"error": "Missing required fields"}), 400
19
-
20
- model_str = data['model_str']
21
- prompt = data['prompt']
22
- seed = data['seed']
23
-
24
- # Check if the model_str exists in the models list
25
- if model_str not in models:
26
- return jsonify({"error": f"Model '{model_str}' is not available."}), 400
27
-
 
 
 
 
 
 
 
 
28
  try:
29
- # Send a request to the Gradio Client and get the result
30
- result = client.predict(
31
- model_str=model_str,
32
- prompt=prompt,
33
- seed=seed,
34
- api_name="/predict"
35
- )
36
-
37
- # Save the result to a file (assuming it returns a filepath)
38
- result_path = result # Result is already the filepath
39
- return send_file(result_path, mimetype='image/png')
40
-
41
- except Exception as e:
42
- return jsonify({"error": str(e)}), 500
43
-
44
- if __name__ == '__main__':
45
- myapp.run(debug=True)
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from random import randint
3
+ from all_models import models
4
+ from externalmod import gr_Interface_load
5
+ import asyncio
6
+ import os
7
+ from threading import RLock
8
+
9
+ lock = RLock()
10
+ HF_TOKEN = os.environ.get("HF_TOKEN")
11
+
12
+ def load_fn(models):
13
+ global models_load
14
+ models_load = {}
15
+
16
+ for model in models:
17
+ if model not in models_load.keys():
18
+ try:
19
+ m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
20
+ except Exception as error:
21
+ print(error)
22
+ m = gr.Interface(lambda: None, ['text'], ['image'])
23
+ models_load.update({model: m})
24
+
25
+ load_fn(models)
26
+
27
+ num_models = 6
28
+ MAX_SEED = 3999999999
29
+ default_models = models[:num_models]
30
+ inference_timeout = 600
31
+
32
+ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
+ kwargs = {"seed": seed}
34
+ task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
+ await asyncio.sleep(0)
36
  try:
37
+ result = await asyncio.wait_for(task, timeout=timeout)
38
+ except (Exception, asyncio.TimeoutError) as e:
39
+ print(e)
40
+ print(f"Task timed out: {model_str}")
41
+ if not task.done():
42
+ task.cancel()
43
+ result = None
44
+ if task.done() and result is not None:
45
+ with lock:
46
+ png_path = "image.png"
47
+ result.save(png_path)
48
+ return png_path
49
+ return None
50
+
51
+ # Expose Gradio API
52
+ def generate_api(model_str, prompt, seed=1):
53
+ result = asyncio.run(infer(model_str, prompt, seed))
54
+ if result:
55
+ return result # Path to generated image
56
+ return None
57
+
58
+ # Launch Gradio API without frontend
59
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
+ iface.launch(show_api=True, share=True)