|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import transformers |
|
import torch |
|
import streamlit as st |
|
|
|
import re |
|
|
|
|
|
|
|
model_id = "google/gemma-1.1-2b-it" |
|
dtype = torch.bfloat16 |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
|
|
torch_dtype=dtype, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("π¬ Chatbot") |
|
st.caption("π A streamlit chatbot powered by Google's Gemma") |
|
|
|
|
|
if 'messages' not in st.session_state: |
|
st.session_state['messages'] = [] |
|
|
|
|
|
for messasge in st.session_state.messages: |
|
st.chat_message(messasge["role"]).write(messasge["content"]) |
|
|
|
|
|
if prompt := st.chat_input(): |
|
|
|
|
|
st.chat_message("user").write(prompt) |
|
|
|
st.session_state.messages.append({"role": "user", "content": prompt}) |
|
|
|
messages=st.session_state.messages |
|
|
|
|
|
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
inputs = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt") |
|
outputs = model.generate(input_ids=inputs, max_new_tokens=150) |
|
|
|
|
|
|
|
msg = tokenizer.decode(outputs[0]) |
|
|
|
msg = re.sub(r'<.*?>', '', msg) |
|
|
|
|
|
|
|
st.chat_message("assistant").write(msg) |
|
|
|
|
|
st.session_state.messages.append({"role": "assistant", "content": msg}) |