Geek7 commited on
Commit
33fcfe8
1 Parent(s): a40654e

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +54 -22
myapp.py CHANGED
@@ -1,34 +1,63 @@
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
 
 
3
  import os
 
4
  from huggingface_hub import InferenceClient
 
5
  from io import BytesIO # For converting image to bytes
6
 
7
- # Initialize the Flask app
8
  myapp = Flask(__name__)
9
  CORS(myapp) # Enable CORS for all routes
10
 
11
- # Initialize the InferenceClient with your Hugging Face token
12
- HF_TOKEN = os.environ.get("HF_TOKEN") # Ensure to set your Hugging Face token in the environment
13
- client = InferenceClient(token=HF_TOKEN)
 
14
 
15
  @myapp.route('/')
16
  def home():
17
  return "Welcome to the Image Background Remover!"
18
 
 
 
 
19
 
20
- # Function to generate an image from a prompt
21
- def generate_image(prompt, seed=1, model="prompthero/openjourney-v4"):
 
 
 
 
22
  try:
23
- result_image = client.text_to_image(prompt=prompt, seed=seed, model=model)
24
- return result_image
25
- except Exception as e:
26
- print(f"Error generating image: {str(e)}")
27
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Flask route for the API endpoint
30
  @myapp.route('/generate_image', methods=['POST'])
31
- def generate_api():
32
  data = request.get_json()
33
 
34
  # Extract required fields from the request
@@ -39,20 +68,23 @@ def generate_api():
39
  if not prompt:
40
  return jsonify({"error": "Prompt is required"}), 400
41
 
 
 
 
 
 
42
  try:
43
- # Call the generate_image function
44
- image = generate_image(prompt, seed, model_name)
45
-
46
- if image:
47
- # Save the image to a BytesIO object to send as response
48
- image_bytes = BytesIO()
49
- image.save(image_bytes, format='PNG')
50
- image_bytes.seek(0) # Go to the start of the byte stream
51
- return send_file(image_bytes, mimetype='image/png', as_attachment=True, download_name='generated_image.png')
52
  else:
53
  return jsonify({"error": "Failed to generate image"}), 500
54
  except Exception as e:
55
- print(f"Error in generate_api: {str(e)}") # Log the error
56
  return jsonify({"error": str(e)}), 500
57
 
58
  # Add this block to make sure your app runs when called
 
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
3
+ import asyncio
4
+ import tempfile
5
  import os
6
+ from threading import RLock
7
  from huggingface_hub import InferenceClient
8
+ from PIL import Image # Import Pillow
9
  from io import BytesIO # For converting image to bytes
10
 
 
11
  myapp = Flask(__name__)
12
  CORS(myapp) # Enable CORS for all routes
13
 
14
+ lock = RLock()
15
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
16
+
17
+ inference_timeout = 600 # Set timeout for inference
18
 
19
  @myapp.route('/')
20
  def home():
21
  return "Welcome to the Image Background Remover!"
22
 
23
+ # Function to dynamically load models from the "models" list
24
+ def get_model_from_name(model_name):
25
+ return model_name if model_name in models else None
26
 
27
+ # Asynchronous function to perform inference
28
+ async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
29
+ task = asyncio.create_task(
30
+ asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
31
+ )
32
+ await asyncio.sleep(0)
33
  try:
34
+ result = await asyncio.wait_for(task, timeout=timeout)
35
+ except (Exception, asyncio.TimeoutError) as e:
36
+ print(e)
37
+ print(f"Task timed out for model: {model}")
38
+ if not task.done():
39
+ task.cancel()
40
+ result = None
41
+
42
+ if task.done() and result is not None:
43
+ with lock:
44
+ # Convert image result to bytes using Pillow
45
+ image_bytes = BytesIO()
46
+ # Assuming result is an image object from huggingface_hub
47
+ result.save(image_bytes, format='PNG') # Save the image to a BytesIO object
48
+ image_bytes.seek(0) # Go to the start of the byte stream
49
+
50
+ # Save the result image as a temporary file
51
+ temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
52
+ with open(temp_image.name, "wb") as f:
53
+ f.write(image_bytes.read()) # Write the bytes to the temp file
54
+
55
+ return temp_image.name # Return the path to the saved image
56
+ return None
57
 
58
  # Flask route for the API endpoint
59
  @myapp.route('/generate_image', methods=['POST'])
60
+ def generate_image():
61
  data = request.get_json()
62
 
63
  # Extract required fields from the request
 
68
  if not prompt:
69
  return jsonify({"error": "Prompt is required"}), 400
70
 
71
+ # Get the model from all_models
72
+ model = get_model_from_name(model_name)
73
+ if not model:
74
+ return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
75
+
76
  try:
77
+ # Create a generic InferenceClient for the model
78
+ client = InferenceClient(token=HF_TOKEN)
79
+
80
+ # Call the async inference function
81
+ result_path = asyncio.run(infer(client, prompt, seed, model=model))
82
+ if result_path:
83
+ return send_file(result_path, mimetype='image/png') # Send back the generated image file
 
 
84
  else:
85
  return jsonify({"error": "Failed to generate image"}), 500
86
  except Exception as e:
87
+ print(f"Error in generate_image: {str(e)}") # Log the error
88
  return jsonify({"error": str(e)}), 500
89
 
90
  # Add this block to make sure your app runs when called