File size: 1,672 Bytes
04b77a5
 
 
98b1c51
ddbe7d2
04b77a5
 
 
 
 
 
 
 
98b1c51
04b77a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddbe7d2
 
29f5c0c
04b77a5
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import { z } from "zod";
import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";
import { env } from "$env/dynamic/private";
import { logger } from "$lib/server/logger";

export const embeddingEndpointHfApiSchema = z.object({
	weight: z.number().int().positive().default(1),
	model: z.any(),
	type: z.literal("hfapi"),
	authorization: z
		.string()
		.optional()
		.transform((v) => (!v && env.HF_TOKEN ? "Bearer " + env.HF_TOKEN : v)), // if the header is not set but HF_TOKEN is, use it as the authorization header
});

export async function embeddingEndpointHfApi(
	input: z.input<typeof embeddingEndpointHfApiSchema>
): Promise<EmbeddingEndpoint> {
	const { model, authorization } = embeddingEndpointHfApiSchema.parse(input);
	const url = "https://api-inference.huggingface.co/models/" + model.id;

	return async ({ inputs }) => {
		const batchesInputs = chunk(inputs, 128);

		const batchesResults = await Promise.all(
			batchesInputs.map(async (batchInputs) => {
				const response = await fetch(`${url}`, {
					method: "POST",
					headers: {
						Accept: "application/json",
						"Content-Type": "application/json",
						...(authorization ? { Authorization: authorization } : {}),
					},
					body: JSON.stringify({ inputs: batchInputs }),
				});

				if (!response.ok) {
					logger.error(await response.text());
					logger.error("Failed to get embeddings from Hugging Face API", response);
					return [];
				}

				const embeddings: Embedding[] = await response.json();
				return embeddings;
			})
		);

		const flatAllEmbeddings = batchesResults.flat();

		return flatAllEmbeddings;
	};
}