jbilcke-hf HF staff commited on
Commit
d2f7c95
·
1 Parent(s): 3e623a9

fix Inference API slownless

Browse files
src/app/api/providers/huggingface/predictWithHuggingFace.ts CHANGED
@@ -9,6 +9,7 @@ export async function predict({
9
  systemPrompt,
10
  userPrompt,
11
  nbMaxNewTokens,
 
12
  prefix,
13
  }: LLMPredictionFunctionParams): Promise<string> {
14
 
@@ -18,7 +19,10 @@ export async function predict({
18
  try {
19
  for await (const output of hf.textGenerationStream({
20
  // model: "mistralai/Mixtral-8x7B-v0.1",
21
- model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
 
 
 
22
  inputs: createZephyrPrompt([
23
  { role: "system", content: systemPrompt },
24
  { role: "user", content: userPrompt }
 
9
  systemPrompt,
10
  userPrompt,
11
  nbMaxNewTokens,
12
+ turbo,
13
  prefix,
14
  }: LLMPredictionFunctionParams): Promise<string> {
15
 
 
19
  try {
20
  for await (const output of hf.textGenerationStream({
21
  // model: "mistralai/Mixtral-8x7B-v0.1",
22
+ model:
23
+ turbo
24
+ ? "HuggingFaceH4/zephyr-7b-beta"
25
+ : "mistralai/Mixtral-8x7B-Instruct-v0.1",
26
  inputs: createZephyrPrompt([
27
  { role: "system", content: systemPrompt },
28
  { role: "user", content: userPrompt }
src/app/api/providers/types.ts CHANGED
@@ -16,6 +16,7 @@ export type LLMPredictionFunctionParams = {
16
  systemPrompt: string
17
  userPrompt: string
18
  nbMaxNewTokens: number
 
19
  prefix?: string
20
  // llmVendorConfig: LLMVendorConfig
21
  }
 
16
  systemPrompt: string
17
  userPrompt: string
18
  nbMaxNewTokens: number
19
+ turbo?: boolean
20
  prefix?: string
21
  // llmVendorConfig: LLMVendorConfig
22
  }
src/app/api/v1/create/index.ts CHANGED
@@ -15,10 +15,12 @@ export async function create(request: {
15
  prompt?: string
16
  width?: number
17
  height?: number
 
18
  }= {
19
  prompt: "",
20
  width: 1024,
21
  height: 576,
 
22
  }): Promise<ClapProject> {
23
  const prompt = `${request?.prompt || ""}`.trim()
24
 
@@ -28,7 +30,8 @@ export async function create(request: {
28
 
29
  const width = getValidNumber(request?.width, 256, 8192, 1024)
30
  const height = getValidNumber(request?.height, 256, 8192, 576)
31
-
 
32
  const userPrompt = `Movie story to generate: ${prompt}
33
 
34
  Output: `
@@ -46,6 +49,7 @@ Output: `
46
  userPrompt,
47
  nbMaxNewTokens,
48
  prefix,
 
49
  })
50
 
51
  console.log("api/v1/create(): rawString: ", rawString)
@@ -64,6 +68,7 @@ Output: `
64
  userPrompt: userPrompt + ".", // we trick the Hugging Face cache
65
  nbMaxNewTokens,
66
  prefix,
 
67
  })
68
 
69
  console.log("api/v1/create(): rawString: ", rawString)
 
15
  prompt?: string
16
  width?: number
17
  height?: number
18
+ turbo?: boolean
19
  }= {
20
  prompt: "",
21
  width: 1024,
22
  height: 576,
23
+ turbo: false,
24
  }): Promise<ClapProject> {
25
  const prompt = `${request?.prompt || ""}`.trim()
26
 
 
30
 
31
  const width = getValidNumber(request?.width, 256, 8192, 1024)
32
  const height = getValidNumber(request?.height, 256, 8192, 576)
33
+ const turbo = request?.turbo ? true : false
34
+
35
  const userPrompt = `Movie story to generate: ${prompt}
36
 
37
  Output: `
 
49
  userPrompt,
50
  nbMaxNewTokens,
51
  prefix,
52
+ turbo,
53
  })
54
 
55
  console.log("api/v1/create(): rawString: ", rawString)
 
68
  userPrompt: userPrompt + ".", // we trick the Hugging Face cache
69
  nbMaxNewTokens,
70
  prefix,
71
+ turbo,
72
  })
73
 
74
  console.log("api/v1/create(): rawString: ", rawString)
src/app/api/v1/create/route.ts CHANGED
@@ -14,6 +14,7 @@ export async function POST(req: NextRequest) {
14
  prompt: string
15
  width: number
16
  height: number
 
17
  // can add more stuff for the V2 of Stories Factory
18
  }
19
 
@@ -22,7 +23,8 @@ export async function POST(req: NextRequest) {
22
  const clap = await create({
23
  prompt: `${request?.prompt || ""}`.trim(),
24
  width: getValidNumber(request?.width, 256, 8192, 1024),
25
- height: getValidNumber(request?.height, 256, 8192, 576)
 
26
  })
27
 
28
  // TODO replace by Clap file streaming
 
14
  prompt: string
15
  width: number
16
  height: number
17
+ turbo: boolean
18
  // can add more stuff for the V2 of Stories Factory
19
  }
20
 
 
23
  const clap = await create({
24
  prompt: `${request?.prompt || ""}`.trim(),
25
  width: getValidNumber(request?.width, 256, 8192, 1024),
26
+ height: getValidNumber(request?.height, 256, 8192, 576),
27
+ turbo: request?.turbo ? true : false,
28
  })
29
 
30
  // TODO replace by Clap file streaming