Spaces:
Running
Running
File size: 1,672 Bytes
04b77a5 98b1c51 ddbe7d2 04b77a5 98b1c51 04b77a5 ddbe7d2 29f5c0c 04b77a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import { z } from "zod";
import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";
import { env } from "$env/dynamic/private";
import { logger } from "$lib/server/logger";
export const embeddingEndpointHfApiSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("hfapi"),
authorization: z
.string()
.optional()
.transform((v) => (!v && env.HF_TOKEN ? "Bearer " + env.HF_TOKEN : v)), // if the header is not set but HF_TOKEN is, use it as the authorization header
});
export async function embeddingEndpointHfApi(
input: z.input<typeof embeddingEndpointHfApiSchema>
): Promise<EmbeddingEndpoint> {
const { model, authorization } = embeddingEndpointHfApiSchema.parse(input);
const url = "https://api-inference.huggingface.co/models/" + model.id;
return async ({ inputs }) => {
const batchesInputs = chunk(inputs, 128);
const batchesResults = await Promise.all(
batchesInputs.map(async (batchInputs) => {
const response = await fetch(`${url}`, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
...(authorization ? { Authorization: authorization } : {}),
},
body: JSON.stringify({ inputs: batchInputs }),
});
if (!response.ok) {
logger.error(await response.text());
logger.error("Failed to get embeddings from Hugging Face API", response);
return [];
}
const embeddings: Embedding[] = await response.json();
return embeddings;
})
);
const flatAllEmbeddings = batchesResults.flat();
return flatAllEmbeddings;
};
}
|