import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig # use the official bitnet package to supply the missing code from bitnet.configuration_bitnet import BitNetConfig from bitnet.modeling_bitnet import BitNetForCausalLM from bitnet.tokenization_bitnet import BitNetTokenizer # Singleton for model and tokenizer _model = None _tokenizer = None def load_model(): global _model, _tokenizer if _model is None or _tokenizer is None: model_id = "microsoft/bitnet-b1.58-2B-4T" # load tokenizer, config, and model from the bitnet pip package _tokenizer = BitNetTokenizer.from_pretrained(model_id) config = BitNetConfig.from_pretrained(model_id) _model = BitNetForCausalLM.from_pretrained( model_id, config=config, torch_dtype=torch.bfloat16 ) return _model, _tokenizer def manage_history(history): # Limit to 3 turns (each turn is user + assistant = 2 messages) max_messages = 6 # 3 turns * 2 messages per turn if len(history) > max_messages: history = history[-max_messages:] # Limit total character count to 300 total_chars = sum(len(msg["content"]) for msg in history) while total_chars > 300 and history: history.pop(0) # Remove oldest message total_chars = sum(len(msg["content"]) for msg in history) return history def generate_response(user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, history): model, tokenizer = load_model() messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_input}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) chat_input = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate response chat_outputs = model.generate( **chat_input, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True ) # Decode response response = tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True) # Update history history.append({"role": "user", "content": user_input}) history.append({"role": "assistant", "content": response}) # Manage history limits history = manage_history(history) return history, history # Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# BitNet b1.58 2B4T Demo") with gr.Row(): with gr.Column(): gr.Markdown(""" ## About BitNet b1.58 2B4T BitNet b1.58 2B4T is the first open-source, native 1-bit Large Language Model with 2 billion parameters, developed by Microsoft Research. Trained on 4 trillion tokens, it matches the performance of full-precision models while offering significant efficiency gains in memory, energy, and latency. Features include: - Transformer-based architecture with BitLinear layers - Native 1.58-bit weights and 8-bit activations - Maximum context length of 4096 tokens - Optimized for efficient inference with bitnet.cpp """) with gr.Column(): gr.Markdown(""" ## About Tonic AI Tonic AI is a vibrant community of AI enthusiasts and developers always building cool demos and pushing the boundaries of what's possible with AI. We're passionate about creating innovative, accessible, and engaging AI experiences for everyone. Join us in exploring the future of AI! """) with gr.Row(): with gr.Column(): user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") system_prompt = gr.Textbox( label="System Prompt", value="You are a helpful AI assistant.", placeholder="Enter system prompt..." ) with gr.Accordion("Advanced Options", open=False): max_new_tokens = gr.Slider( minimum=10, maximum=500, value=50, step=10, label="Max New Tokens" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top P" ) top_k = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top K" ) submit_btn = gr.Button("Send") with gr.Column(): chatbot = gr.Chatbot(label="Conversation", type="messages") chat_history = gr.State([]) submit_btn.click( fn=generate_response, inputs=[ user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, chat_history ], outputs=[chatbot, chat_history] ) if __name__ == "__main__": # Preload model to avoid threading issues load_model() demo.launch(ssr_mode=False, share=True)