Spaces:
Running
Running
import * as webllm from "https://esm.run/@mlc-ai/web-llm"; | |
const messages = [ | |
{ | |
content: "You are a helpful AI agent assisting users. And your name is 'Kanha'", | |
role: "system", | |
}, | |
]; | |
const modelLibURLPrefix = "https://huggingface.co/Kanha-AI/"; | |
const modelVersion = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC"; | |
const appConfig = { | |
model_list: [ | |
{ | |
model: "https://huggingface.co/Kanha-AI/llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC", | |
model_id: "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC", | |
model_lib: | |
webllm.modelLibURLPrefix + webllm.modelVersion + "/Llama-3.2-1B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm", | |
vram_required_MB: 3672.07, | |
low_resource_required: false, | |
overrides: { | |
context_window_size: 4096, | |
}, | |
}, | |
], | |
}; | |
let selectedModel = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC"; | |
let engine = null; | |
let isInitializing = false; | |
async function createEngine() { | |
if (engine) return engine; | |
engine = await webllm.CreateMLCEngine(selectedModel, { appConfig: appConfig }); | |
engine.setInitProgressCallback(updateEngineInitProgressCallback); | |
return engine; | |
} | |
function updateEngineInitProgressCallback(report) { | |
console.log("initialize", report.progress); | |
const statusElement = document.getElementById("download-status"); | |
statusElement.textContent = report.text; | |
statusElement.classList.remove("hidden"); | |
const progressBar = document.getElementById("progress-bar"); | |
progressBar.style.width = `${report.progress * 100}%`; | |
} | |
function getPlatform() { | |
const userAgent = navigator.userAgent || navigator.vendor || window.opera; | |
if (/android/i.test(userAgent)) { | |
return "Android"; | |
} | |
if (/iPad|iPhone|iPod/.test(userAgent) && !window.MSStream) { | |
return "iOS"; | |
} | |
return "Other"; | |
} | |
function getWebGPUInstructions(platform) { | |
switch (platform) { | |
case "Android": | |
return "To enable WebGPU on Android:\n1. Go to chrome://flags\n2. Enable 'WebGPU Developer Features' and 'Unsafe WebGPU Support'\n3. Restart your browser"; | |
case "iOS": | |
return "To enable WebGPU on iOS:\n1. Open Settings\n2. Tap Safari\n3. Tap Advanced\n4. Tap Feature Flags\n5. Turn on WebGPU"; | |
default: | |
return "WebGPU might not be supported on your device. Please check if your browser is up to date."; | |
} | |
} | |
async function checkWebGPUSupport() { | |
if (!navigator.gpu) { | |
const platform = getPlatform(); | |
const instructions = getWebGPUInstructions(platform); | |
throw new Error(`WebGPU is not supported in this browser. ${instructions}`); | |
} | |
const adapter = await navigator.gpu.requestAdapter(); | |
if (!adapter) { | |
throw new Error("Couldn't request WebGPU adapter. Please make sure WebGPU is enabled on your device."); | |
} | |
const device = await adapter.requestDevice(); | |
if (!device) { | |
throw new Error("Couldn't request WebGPU device. Please make sure WebGPU is enabled on your device."); | |
} | |
return true; | |
} | |
async function checkRAM() { | |
if (!navigator.deviceMemory) { | |
console.warn("Device memory information is not available."); | |
return true; // Assume it's okay if we can't check | |
} | |
const ramGB = navigator.deviceMemory; | |
if (ramGB < 3) { | |
throw new Error(`Insufficient RAM. Required: 2GB Free RAM, Available: ${ramGB}GB`); | |
} | |
return true; | |
} | |
async function initializeWebLLMEngine() { | |
if (isInitializing) return; | |
isInitializing = true; | |
const progressContainer = document.getElementById("progress-container"); | |
const statusElement = document.getElementById("download-status"); | |
try { | |
// Check system requirements | |
await checkRAM(); | |
await checkWebGPUSupport(); | |
progressContainer.classList.remove("hidden"); | |
statusElement.classList.remove("hidden"); | |
selectedModel = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC"; // Using the default model | |
const config = { | |
temperature: 1.0, | |
top_p: 1, | |
}; | |
const engine = await createEngine(); | |
await engine.reload(selectedModel, config); | |
statusElement.textContent = "Model initialized successfully!"; | |
} catch (error) { | |
console.error("Error initializing WebLLM engine:", error); | |
statusElement.textContent = `Error initializing: ${error.message}\n\nFor more information and troubleshooting, please visit kanha.ai/faq`; | |
statusElement.classList.remove("hidden"); | |
throw error; // Re-throw the error to be caught in onMessageSend | |
} finally { | |
progressContainer.classList.add("hidden"); | |
isInitializing = false; | |
} | |
} | |
async function streamingGenerating(messages, onUpdate, onFinish, onError) { | |
try { | |
let curMessage = ""; | |
let usage; | |
const engine = await createEngine(); | |
const completion = await engine.chat.completions.create({ | |
stream: true, | |
messages, | |
stream_options: { include_usage: true }, | |
}); | |
for await (const chunk of completion) { | |
const curDelta = chunk.choices[0]?.delta.content; | |
if (curDelta) { | |
curMessage += curDelta; | |
} | |
if (chunk.usage) { | |
usage = chunk.usage; | |
} | |
onUpdate(curMessage); | |
} | |
const finalMessage = await engine.getMessage(); | |
onFinish(finalMessage, usage); | |
} catch (err) { | |
onError(err); | |
} | |
} | |
async function onMessageSend() { | |
const input = document.getElementById("user-input"); | |
const sendButton = document.getElementById("send"); | |
const message = { | |
content: input.value.trim(), | |
role: "user", | |
}; | |
if (message.content.length === 0) { | |
return; | |
} | |
sendButton.disabled = true; | |
sendButton.innerHTML = '<i class="fas fa-spinner fa-spin"></i>'; | |
messages.push(message); | |
appendMessage(message); | |
input.value = ""; | |
input.setAttribute("placeholder", "AI is thinking..."); | |
const aiMessage = { | |
content: "typing...", | |
role: "assistant", | |
}; | |
appendMessage(aiMessage); | |
try { | |
if (!engine) { | |
await initializeWebLLMEngine(); | |
} | |
const onFinishGenerating = (finalMessage, usage) => { | |
updateLastMessage(finalMessage); | |
sendButton.disabled = false; | |
sendButton.innerHTML = '<i class="fas fa-paper-plane"></i>'; | |
input.setAttribute("placeholder", "Type your message here..."); | |
if (usage) { | |
const usageText = | |
`Prompt tokens: ${usage.prompt_tokens}, ` + | |
`Completion tokens: ${usage.completion_tokens}, ` + | |
`Prefill: ${usage.extra.prefill_tokens_per_s.toFixed(2)} tokens/sec, ` + | |
`Decoding: ${usage.extra.decode_tokens_per_s.toFixed(2)} tokens/sec`; | |
document.getElementById("chat-stats").classList.remove("hidden"); | |
document.getElementById("chat-stats").textContent = usageText; | |
} | |
}; | |
await streamingGenerating( | |
messages, | |
updateLastMessage, | |
onFinishGenerating, | |
onError | |
); | |
} catch (error) { | |
onError(error); | |
// Update the AI message to show the error | |
updateLastMessage("I'm sorry, but I encountered an error: " + error.message); | |
} | |
} | |
function appendMessage(message) { | |
const chatBox = document.getElementById("chat-box"); | |
const messageElement = document.createElement("div"); | |
messageElement.classList.add("message"); | |
if (message.role === "user") { | |
messageElement.classList.add("user-message"); | |
messageElement.textContent = message.content; | |
} else { | |
messageElement.classList.add("assistant-message"); | |
if (message.content === "typing...") { | |
messageElement.classList.add("typing"); | |
messageElement.textContent = message.content; | |
} else { | |
messageElement.innerHTML = marked.parse(message.content); | |
} | |
} | |
chatBox.appendChild(messageElement); | |
chatBox.scrollTop = chatBox.scrollHeight; | |
} | |
function updateLastMessage(content) { | |
const chatBox = document.getElementById("chat-box"); | |
const messages = chatBox.getElementsByClassName("message"); | |
const lastMessage = messages[messages.length - 1]; | |
lastMessage.innerHTML = marked.parse(content); | |
lastMessage.classList.remove("typing"); | |
} | |
function onError(err) { | |
console.error(err); | |
const statusElement = document.getElementById("download-status"); | |
statusElement.textContent = `Error: ${err.message}\n\nFor more information and troubleshooting, please visit kanha.ai/faq`; | |
statusElement.classList.remove("hidden"); | |
document.getElementById("send").disabled = false; | |
document.getElementById("send").innerHTML = '<i class="fas fa-paper-plane"></i>'; | |
} | |
// UI binding | |
document.getElementById("send").addEventListener("click", onMessageSend); | |
document.getElementById("user-input").addEventListener("keypress", function(event) { | |
if (event.key === "Enter") { | |
event.preventDefault(); | |
onMessageSend(); | |
} | |
}); |