Luke nsarrazin HF staff commited on
Commit
2ce3b4b
·
unverified ·
1 Parent(s): 7d34920

Support Gemini 1.5 Pro from Vertex AI (#1041)

Browse files

* fix support for gemini on vertex ai

- use native messages api
- provide full content
- hide "Continue" button
- support "safety settings"
- fix confusion in readme

* respect new lines in model description

* ignore google service accounts matching `gcp-*.json`

* type checks

* copy service account, if exists, to container

* narrow response type

* fix streaming generation

---------

Co-authored-by: Nathan Sarrazin <[email protected]>

.gitignore CHANGED
@@ -11,4 +11,5 @@ SECRET_CONFIG
11
  .idea
12
  !.env.ci
13
  !.env
14
- !.env.template
 
 
11
  .idea
12
  !.env.ci
13
  !.env
14
+ !.env.template
15
+ gcp-*.json
Dockerfile CHANGED
@@ -37,5 +37,6 @@ ENV HOME=/home/user \
37
  COPY --from=builder-production --chown=1000 /app/node_modules /app/node_modules
38
  COPY --link --chown=1000 package.json /app/package.json
39
  COPY --from=builder --chown=1000 /app/build /app/build
 
40
 
41
  CMD pm2 start /app/build/index.js -i $CPU_CORES --no-daemon
 
37
  COPY --from=builder-production --chown=1000 /app/node_modules /app/node_modules
38
  COPY --link --chown=1000 package.json /app/package.json
39
  COPY --from=builder --chown=1000 /app/build /app/build
40
+ COPY --chown=1000 gcp-*.json /app/
41
 
42
  CMD pm2 start /app/build/index.js -i $CPU_CORES --no-daemon
README.md CHANGED
@@ -601,18 +601,24 @@ The service account credentials file can be imported as an environmental variabl
601
  GOOGLE_APPLICATION_CREDENTIALS = clientid.json
602
  ```
603
 
604
- Make sure docker has access to the file. Afterwards Google Vertex endpoints can be configured as following:
 
605
 
606
  ```
607
  MODELS=`[
608
  //...
609
  {
610
- "name": "gemini-1.0-pro", //model-name
611
- "displayName": "Vertex Gemini Pro 1.0",
612
- "location": "europe-west3",
613
- "apiEndpoint": "", //alternative api endpoint url
614
  "endpoints" : [{
615
- "type": "vertex"
 
 
 
 
 
 
 
616
  }]
617
  },
618
  ]`
 
601
  GOOGLE_APPLICATION_CREDENTIALS = clientid.json
602
  ```
603
 
604
+ Make sure your docker container has access to the file and the variable is correctly set.
605
+ Afterwards Google Vertex endpoints can be configured as following:
606
 
607
  ```
608
  MODELS=`[
609
  //...
610
  {
611
+ "name": "gemini-1.5-pro",
612
+ "displayName": "Vertex Gemini Pro 1.5",
 
 
613
  "endpoints" : [{
614
+ "type": "vertex",
615
+ "project": "abc-xyz",
616
+ "location": "europe-west3",
617
+ "model": "gemini-1.5-pro-preview-0409", // model-name
618
+
619
+ // Optional
620
+ "safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
621
+ "apiEndpoint": "", // alternative api endpoint url
622
  }]
623
  },
624
  ]`
package-lock.json CHANGED
@@ -8,6 +8,7 @@
8
  "name": "chat-ui",
9
  "version": "0.8.2",
10
  "dependencies": {
 
11
  "@huggingface/hub": "^0.5.1",
12
  "@huggingface/inference": "^2.6.3",
13
  "@iconify-json/bi": "^1.1.21",
@@ -72,7 +73,7 @@
72
  },
73
  "optionalDependencies": {
74
  "@anthropic-ai/sdk": "^0.17.1",
75
- "@google-cloud/vertexai": "^0.5.0",
76
  "aws4fetch": "^1.0.17",
77
  "cohere-ai": "^7.9.0",
78
  "openai": "^4.14.2"
@@ -630,9 +631,9 @@
630
  }
631
  },
632
  "node_modules/@google-cloud/vertexai": {
633
- "version": "0.5.0",
634
- "resolved": "https://registry.npmjs.org/@google-cloud/vertexai/-/vertexai-0.5.0.tgz",
635
- "integrity": "sha512-qIFHYTXA5UCLdm9JG+Xf1suomCXxRqa1PKdYjqXuhZsCm8mn37Rb0Tf8djlhDzuRVWyWoQTmsWpsk28ZTmbqJg==",
636
  "optional": true,
637
  "dependencies": {
638
  "google-auth-library": "^9.1.0"
 
8
  "name": "chat-ui",
9
  "version": "0.8.2",
10
  "dependencies": {
11
+ "@google-cloud/vertexai": "^1.1.0",
12
  "@huggingface/hub": "^0.5.1",
13
  "@huggingface/inference": "^2.6.3",
14
  "@iconify-json/bi": "^1.1.21",
 
73
  },
74
  "optionalDependencies": {
75
  "@anthropic-ai/sdk": "^0.17.1",
76
+ "@google-cloud/vertexai": "^1.1.0",
77
  "aws4fetch": "^1.0.17",
78
  "cohere-ai": "^7.9.0",
79
  "openai": "^4.14.2"
 
631
  }
632
  },
633
  "node_modules/@google-cloud/vertexai": {
634
+ "version": "1.1.0",
635
+ "resolved": "https://registry.npmjs.org/@google-cloud/vertexai/-/vertexai-1.1.0.tgz",
636
+ "integrity": "sha512-hfwfdlVpJ+kM6o2b5UFfPnweBcz8tgHAFRswnqUKYqLJsvKU0DDD0Z2/YKoHyAUoPJAv20qg6KlC3msNeUKUiw==",
637
  "optional": true,
638
  "dependencies": {
639
  "google-auth-library": "^9.1.0"
package.json CHANGED
@@ -82,7 +82,7 @@
82
  },
83
  "optionalDependencies": {
84
  "@anthropic-ai/sdk": "^0.17.1",
85
- "@google-cloud/vertexai": "^0.5.0",
86
  "aws4fetch": "^1.0.17",
87
  "cohere-ai": "^7.9.0",
88
  "openai": "^4.14.2"
 
82
  },
83
  "optionalDependencies": {
84
  "@anthropic-ai/sdk": "^0.17.1",
85
+ "@google-cloud/vertexai": "^1.1.0",
86
  "aws4fetch": "^1.0.17",
87
  "cohere-ai": "^7.9.0",
88
  "openai": "^4.14.2"
src/lib/server/endpoints/google/endpointVertex.ts CHANGED
@@ -1,8 +1,14 @@
1
- import { VertexAI, HarmCategory, HarmBlockThreshold } from "@google-cloud/vertexai";
2
- import { buildPrompt } from "$lib/buildPrompt";
3
- import type { TextGenerationStreamOutput } from "@huggingface/inference";
 
 
 
 
4
  import type { Endpoint } from "../endpoints";
5
  import { z } from "zod";
 
 
6
 
7
  export const endpointVertexParametersSchema = z.object({
8
  weight: z.number().int().positive().default(1),
@@ -11,10 +17,20 @@ export const endpointVertexParametersSchema = z.object({
11
  location: z.string().default("europe-west1"),
12
  project: z.string(),
13
  apiEndpoint: z.string().optional(),
 
 
 
 
 
 
 
 
 
14
  });
15
 
16
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
17
- const { project, location, model, apiEndpoint } = endpointVertexParametersSchema.parse(input);
 
18
 
19
  const vertex_ai = new VertexAI({
20
  project,
@@ -22,55 +38,104 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
22
  apiEndpoint,
23
  });
24
 
25
- const generativeModel = vertex_ai.getGenerativeModel({
26
- model: model.id ?? model.name,
27
- safety_settings: [
28
- {
29
- category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
30
- threshold: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  },
32
- ],
33
- generation_config: {},
34
- });
35
 
36
- return async ({ messages, preprompt, continueMessage }) => {
37
- const prompt = await buildPrompt({
38
- messages,
39
- continueMessage,
40
- preprompt,
41
- model,
 
 
 
 
 
 
 
 
 
 
42
  });
43
 
44
- const chat = generativeModel.startChat();
45
- const result = await chat.sendMessageStream(prompt);
46
- let tokenId = 0;
 
 
 
 
 
 
 
 
 
 
47
 
 
48
  return (async function* () {
49
  let generatedText = "";
50
 
51
  for await (const data of result.stream) {
52
- if (Array.isArray(data?.candidates) && data.candidates.length > 0) {
53
- const firstPart = data.candidates[0].content.parts[0];
54
- if ("text" in firstPart) {
55
- const content = firstPart.text;
56
- generatedText += content;
57
- const output: TextGenerationStreamOutput = {
58
- token: {
59
- id: tokenId++,
60
- text: content ?? "",
61
- logprob: 0,
62
- special: false,
63
- },
64
- generated_text: generatedText,
65
- details: null,
66
- };
67
- yield output;
68
- }
 
 
 
 
 
 
 
 
69
 
70
- if (!data.candidates.slice(-1)[0].finishReason) break;
71
- } else {
72
- break;
73
- }
74
  }
75
  })();
76
  };
 
1
+ import {
2
+ VertexAI,
3
+ HarmCategory,
4
+ HarmBlockThreshold,
5
+ type Content,
6
+ type TextPart,
7
+ } from "@google-cloud/vertexai";
8
  import type { Endpoint } from "../endpoints";
9
  import { z } from "zod";
10
+ import type { Message } from "$lib/types/Message";
11
+ import type { TextGenerationStreamOutput } from "@huggingface/inference";
12
 
13
  export const endpointVertexParametersSchema = z.object({
14
  weight: z.number().int().positive().default(1),
 
17
  location: z.string().default("europe-west1"),
18
  project: z.string(),
19
  apiEndpoint: z.string().optional(),
20
+ safetyThreshold: z
21
+ .enum([
22
+ HarmBlockThreshold.HARM_BLOCK_THRESHOLD_UNSPECIFIED,
23
+ HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
24
+ HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
25
+ HarmBlockThreshold.BLOCK_NONE,
26
+ HarmBlockThreshold.BLOCK_ONLY_HIGH,
27
+ ])
28
+ .optional(),
29
  });
30
 
31
  export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
32
+ const { project, location, model, apiEndpoint, safetyThreshold } =
33
+ endpointVertexParametersSchema.parse(input);
34
 
35
  const vertex_ai = new VertexAI({
36
  project,
 
38
  apiEndpoint,
39
  });
40
 
41
+ return async ({ messages, preprompt, generateSettings }) => {
42
+ const generativeModel = vertex_ai.getGenerativeModel({
43
+ model: model.id ?? model.name,
44
+ safetySettings: safetyThreshold
45
+ ? [
46
+ {
47
+ category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
48
+ threshold: safetyThreshold,
49
+ },
50
+ {
51
+ category: HarmCategory.HARM_CATEGORY_HARASSMENT,
52
+ threshold: safetyThreshold,
53
+ },
54
+ {
55
+ category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
56
+ threshold: safetyThreshold,
57
+ },
58
+ {
59
+ category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
60
+ threshold: safetyThreshold,
61
+ },
62
+ {
63
+ category: HarmCategory.HARM_CATEGORY_UNSPECIFIED,
64
+ threshold: safetyThreshold,
65
+ },
66
+ ]
67
+ : undefined,
68
+ generationConfig: {
69
+ maxOutputTokens: generateSettings?.max_new_tokens ?? 4096,
70
+ stopSequences: generateSettings?.stop,
71
+ temperature: generateSettings?.temperature ?? 1,
72
  },
73
+ });
 
 
74
 
75
+ // Preprompt is the same as the first system message.
76
+ let systemMessage = preprompt;
77
+ if (messages[0].from === "system") {
78
+ systemMessage = messages[0].content;
79
+ messages.shift();
80
+ }
81
+
82
+ const vertexMessages = messages.map(({ from, content }: Omit<Message, "id">): Content => {
83
+ return {
84
+ role: from === "user" ? "user" : "model",
85
+ parts: [
86
+ {
87
+ text: content,
88
+ },
89
+ ],
90
+ };
91
  });
92
 
93
+ const result = await generativeModel.generateContentStream({
94
+ contents: vertexMessages,
95
+ systemInstruction: systemMessage
96
+ ? {
97
+ role: "system",
98
+ parts: [
99
+ {
100
+ text: systemMessage,
101
+ },
102
+ ],
103
+ }
104
+ : undefined,
105
+ });
106
 
107
+ let tokenId = 0;
108
  return (async function* () {
109
  let generatedText = "";
110
 
111
  for await (const data of result.stream) {
112
+ if (!data?.candidates?.length) break; // Handle case where no candidates are present
113
+
114
+ const candidate = data.candidates[0];
115
+ if (!candidate.content?.parts?.length) continue; // Skip if no parts are present
116
+
117
+ const firstPart = candidate.content.parts.find((part) => "text" in part) as
118
+ | TextPart
119
+ | undefined;
120
+ if (!firstPart) continue; // Skip if no text part is found
121
+
122
+ const isLastChunk = !!candidate.finishReason;
123
+
124
+ const content = firstPart.text;
125
+ generatedText += content;
126
+ const output: TextGenerationStreamOutput = {
127
+ token: {
128
+ id: tokenId++,
129
+ text: content,
130
+ logprob: 0,
131
+ special: isLastChunk,
132
+ },
133
+ generated_text: isLastChunk ? generatedText : null,
134
+ details: null,
135
+ };
136
+ yield output;
137
 
138
+ if (isLastChunk) break;
 
 
 
139
  }
140
  })();
141
  };
src/routes/models/+page.svelte CHANGED
@@ -64,7 +64,9 @@
64
  <dt class="flex items-center gap-2 font-semibold">
65
  {model.displayName}
66
  </dt>
67
- <dd class="text-sm text-gray-500 dark:text-gray-400">{model.description || "-"}</dd>
 
 
68
  </a>
69
  {/each}
70
  </dl>
 
64
  <dt class="flex items-center gap-2 font-semibold">
65
  {model.displayName}
66
  </dt>
67
+ <dd class="whitespace-pre-wrap text-sm text-gray-500 dark:text-gray-400">
68
+ {model.description || "-"}
69
+ </dd>
70
  </a>
71
  {/each}
72
  </dl>