|
from flask import Flask, request, jsonify, send_file |
|
from flask_cors import CORS |
|
from gradio_client import Client |
|
from all_models import models |
|
|
|
myapp = Flask(__name__) |
|
CORS(myapp) |
|
|
|
|
|
client = Client("Geek7/mdztxi2") |
|
|
|
@app.route('/predict', methods=['POST']) |
|
def predict(): |
|
data = request.get_json() |
|
|
|
|
|
if not data or 'model_str' not in data or 'prompt' not in data or 'seed' not in data: |
|
return jsonify({"error": "Missing required fields"}), 400 |
|
|
|
model_str = data['model_str'] |
|
prompt = data['prompt'] |
|
seed = data['seed'] |
|
|
|
|
|
if model_str not in models: |
|
return jsonify({"error": f"Model '{model_str}' is not available."}), 400 |
|
|
|
try: |
|
|
|
result = client.predict( |
|
model_str=model_str, |
|
prompt=prompt, |
|
seed=seed, |
|
api_name="/predict" |
|
) |
|
|
|
|
|
result_path = result |
|
return send_file(result_path, mimetype='image/png') |
|
|
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
if __name__ == '__main__': |
|
myapp.run(debug=True) |