|
<script lang="ts"> |
|
import ChatWindow from '$lib/components/chat/ChatWindow.svelte'; |
|
import { pendingMessage } from '$lib/stores/pendingMessage'; |
|
import { onMount } from 'svelte'; |
|
import type { PageData } from './$types'; |
|
import { page } from '$app/stores'; |
|
import { |
|
PUBLIC_ASSISTANT_MESSAGE_TOKEN, |
|
PUBLIC_SEP_TOKEN, |
|
PUBLIC_USER_MESSAGE_TOKEN |
|
} from '$env/static/public'; |
|
import { HfInference } from '@huggingface/inference'; |
|
|
|
export let data: PageData; |
|
|
|
$: messages = data.messages; |
|
|
|
const userToken = PUBLIC_USER_MESSAGE_TOKEN; |
|
const assistantToken = PUBLIC_ASSISTANT_MESSAGE_TOKEN; |
|
const sepToken = PUBLIC_SEP_TOKEN; |
|
|
|
const hf = new HfInference(); |
|
const model = hf.endpoint(`${$page.url.origin}/api/conversation`); |
|
|
|
let loading = false; |
|
|
|
async function getTextGenerationStream(inputs: string) { |
|
const response = model.textGenerationStream( |
|
{ |
|
inputs, |
|
parameters: { |
|
|
|
|
|
stop: ['<|endoftext|>'], |
|
max_new_tokens: 1024, |
|
truncate: 1024, |
|
typical_p: 0.2 |
|
} |
|
}, |
|
{ |
|
use_cache: false |
|
} |
|
); |
|
|
|
|
|
const endOfTextRegex = /(?<!`)<(?!`)/; |
|
|
|
for await (const data of response) { |
|
if (!data) break; |
|
|
|
if (!data.token.special) { |
|
if (messages.at(-1)?.from !== 'assistant') { |
|
|
|
messages = [...messages, { from: 'assistant', content: data.token.text.trimStart() }]; |
|
} else { |
|
const isEndOfText = endOfTextRegex.test(data.token.text); |
|
|
|
messages.at(-1)!.content += isEndOfText |
|
? data.token.text.replace('<', '') |
|
: data.token.text; |
|
messages = messages; |
|
|
|
if (isEndOfText) break; |
|
} |
|
} |
|
} |
|
|
|
|
|
} |
|
|
|
async function writeMessage(message: string) { |
|
if (!message.trim()) return; |
|
|
|
try { |
|
loading = true; |
|
|
|
messages = [...messages, { from: 'user', content: message }]; |
|
message = ''; |
|
const inputs = |
|
messages |
|
.map( |
|
(m) => |
|
(m.from === 'user' ? userToken + m.content : assistantToken + m.content) + |
|
(m.content.endsWith(sepToken) ? '' : sepToken) |
|
) |
|
.join('') + assistantToken; |
|
|
|
await getTextGenerationStream(inputs); |
|
} finally { |
|
loading = false; |
|
} |
|
} |
|
|
|
onMount(async () => { |
|
if ($pendingMessage) { |
|
const val = $pendingMessage; |
|
$pendingMessage = ''; |
|
|
|
if (messages.length === 0) { |
|
writeMessage(val); |
|
} |
|
} |
|
}); |
|
</script> |
|
|
|
<ChatWindow disabled={loading} {messages} on:message={(message) => writeMessage(message.detail)} /> |
|
|