Geek7 commited on
Commit
a63902d
·
verified ·
1 Parent(s): 1dec5c9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -38
app.py CHANGED
@@ -5,17 +5,10 @@ from externalmod import gr_Interface_load
5
  import asyncio
6
  import os
7
  from threading import RLock
8
- from flask import Flask, request, jsonify, send_file
9
- from flask_cors import CORS
10
 
11
  lock = RLock()
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
 
14
- # Initialize Flask app
15
- app = Flask(__name__)
16
- CORS(app) # Enable CORS for all routes
17
-
18
- # Load models using gr_Interface_load
19
  def load_fn(models):
20
  global models_load
21
  models_load = {}
@@ -36,7 +29,6 @@ MAX_SEED = 3999999999
36
  default_models = models[:num_models]
37
  inference_timeout = 600
38
 
39
- # Inference function with generate_api embedded
40
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
41
  kwargs = {"seed": seed}
42
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
@@ -46,43 +38,23 @@ async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
46
  except (Exception, asyncio.TimeoutError) as e:
47
  print(e)
48
  print(f"Task timed out: {model_str}")
49
- if not task.done():
50
  task.cancel()
51
  result = None
52
  if task.done() and result is not None:
53
  with lock:
54
- png_path = "generated_image.png"
55
  result.save(png_path)
56
  return png_path
57
  return None
58
 
59
- # Flask API to call the generate_api function
60
- @app.route('/predict', methods=['POST'])
61
- def predict():
62
- data = request.get_json()
63
-
64
- # Validate required fields
65
- if not data or 'prompt' not in data or 'model_str' not in data:
66
- return jsonify({"error": "Missing required fields"}), 400
67
-
68
- model_str = data['model_str']
69
- prompt = data['prompt']
70
- seed = data.get('seed', 1)
71
-
72
- # Make the asynchronous call to the infer function within the Flask route
73
- try:
74
- image_path = asyncio.run(infer(model_str, prompt, seed)) # Directly call infer function here
75
- if image_path:
76
- return send_file(image_path, mimetype='image/png')
77
- else:
78
- return jsonify({"error": "Failed to generate image"}), 500
79
- except Exception as e:
80
- return jsonify({"error": str(e)}), 500
81
-
82
- if __name__ == '__main__':
83
- # Run Flask app
84
- app.run(debug=True)
85
 
86
- # You can optionally launch the Gradio interface in parallel
87
- iface = gr.Interface(fn=infer, inputs=["text", "text", "number"], outputs="file")
88
  iface.launch(show_api=True, share=True)
 
5
  import asyncio
6
  import os
7
  from threading import RLock
 
 
8
 
9
  lock = RLock()
10
  HF_TOKEN = os.environ.get("HF_TOKEN")
11
 
 
 
 
 
 
12
  def load_fn(models):
13
  global models_load
14
  models_load = {}
 
29
  default_models = models[:num_models]
30
  inference_timeout = 600
31
 
 
32
  async def infer(model_str, prompt, seed=1, timeout=inference_timeout):
33
  kwargs = {"seed": seed}
34
  task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN))
 
38
  except (Exception, asyncio.TimeoutError) as e:
39
  print(e)
40
  print(f"Task timed out: {model_str}")
41
+ if not task.done():
42
  task.cancel()
43
  result = None
44
  if task.done() and result is not None:
45
  with lock:
46
+ png_path = "image.png"
47
  result.save(png_path)
48
  return png_path
49
  return None
50
 
51
+ # Expose Gradio API
52
+ def generate_api(model_str, prompt, seed=1):
53
+ result = asyncio.run(infer(model_str, prompt, seed))
54
+ if result:
55
+ return result # Path to generated image
56
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Launch Gradio API without frontend
59
+ iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file")
60
  iface.launch(show_api=True, share=True)