|
from flask import Flask, request, jsonify, send_file |
|
from flask_cors import CORS |
|
import os |
|
from huggingface_hub import InferenceClient |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
client = InferenceClient(token=HF_TOKEN) |
|
|
|
@app.route('/') |
|
def home(): |
|
return "Welcome to the Image Background Remover!" |
|
|
|
|
|
|
|
def generate_image(prompt, negative_prompt=None, height=512, width=512, model="stabilityai/stable-diffusion-2-1", num_inference_steps=50, guidance_scale=7.5, seed=None): |
|
try: |
|
|
|
image = client.text_to_image( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
height=height, |
|
width=width, |
|
model=model, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
seed=seed |
|
) |
|
return image |
|
except Exception as e: |
|
print(f"Error generating image: {str(e)}") |
|
return None |
|
|
|
|
|
@app.route('/generate_image', methods=['POST']) |
|
def generate_api(): |
|
data = request.get_json() |
|
|
|
|
|
prompt = data.get('prompt', '') |
|
negative_prompt = data.get('negative_prompt', None) |
|
height = data.get('height', 1024) |
|
width = data.get('width', 720) |
|
num_inference_steps = data.get('num_inference_steps', 50) |
|
guidance_scale = data.get('guidance_scale', 7.5) |
|
model_name = data.get('model', 'stabilityai/stable-diffusion-2-1') |
|
seed = data.get('seed', None) |
|
|
|
if not prompt: |
|
return jsonify({"error": "Prompt is required"}), 400 |
|
|
|
try: |
|
|
|
image = generate_image(prompt, negative_prompt, height, width, model_name, num_inference_steps, guidance_scale, seed) |
|
|
|
if image: |
|
|
|
img_byte_arr = BytesIO() |
|
image.save(img_byte_arr, format='PNG') |
|
img_byte_arr.seek(0) |
|
|
|
|
|
return send_file( |
|
img_byte_arr, |
|
mimetype='image/png', |
|
as_attachment=False, |
|
download_name='generated_image.png' |
|
) |
|
else: |
|
return jsonify({"error": "Failed to generate image"}), 500 |
|
except Exception as e: |
|
print(f"Error in generate_api: {str(e)}") |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run(host='0.0.0.0', port=7860) |