from flask import Flask, request, send_file, Response, jsonify from flask_cors import CORS import numpy as np import io import torch import cv2 from segment_anything import sam_model_registry, SamAutomaticMaskGenerator from PIL import Image import zipfile app = Flask(__name__) CORS(app) cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu" print(cudaOrNah) # Global model setup # running out of memory adjusted # checkpoint = "sam_vit_h_4b8939.pth" # model_type = "vit_h" checkpoint = "sam_vit_l_0b3195.pth" model_type = "vit_l" sam = sam_model_registry[model_type](checkpoint=checkpoint) sam.to(device=cudaOrNah) mask_generator = SamAutomaticMaskGenerator( model=sam, min_mask_region_area=0.0015 # Adjust this value as needed ) print('Setup SAM model') @app.route('/') def hello(): return {"hei": "Shredded to peices"} @app.route('/health', methods=['GET']) def health_check(): # Simple health check endpoint return jsonify({"status": "ok"}), 200 @app.route('/get-masks', methods=['POST']) def get_masks(): try: print('received image from frontend') # Get the image file from the request if 'image' not in request.files: return jsonify({"error": "No image file provided"}), 400 image_file = request.files['image'] if image_file.filename == '': return jsonify({"error": "No image file provided"}), 400 # Read image file using OpenCV-style approach (similar to cv2.imread)s # Convert the image file to a NumPy array using OpenCV file_bytes = np.fromstring(image_file.read(), np.uint8) image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) # Convert BGR to RGB using OpenCV (similar to cv2.cvtColor) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image is None: raise ValueError("Image not found or unable to read.") if cudaOrNah == "cuda": torch.cuda.empty_cache() masks = mask_generator.generate(image) if cudaOrNah == "cuda": torch.cuda.empty_cache() masks = sorted(masks, key=(lambda x: x['area']), reverse=True) def is_background(segmentation): val = (segmentation[10, 10] or segmentation[-10, 10] or segmentation[10, -10] or segmentation[-10, -10]) return val masks = [mask for mask in masks if not is_background(mask['segmentation'])] for i in range(0, len(masks) - 1)[::-1]: large_mask = masks[i]['segmentation'] for j in range(i+1, len(masks)): not_small_mask = np.logical_not(masks[j]['segmentation']) masks[i]['segmentation'] = np.logical_and(large_mask, not_small_mask) masks[i]['area'] = masks[i]['segmentation'].sum() large_mask = masks[i]['segmentation'] def sum_under_threshold(segmentation, threshold): return segmentation.sum() / segmentation.size < 0.0015 masks = [mask for mask in masks if not sum_under_threshold(mask['segmentation'], 100)] masks = sorted(masks, key=(lambda x: x['area']), reverse=True) # Create a zip file in memory zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: for idx, mask in enumerate(masks): alpha = mask['segmentation'].astype('uint8') * 255 mask_image = Image.fromarray(alpha) mask_io = io.BytesIO() mask_image.save(mask_io, format="PNG") mask_io.seek(0) zip_file.writestr(f'mask_{idx+1}.png', mask_io.read()) zip_buffer.seek(0) return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip') except Exception as e: # Log the error message if needed print(f"Error processing the image: {e}") # Return a JSON response with the error message and a 400 Bad Request status return jsonify({"error": "Error processing the image", "details": str(e)}), 400 if __name__ == '__main__': app.run(debug=True)