Geek7 commited on
Commit
c23f4d4
1 Parent(s): 33fcfe8

Update myapp.py

Browse files
Files changed (1) hide show
  1. myapp.py +25 -54
myapp.py CHANGED
@@ -1,63 +1,35 @@
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
3
- import asyncio
4
- import tempfile
5
  import os
6
- from threading import RLock
7
  from huggingface_hub import InferenceClient
8
- from PIL import Image # Import Pillow
9
  from io import BytesIO # For converting image to bytes
 
10
 
 
11
  myapp = Flask(__name__)
12
  CORS(myapp) # Enable CORS for all routes
13
 
14
- lock = RLock()
15
- HF_TOKEN = os.environ.get("HF_TOKEN") # Hugging Face token
16
-
17
- inference_timeout = 600 # Set timeout for inference
18
 
19
  @myapp.route('/')
20
  def home():
21
  return "Welcome to the Image Background Remover!"
22
 
23
- # Function to dynamically load models from the "models" list
24
- def get_model_from_name(model_name):
25
- return model_name if model_name in models else None
26
-
27
- # Asynchronous function to perform inference
28
- async def infer(client, prompt, seed=1, timeout=inference_timeout, model="prompthero/openjourney-v4"):
29
- task = asyncio.create_task(
30
- asyncio.to_thread(client.text_to_image, prompt=prompt, seed=seed, model=model)
31
- )
32
- await asyncio.sleep(0)
33
  try:
34
- result = await asyncio.wait_for(task, timeout=timeout)
35
- except (Exception, asyncio.TimeoutError) as e:
36
- print(e)
37
- print(f"Task timed out for model: {model}")
38
- if not task.done():
39
- task.cancel()
40
- result = None
41
-
42
- if task.done() and result is not None:
43
- with lock:
44
- # Convert image result to bytes using Pillow
45
- image_bytes = BytesIO()
46
- # Assuming result is an image object from huggingface_hub
47
- result.save(image_bytes, format='PNG') # Save the image to a BytesIO object
48
- image_bytes.seek(0) # Go to the start of the byte stream
49
-
50
- # Save the result image as a temporary file
51
- temp_image = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
52
- with open(temp_image.name, "wb") as f:
53
- f.write(image_bytes.read()) # Write the bytes to the temp file
54
-
55
- return temp_image.name # Return the path to the saved image
56
- return None
57
 
58
  # Flask route for the API endpoint
59
  @myapp.route('/generate_image', methods=['POST'])
60
- def generate_image():
61
  data = request.get_json()
62
 
63
  # Extract required fields from the request
@@ -68,23 +40,22 @@ def generate_image():
68
  if not prompt:
69
  return jsonify({"error": "Prompt is required"}), 400
70
 
71
- # Get the model from all_models
72
- model = get_model_from_name(model_name)
73
- if not model:
74
- return jsonify({"error": f"Model '{model_name}' not found in available models"}), 400
75
-
76
  try:
77
- # Create a generic InferenceClient for the model
78
- client = InferenceClient(token=HF_TOKEN)
 
 
 
 
 
 
79
 
80
- # Call the async inference function
81
- result_path = asyncio.run(infer(client, prompt, seed, model=model))
82
- if result_path:
83
- return send_file(result_path, mimetype='image/png') # Send back the generated image file
84
  else:
85
  return jsonify({"error": "Failed to generate image"}), 500
86
  except Exception as e:
87
- print(f"Error in generate_image: {str(e)}") # Log the error
88
  return jsonify({"error": str(e)}), 500
89
 
90
  # Add this block to make sure your app runs when called
 
1
  from flask import Flask, request, jsonify, send_file
2
  from flask_cors import CORS
 
 
3
  import os
 
4
  from huggingface_hub import InferenceClient
 
5
  from io import BytesIO # For converting image to bytes
6
+ from PIL import Image # Import Pillow for image processing
7
 
8
+ # Initialize the Flask app
9
  myapp = Flask(__name__)
10
  CORS(myapp) # Enable CORS for all routes
11
 
12
+ # Initialize the InferenceClient with your Hugging Face token
13
+ HF_TOKEN = os.environ.get("HF_TOKEN") # Ensure to set your Hugging Face token in the environment
14
+ client = InferenceClient(token=HF_TOKEN)
 
15
 
16
  @myapp.route('/')
17
  def home():
18
  return "Welcome to the Image Background Remover!"
19
 
20
+ # Function to generate an image from a prompt
21
+ def generate_image(prompt, seed=1, model="prompthero/openjourney-v4"):
 
 
 
 
 
 
 
 
22
  try:
23
+ # Generate the image using Hugging Face's inference API with the given model
24
+ result_image = client.text_to_image(prompt=prompt, seed=seed, model=model)
25
+ return result_image
26
+ except Exception as e:
27
+ print(f"Error generating image: {str(e)}")
28
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Flask route for the API endpoint
31
  @myapp.route('/generate_image', methods=['POST'])
32
+ def generate_api():
33
  data = request.get_json()
34
 
35
  # Extract required fields from the request
 
40
  if not prompt:
41
  return jsonify({"error": "Prompt is required"}), 400
42
 
 
 
 
 
 
43
  try:
44
+ # Call the generate_image function with the custom model name
45
+ image = generate_image(prompt, seed, model_name)
46
+
47
+ if image:
48
+ # Save the image to a BytesIO object to send as response
49
+ image_bytes = BytesIO()
50
+ image.save(image_bytes, format='PNG')
51
+ image_bytes.seek(0) # Go to the start of the byte stream
52
 
53
+ # Send the generated image as a response with a download option
54
+ return send_file(image_bytes, mimetype='image/png', as_attachment=True, download_name='generated_image.png')
 
 
55
  else:
56
  return jsonify({"error": "Failed to generate image"}), 500
57
  except Exception as e:
58
+ print(f"Error in generate_api: {str(e)}") # Log the error
59
  return jsonify({"error": str(e)}), 500
60
 
61
  # Add this block to make sure your app runs when called