File size: 4,860 Bytes
7f7bf76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
//load the candle SAM Model wasm module
import init, { Model } from "./build/m.js";

async function fetchArrayBuffer(url, cacheModel = true) {
  if (!cacheModel)
    return new Uint8Array(await (await fetch(url)).arrayBuffer());
  const cacheName = "sam-candle-cache";
  const cache = await caches.open(cacheName);
  const cachedResponse = await cache.match(url);
  if (cachedResponse) {
    const data = await cachedResponse.arrayBuffer();
    return new Uint8Array(data);
  }
  const res = await fetch(url, { cache: "force-cache" });
  cache.put(url, res.clone());
  return new Uint8Array(await res.arrayBuffer());
}
class SAMModel {
  static instance = {};
  // keep current image embeddings state
  static imageArrayHash = {};
  // Add a new property to hold the current modelID
  static currentModelID = null;

  static async getInstance(modelURL, modelID) {
    if (!this.instance[modelID]) {
      await init();

      self.postMessage({
        status: "loading",
        message: `Loading Model ${modelID}`,
      });
      const weightsArrayU8 = await fetchArrayBuffer(modelURL);
      this.instance[modelID] = new Model(
        weightsArrayU8,
        /tiny|mobile/.test(modelID)
      );
    } else {
      self.postMessage({ status: "loading", message: "Model Already Loaded" });
    }
    // Set the current modelID to the modelID that was passed in
    this.currentModelID = modelID;
    return this.instance[modelID];
  }

  // Remove the modelID parameter from setImageEmbeddings
  static setImageEmbeddings(imageArrayU8) {
    // check if image embeddings are already set for this image and model
    const imageArrayHash = this.getSimpleHash(imageArrayU8);
    if (
      this.imageArrayHash[this.currentModelID] === imageArrayHash &&
      this.instance[this.currentModelID]
    ) {
      self.postMessage({
        status: "embedding",
        message: "Embeddings Already Set",
      });
      return;
    }
    this.imageArrayHash[this.currentModelID] = imageArrayHash;
    this.instance[this.currentModelID].set_image_embeddings(imageArrayU8);
    self.postMessage({ status: "embedding", message: "Embeddings Set" });
  }

  static getSimpleHash(imageArrayU8) {
    // get simple hash of imageArrayU8
    let imageArrayHash = 0;
    for (let i = 0; i < imageArrayU8.length; i += 100) {
      imageArrayHash ^= imageArrayU8[i];
    }
    return imageArrayHash.toString(16);
  }
}

async function createImageCanvas(
  { mask_shape, mask_data }, // mask
  { original_width, original_height, width, height } // original image
) {
  const [_, __, shape_width, shape_height] = mask_shape;
  const maskCanvas = new OffscreenCanvas(shape_width, shape_height); // canvas for mask
  const maskCtx = maskCanvas.getContext("2d");
  const canvas = new OffscreenCanvas(original_width, original_height); // canvas for creating mask with original image size
  const ctx = canvas.getContext("2d");

  const imageData = maskCtx.createImageData(
    maskCanvas.width,
    maskCanvas.height
  );
  const data = imageData.data;

  for (let p = 0; p < data.length; p += 4) {
    data[p] = 0;
    data[p + 1] = 0;
    data[p + 2] = 0;
    data[p + 3] = mask_data[p / 4] * 255;
  }
  maskCtx.putImageData(imageData, 0, 0);

  let sx, sy;
  if (original_height < original_width) {
    sy = original_height / original_width;
    sx = 1;
  } else {
    sy = 1;
    sx = original_width / original_height;
  }
  ctx.drawImage(
    maskCanvas,
    0,
    0,
    maskCanvas.width * sx,
    maskCanvas.height * sy,
    0,
    0,
    original_width,
    original_height
  );

  const blob = await canvas.convertToBlob();
  return URL.createObjectURL(blob);
}

self.addEventListener("message", async (event) => {
  const { modelURL, modelID, imageURL, points } = event.data;
  try {
    self.postMessage({ status: "loading", message: "Starting SAM" });
    const sam = await SAMModel.getInstance(modelURL, modelID);

    self.postMessage({ status: "loading", message: "Loading Image" });
    const imageArrayU8 = await fetchArrayBuffer(imageURL, false);

    self.postMessage({ status: "embedding", message: "Creating Embeddings" });
    SAMModel.setImageEmbeddings(imageArrayU8);
    if (!points) {
      // no points only do the embeddings
      self.postMessage({
        status: "complete-embedding",
        message: "Embeddings Complete",
      });
      return;
    }

    self.postMessage({ status: "segmenting", message: "Segmenting" });
    const result = sam.mask_for_point(points.x, points.y);
    const { mask, image } = JSON.parse(result);
    const maskDataURL = await createImageCanvas(mask, image);
    // Send the segment back to the main thread as JSON
    self.postMessage({
      status: "complete",
      message: "Segmentation Complete",
      output: { maskURL: maskDataURL },
    });
  } catch (e) {
    self.postMessage({ error: e });
  }
});