import * as webllm from "https://esm.run/@mlc-ai/web-llm";
const messages = [
{
content: "You are a helpful AI agent assisting users. And your name is 'Kanha'",
role: "system",
},
];
const modelLibURLPrefix = "https://huggingface.co/Kanha-AI/";
const modelVersion = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC";
const appConfig = {
model_list: [
{
model: "https://huggingface.co/Kanha-AI/llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC",
model_id: "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC",
model_lib:
webllm.modelLibURLPrefix + webllm.modelVersion + "/Llama-3.2-1B-Instruct-q4f16_1-ctx4k_cs1k-webgpu.wasm",
vram_required_MB: 3672.07,
low_resource_required: false,
overrides: {
context_window_size: 4096,
},
},
],
};
let selectedModel = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC";
let engine = null;
let isInitializing = false;
async function createEngine() {
if (engine) return engine;
engine = await webllm.CreateMLCEngine(selectedModel, { appConfig: appConfig });
engine.setInitProgressCallback(updateEngineInitProgressCallback);
return engine;
}
function updateEngineInitProgressCallback(report) {
console.log("initialize", report.progress);
const statusElement = document.getElementById("download-status");
statusElement.textContent = report.text;
statusElement.classList.remove("hidden");
const progressBar = document.getElementById("progress-bar");
progressBar.style.width = `${report.progress * 100}%`;
}
function getPlatform() {
const userAgent = navigator.userAgent || navigator.vendor || window.opera;
if (/android/i.test(userAgent)) {
return "Android";
}
if (/iPad|iPhone|iPod/.test(userAgent) && !window.MSStream) {
return "iOS";
}
return "Other";
}
function getWebGPUInstructions(platform) {
switch (platform) {
case "Android":
return "To enable WebGPU on Android:\n1. Go to chrome://flags\n2. Enable 'WebGPU Developer Features' and 'Unsafe WebGPU Support'\n3. Restart your browser";
case "iOS":
return "To enable WebGPU on iOS:\n1. Open Settings\n2. Tap Safari\n3. Tap Advanced\n4. Tap Feature Flags\n5. Turn on WebGPU";
default:
return "WebGPU might not be supported on your device. Please check if your browser is up to date.";
}
}
async function checkWebGPUSupport() {
if (!navigator.gpu) {
const platform = getPlatform();
const instructions = getWebGPUInstructions(platform);
throw new Error(`WebGPU is not supported in this browser. ${instructions}`);
}
const adapter = await navigator.gpu.requestAdapter();
if (!adapter) {
throw new Error("Couldn't request WebGPU adapter. Please make sure WebGPU is enabled on your device.");
}
const device = await adapter.requestDevice();
if (!device) {
throw new Error("Couldn't request WebGPU device. Please make sure WebGPU is enabled on your device.");
}
return true;
}
async function checkRAM() {
if (!navigator.deviceMemory) {
console.warn("Device memory information is not available.");
return true; // Assume it's okay if we can't check
}
const ramGB = navigator.deviceMemory;
if (ramGB < 3) {
throw new Error(`Insufficient RAM. Required: 2GB Free RAM, Available: ${ramGB}GB`);
}
return true;
}
async function initializeWebLLMEngine() {
if (isInitializing) return;
isInitializing = true;
const progressContainer = document.getElementById("progress-container");
const statusElement = document.getElementById("download-status");
try {
// Check system requirements
await checkRAM();
await checkWebGPUSupport();
progressContainer.classList.remove("hidden");
statusElement.classList.remove("hidden");
selectedModel = "llama-3.2-1b-test_200steps_bs1_r0_lr2e-06_nq-q4f16_1-MLC"; // Using the default model
const config = {
temperature: 1.0,
top_p: 1,
};
const engine = await createEngine();
await engine.reload(selectedModel, config);
statusElement.textContent = "Model initialized successfully!";
} catch (error) {
console.error("Error initializing WebLLM engine:", error);
statusElement.textContent = `Error initializing: ${error.message}\n\nFor more information and troubleshooting, please visit kanha.ai/faq`;
statusElement.classList.remove("hidden");
throw error; // Re-throw the error to be caught in onMessageSend
} finally {
progressContainer.classList.add("hidden");
isInitializing = false;
}
}
async function streamingGenerating(messages, onUpdate, onFinish, onError) {
try {
let curMessage = "";
let usage;
const engine = await createEngine();
const completion = await engine.chat.completions.create({
stream: true,
messages,
stream_options: { include_usage: true },
});
for await (const chunk of completion) {
const curDelta = chunk.choices[0]?.delta.content;
if (curDelta) {
curMessage += curDelta;
}
if (chunk.usage) {
usage = chunk.usage;
}
onUpdate(curMessage);
}
const finalMessage = await engine.getMessage();
onFinish(finalMessage, usage);
} catch (err) {
onError(err);
}
}
async function onMessageSend() {
const input = document.getElementById("user-input");
const sendButton = document.getElementById("send");
const message = {
content: input.value.trim(),
role: "user",
};
if (message.content.length === 0) {
return;
}
sendButton.disabled = true;
sendButton.innerHTML = '';
messages.push(message);
appendMessage(message);
input.value = "";
input.setAttribute("placeholder", "AI is thinking...");
const aiMessage = {
content: "typing...",
role: "assistant",
};
appendMessage(aiMessage);
try {
if (!engine) {
await initializeWebLLMEngine();
}
const onFinishGenerating = (finalMessage, usage) => {
updateLastMessage(finalMessage);
sendButton.disabled = false;
sendButton.innerHTML = '';
input.setAttribute("placeholder", "Type your message here...");
if (usage) {
const usageText =
`Prompt tokens: ${usage.prompt_tokens}, ` +
`Completion tokens: ${usage.completion_tokens}, ` +
`Prefill: ${usage.extra.prefill_tokens_per_s.toFixed(2)} tokens/sec, ` +
`Decoding: ${usage.extra.decode_tokens_per_s.toFixed(2)} tokens/sec`;
document.getElementById("chat-stats").classList.remove("hidden");
document.getElementById("chat-stats").textContent = usageText;
}
};
await streamingGenerating(
messages,
updateLastMessage,
onFinishGenerating,
onError
);
} catch (error) {
onError(error);
// Update the AI message to show the error
updateLastMessage("I'm sorry, but I encountered an error: " + error.message);
}
}
function appendMessage(message) {
const chatBox = document.getElementById("chat-box");
const messageElement = document.createElement("div");
messageElement.classList.add("message");
if (message.role === "user") {
messageElement.classList.add("user-message");
messageElement.textContent = message.content;
} else {
messageElement.classList.add("assistant-message");
if (message.content === "typing...") {
messageElement.classList.add("typing");
messageElement.textContent = message.content;
} else {
messageElement.innerHTML = marked.parse(message.content);
}
}
chatBox.appendChild(messageElement);
chatBox.scrollTop = chatBox.scrollHeight;
}
function updateLastMessage(content) {
const chatBox = document.getElementById("chat-box");
const messages = chatBox.getElementsByClassName("message");
const lastMessage = messages[messages.length - 1];
lastMessage.innerHTML = marked.parse(content);
lastMessage.classList.remove("typing");
}
function onError(err) {
console.error(err);
const statusElement = document.getElementById("download-status");
statusElement.textContent = `Error: ${err.message}\n\nFor more information and troubleshooting, please visit kanha.ai/faq`;
statusElement.classList.remove("hidden");
document.getElementById("send").disabled = false;
document.getElementById("send").innerHTML = '';
}
// UI binding
document.getElementById("send").addEventListener("click", onMessageSend);
document.getElementById("user-input").addEventListener("keypress", function(event) {
if (event.key === "Enter") {
event.preventDefault();
onMessageSend();
}
});