bert-chatbot-ui / app.py
AnshuPatel4298
Fix protobuf compatibility issues
974d713
raw
history blame
1.77 kB
import streamlit as st
from transformers import BertTokenizer, TFBertForSequenceClassification
import tensorflow as tf
import numpy as np
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
# Paths to your models hosted on Hugging Face
basic_model_url = "https://huggingface.co/anshupatel4298/bert-chatbot-model/resolve/main/basic_chatbot_model.h5"
bert_model_name = "anshupatel4298/bert-chatbot-model/bert_model"
# Load Basic Model
basic_model = tf.keras.models.load_model(basic_model_url)
# Load BERT Model and Tokenizer
bert_model = TFBertForSequenceClassification.from_pretrained(bert_model_name)
bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
# Set your MAX_SEQUENCE_LENGTH here
MAX_SEQUENCE_LENGTH = 100
# Streamlit UI
st.sidebar.title("Select Model")
model_choice = st.sidebar.selectbox("Choose a model:", ["Basic Model", "BERT Model"])
st.title("Chatbot Interface")
user_input = st.text_input("You:")
if st.button("Send"):
if user_input:
if model_choice == "Basic Model":
# Preprocess input for basic model
tokenized_input = tf.keras.preprocessing.text.Tokenizer().texts_to_sequences([user_input])
input_data = tf.keras.preprocessing.sequence.pad_sequences(tokenized_input, maxlen=MAX_SEQUENCE_LENGTH)
prediction = basic_model.predict(input_data)
response = np.argmax(prediction, axis=-1)[0]
else:
# Preprocess input for BERT model
inputs = bert_tokenizer(user_input, return_tensors="tf", max_length=MAX_SEQUENCE_LENGTH, truncation=True, padding="max_length")
outputs = bert_model(**inputs)
response = tf.argmax(outputs.logits, axis=-1).numpy()[0]
st.write(f"Bot: {response}")