bert-chatbot-ui / app.py
anshupatel4298's picture
Update app.py
a1c2edf verified
import json
import re
import nltk
import numpy as np
from flask import Flask, render_template, request
from transformers import BertTokenizer, TFBertForSequenceClassification
import logging
import random
import time
import os
# Define a function to download NLTK data if it doesn't already exist
def download_nltk_data():
nltk_data_dir = os.path.join(os.path.expanduser("~"), "nltk_data")
if not os.path.exists(nltk_data_dir):
os.makedirs(nltk_data_dir)
if not os.path.exists(os.path.join(nltk_data_dir, "tokenizers", "punkt")):
nltk.download('punkt', download_dir=nltk_data_dir)
if not os.path.exists(os.path.join(nltk_data_dir, "corpora", "wordnet")):
nltk.download('wordnet', download_dir=nltk_data_dir)
download_nltk_data()
app = Flask(__name__)
# Set up logging
logging.basicConfig(level=logging.DEBUG)
# Load the BERT tokenizer and model
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = TFBertForSequenceClassification.from_pretrained(model_name)
# Function to preprocess input text
def preprocess_input(text):
inputs = tokenizer(text, return_tensors='tf', max_length=512, truncation=True, padding='max_length')
return inputs
# Function to predict the class using BERT model
def predict_class(sentence, model):
inputs = preprocess_input(sentence)
outputs = model(inputs)
logits = outputs.logits
predicted_class = np.argmax(logits, axis=-1)[0] # The predicted class (index)
logging.debug(f"Logits: {logits}")
logging.debug(f"Predicted class: {predicted_class}")
return predicted_class
# Function to normalize text
def normalize_text(text):
text = text.lower()
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with single space
text = re.sub(r'[^\w\s]', '', text) # Remove punctuation
return text
# Function to find the intent based on user message
def find_intent(user_message, intents_json):
normalized_message = normalize_text(user_message)
for intent in intents_json["intents"]:
for pattern in intent["patterns"]:
normalized_pattern = normalize_text(pattern)
if normalized_pattern in normalized_message:
return intent
return None
# Function to get the response based on the found intent
def getResponse(user_message, intents_json):
intent = find_intent(user_message, intents_json)
if intent:
responses = intent.get("responses", [])
if responses:
result = random.choice(responses)
logging.debug(f"Response chosen: {result}") # Log the chosen response
return result
else:
logging.debug(f"No responses found for intent.")
return "I'm not sure what to say about that!"
return "Sorry, I didn't understand that."
# Route for the home page
@app.route("/")
def index():
return render_template('chat.html')
# Route to handle the chat messages
@app.route("/get", methods=["GET", "POST"])
def chat():
msg = request.form["msg"]
response = chatbot_response(msg)
return response
# Function to generate the chatbot response
def chatbot_response(user_message):
# Simulate model processing time
time.sleep(random.uniform(0.5, 1.5)) # Simulate delay
# Use the "model" to get a response from the intents
data_file = open('intents.json').read()
intents = json.loads(data_file)
# Use pattern matching to get response from intents
res = getResponse(user_message, intents)
logging.debug(f"Final chatbot response: {res}")
return res
if __name__ == "__main__":
# Run the app on all interfaces
app.run(host='0.0.0.0', port=5000, debug=False) # Ensure it's running on all interfaces