Spaces:
Running
Running
Commit
·
d2f7c95
1
Parent(s):
3e623a9
fix Inference API slownless
Browse files
src/app/api/providers/huggingface/predictWithHuggingFace.ts
CHANGED
@@ -9,6 +9,7 @@ export async function predict({
|
|
9 |
systemPrompt,
|
10 |
userPrompt,
|
11 |
nbMaxNewTokens,
|
|
|
12 |
prefix,
|
13 |
}: LLMPredictionFunctionParams): Promise<string> {
|
14 |
|
@@ -18,7 +19,10 @@ export async function predict({
|
|
18 |
try {
|
19 |
for await (const output of hf.textGenerationStream({
|
20 |
// model: "mistralai/Mixtral-8x7B-v0.1",
|
21 |
-
model:
|
|
|
|
|
|
|
22 |
inputs: createZephyrPrompt([
|
23 |
{ role: "system", content: systemPrompt },
|
24 |
{ role: "user", content: userPrompt }
|
|
|
9 |
systemPrompt,
|
10 |
userPrompt,
|
11 |
nbMaxNewTokens,
|
12 |
+
turbo,
|
13 |
prefix,
|
14 |
}: LLMPredictionFunctionParams): Promise<string> {
|
15 |
|
|
|
19 |
try {
|
20 |
for await (const output of hf.textGenerationStream({
|
21 |
// model: "mistralai/Mixtral-8x7B-v0.1",
|
22 |
+
model:
|
23 |
+
turbo
|
24 |
+
? "HuggingFaceH4/zephyr-7b-beta"
|
25 |
+
: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
26 |
inputs: createZephyrPrompt([
|
27 |
{ role: "system", content: systemPrompt },
|
28 |
{ role: "user", content: userPrompt }
|
src/app/api/providers/types.ts
CHANGED
@@ -16,6 +16,7 @@ export type LLMPredictionFunctionParams = {
|
|
16 |
systemPrompt: string
|
17 |
userPrompt: string
|
18 |
nbMaxNewTokens: number
|
|
|
19 |
prefix?: string
|
20 |
// llmVendorConfig: LLMVendorConfig
|
21 |
}
|
|
|
16 |
systemPrompt: string
|
17 |
userPrompt: string
|
18 |
nbMaxNewTokens: number
|
19 |
+
turbo?: boolean
|
20 |
prefix?: string
|
21 |
// llmVendorConfig: LLMVendorConfig
|
22 |
}
|
src/app/api/v1/create/index.ts
CHANGED
@@ -15,10 +15,12 @@ export async function create(request: {
|
|
15 |
prompt?: string
|
16 |
width?: number
|
17 |
height?: number
|
|
|
18 |
}= {
|
19 |
prompt: "",
|
20 |
width: 1024,
|
21 |
height: 576,
|
|
|
22 |
}): Promise<ClapProject> {
|
23 |
const prompt = `${request?.prompt || ""}`.trim()
|
24 |
|
@@ -28,7 +30,8 @@ export async function create(request: {
|
|
28 |
|
29 |
const width = getValidNumber(request?.width, 256, 8192, 1024)
|
30 |
const height = getValidNumber(request?.height, 256, 8192, 576)
|
31 |
-
|
|
|
32 |
const userPrompt = `Movie story to generate: ${prompt}
|
33 |
|
34 |
Output: `
|
@@ -46,6 +49,7 @@ Output: `
|
|
46 |
userPrompt,
|
47 |
nbMaxNewTokens,
|
48 |
prefix,
|
|
|
49 |
})
|
50 |
|
51 |
console.log("api/v1/create(): rawString: ", rawString)
|
@@ -64,6 +68,7 @@ Output: `
|
|
64 |
userPrompt: userPrompt + ".", // we trick the Hugging Face cache
|
65 |
nbMaxNewTokens,
|
66 |
prefix,
|
|
|
67 |
})
|
68 |
|
69 |
console.log("api/v1/create(): rawString: ", rawString)
|
|
|
15 |
prompt?: string
|
16 |
width?: number
|
17 |
height?: number
|
18 |
+
turbo?: boolean
|
19 |
}= {
|
20 |
prompt: "",
|
21 |
width: 1024,
|
22 |
height: 576,
|
23 |
+
turbo: false,
|
24 |
}): Promise<ClapProject> {
|
25 |
const prompt = `${request?.prompt || ""}`.trim()
|
26 |
|
|
|
30 |
|
31 |
const width = getValidNumber(request?.width, 256, 8192, 1024)
|
32 |
const height = getValidNumber(request?.height, 256, 8192, 576)
|
33 |
+
const turbo = request?.turbo ? true : false
|
34 |
+
|
35 |
const userPrompt = `Movie story to generate: ${prompt}
|
36 |
|
37 |
Output: `
|
|
|
49 |
userPrompt,
|
50 |
nbMaxNewTokens,
|
51 |
prefix,
|
52 |
+
turbo,
|
53 |
})
|
54 |
|
55 |
console.log("api/v1/create(): rawString: ", rawString)
|
|
|
68 |
userPrompt: userPrompt + ".", // we trick the Hugging Face cache
|
69 |
nbMaxNewTokens,
|
70 |
prefix,
|
71 |
+
turbo,
|
72 |
})
|
73 |
|
74 |
console.log("api/v1/create(): rawString: ", rawString)
|
src/app/api/v1/create/route.ts
CHANGED
@@ -14,6 +14,7 @@ export async function POST(req: NextRequest) {
|
|
14 |
prompt: string
|
15 |
width: number
|
16 |
height: number
|
|
|
17 |
// can add more stuff for the V2 of Stories Factory
|
18 |
}
|
19 |
|
@@ -22,7 +23,8 @@ export async function POST(req: NextRequest) {
|
|
22 |
const clap = await create({
|
23 |
prompt: `${request?.prompt || ""}`.trim(),
|
24 |
width: getValidNumber(request?.width, 256, 8192, 1024),
|
25 |
-
height: getValidNumber(request?.height, 256, 8192, 576)
|
|
|
26 |
})
|
27 |
|
28 |
// TODO replace by Clap file streaming
|
|
|
14 |
prompt: string
|
15 |
width: number
|
16 |
height: number
|
17 |
+
turbo: boolean
|
18 |
// can add more stuff for the V2 of Stories Factory
|
19 |
}
|
20 |
|
|
|
23 |
const clap = await create({
|
24 |
prompt: `${request?.prompt || ""}`.trim(),
|
25 |
width: getValidNumber(request?.width, 256, 8192, 1024),
|
26 |
+
height: getValidNumber(request?.height, 256, 8192, 576),
|
27 |
+
turbo: request?.turbo ? true : false,
|
28 |
})
|
29 |
|
30 |
// TODO replace by Clap file streaming
|