File size: 4,860 Bytes
7f7bf76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
//load the candle SAM Model wasm module
import init, { Model } from "./build/m.js";
async function fetchArrayBuffer(url, cacheModel = true) {
if (!cacheModel)
return new Uint8Array(await (await fetch(url)).arrayBuffer());
const cacheName = "sam-candle-cache";
const cache = await caches.open(cacheName);
const cachedResponse = await cache.match(url);
if (cachedResponse) {
const data = await cachedResponse.arrayBuffer();
return new Uint8Array(data);
}
const res = await fetch(url, { cache: "force-cache" });
cache.put(url, res.clone());
return new Uint8Array(await res.arrayBuffer());
}
class SAMModel {
static instance = {};
// keep current image embeddings state
static imageArrayHash = {};
// Add a new property to hold the current modelID
static currentModelID = null;
static async getInstance(modelURL, modelID) {
if (!this.instance[modelID]) {
await init();
self.postMessage({
status: "loading",
message: `Loading Model ${modelID}`,
});
const weightsArrayU8 = await fetchArrayBuffer(modelURL);
this.instance[modelID] = new Model(
weightsArrayU8,
/tiny|mobile/.test(modelID)
);
} else {
self.postMessage({ status: "loading", message: "Model Already Loaded" });
}
// Set the current modelID to the modelID that was passed in
this.currentModelID = modelID;
return this.instance[modelID];
}
// Remove the modelID parameter from setImageEmbeddings
static setImageEmbeddings(imageArrayU8) {
// check if image embeddings are already set for this image and model
const imageArrayHash = this.getSimpleHash(imageArrayU8);
if (
this.imageArrayHash[this.currentModelID] === imageArrayHash &&
this.instance[this.currentModelID]
) {
self.postMessage({
status: "embedding",
message: "Embeddings Already Set",
});
return;
}
this.imageArrayHash[this.currentModelID] = imageArrayHash;
this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
self.postMessage({ status: "embedding", message: "Embeddings Set" });
}
static getSimpleHash(imageArrayU8) {
// get simple hash of imageArrayU8
let imageArrayHash = 0;
for (let i = 0; i < imageArrayU8.length; i += 100) {
imageArrayHash ^= imageArrayU8[i];
}
return imageArrayHash.toString(16);
}
}
async function createImageCanvas(
{ mask_shape, mask_data }, // mask
{ original_width, original_height, width, height } // original image
) {
const [_, __, shape_width, shape_height] = mask_shape;
const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask
const maskCtx = maskCanvas.getContext("2d");
const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size
const ctx = canvas.getContext("2d");
const imageData = maskCtx.createImageData(
maskCanvas.width,
maskCanvas.height
);
const data = imageData.data;
for (let p = 0; p < data.length; p += 4) {
data[p] = 0;
data[p + 1] = 0;
data[p + 2] = 0;
data[p + 3] = mask_data[p / 4] * 255;
}
maskCtx.putImageData(imageData, 0, 0);
let sx, sy;
if (original_height < original_width) {
sy = original_height / original_width;
sx = 1;
} else {
sy = 1;
sx = original_width / original_height;
}
ctx.drawImage(
maskCanvas,
0,
0,
maskCanvas.width * sx,
maskCanvas.height * sy,
0,
0,
original_width,
original_height
);
const blob = await canvas.convertToBlob();
return URL.createObjectURL(blob);
}
self.addEventListener("message", async (event) => {
const { modelURL, modelID, imageURL, points } = event.data;
try {
self.postMessage({ status: "loading", message: "Starting SAM" });
const sam = await SAMModel.getInstance(modelURL, modelID);
self.postMessage({ status: "loading", message: "Loading Image" });
const imageArrayU8 = await fetchArrayBuffer(imageURL, false);
self.postMessage({ status: "embedding", message: "Creating Embeddings" });
SAMModel.setImageEmbeddings(imageArrayU8);
if (!points) {
// no points only do the embeddings
self.postMessage({
status: "complete-embedding",
message: "Embeddings Complete",
});
return;
}
self.postMessage({ status: "segmenting", message: "Segmenting" });
const result = sam.mask_for_point(points.x, points.y);
const { mask, image } = JSON.parse(result);
const maskDataURL = await createImageCanvas(mask, image);
// Send the segment back to the main thread as JSON
self.postMessage({
status: "complete",
message: "Segmentation Complete",
output: { maskURL: maskDataURL },
});
} catch (e) {
self.postMessage({ error: e });
}
});
|