Geek7 commited on
Commit
44aff97
1 Parent(s): a63902d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -59
app.py CHANGED
@@ -1,60 +1,45 @@
1
- import gradio as gr
2
- from random import randint
3
- from all_models import models
4
- from externalmod import gr_Interface_load
5
- import asyncio
6
- import os
7
- from threading import RLock
8
-
9
- lock = RLock()
10
- HF_TOKEN = os.environ.get("HF_TOKEN")
11
-
12
- def load_fn(models):
13
- global models_load
14
- models_load = {}
15
-
16
- for model in models:
17
- if model not in models_load.keys():
18
- try:
19
- m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN)
20
- except Exception as error:
21
- print(error)
22
- m = gr.Interface(lambda: None, ['text'], ['image'])
23
- models_load.update({model: m})
24
-
25
- load_fn(models)
26
-
27
- num_models = 6
28
- MAX_SEED = 3999999999
29
- default_models = models[:num_models]
30
- inference_timeout = 600
31
-
32
- async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
- kwargs = {"seed": seed}
34
- task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
35
- await asyncio.sleep(0)
36
  try:
37
- result = await asyncio.wait_for(task, timeout=timeout)
38
- except (Exception, asyncio.TimeoutError) as e:
39
- print(e)
40
- print(f"Task timed out: {model_str}")
41
- if not task.done():
42
- task.cancel()
43
- result = None
44
- if task.done() and result is not None:
45
- with lock:
46
- png_path = "image.png"
47
- result.save(png_path)
48
- return png_path
49
- return None
50
-
51
- # Expose Gradio API
52
- def generate_api(model_str, prompt, seed=1):
53
- result = asyncio.run(infer(model_str, prompt, seed))
54
- if result:
55
- return result # Path to generated image
56
- return None
57
-
58
- # Launch Gradio API without frontend
59
- iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
- iface.launch(show_api=True, share=True)
 
1
+ from flask import Flask, request, jsonify, send_file
2
+ from flask_cors import CORS
3
+ from gradio_client import Client
4
+ from all_models import models # Import the models list
5
+
6
+ app = Flask(__name__)
7
+ CORS(app)
8
+
9
+ # Initialize Gradio Client with the first model in the list
10
+ client = Client("Geek7/mdztxi2")
11
+
12
+ @app.route('/predict', methods=['POST'])
13
+ def predict():
14
+ data = request.get_json()
15
+
16
+ # Validate required fields
17
+ if not data or 'model_str' not in data or 'prompt' not in data or 'seed' not in data:
18
+ return jsonify({"error": "Missing required fields"}), 400
19
+
20
+ model_str = data['model_str']
21
+ prompt = data['prompt']
22
+ seed = data['seed']
23
+
24
+ # Check if the model_str exists in the models list
25
+ if model_str not in models:
26
+ return jsonify({"error": f"Model '{model_str}' is not available."}), 400
27
+
 
 
 
 
 
 
 
 
28
  try:
29
+ # Send a request to the Gradio Client and get the result
30
+ result = client.predict(
31
+ model_str=model_str,
32
+ prompt=prompt,
33
+ seed=seed,
34
+ api_name="/predict"
35
+ )
36
+
37
+ # Save the result to a file (assuming it returns a filepath)
38
+ result_path = result # Result is already the filepath
39
+ return send_file(result_path, mimetype='image/png')
40
+
41
+ except Exception as e:
42
+ return jsonify({"error": str(e)}), 500
43
+
44
+ if __name__ == '__main__':
45
+ app.run(debug=True)