Spaces:
Sleeping
Sleeping
import json | |
import re | |
import nltk | |
import numpy as np | |
from flask import Flask, render_template, request | |
from transformers import BertTokenizer, TFBertForSequenceClassification | |
import logging | |
import random | |
import time | |
import os | |
# Define a function to download NLTK data if it doesn't already exist | |
def download_nltk_data(): | |
nltk_data_dir = os.path.join(os.path.expanduser("~"), "nltk_data") | |
if not os.path.exists(nltk_data_dir): | |
os.makedirs(nltk_data_dir) | |
if not os.path.exists(os.path.join(nltk_data_dir, "tokenizers", "punkt")): | |
nltk.download('punkt', download_dir=nltk_data_dir) | |
if not os.path.exists(os.path.join(nltk_data_dir, "corpora", "wordnet")): | |
nltk.download('wordnet', download_dir=nltk_data_dir) | |
download_nltk_data() | |
app = Flask(__name__) | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
# Load the BERT tokenizer and model | |
model_name = "bert-base-uncased" | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
bert_model = TFBertForSequenceClassification.from_pretrained(model_name) | |
# Function to preprocess input text | |
def preprocess_input(text): | |
inputs = tokenizer(text, return_tensors='tf', max_length=512, truncation=True, padding='max_length') | |
return inputs | |
# Function to predict the class using BERT model | |
def predict_class(sentence, model): | |
inputs = preprocess_input(sentence) | |
outputs = model(inputs) | |
logits = outputs.logits | |
predicted_class = np.argmax(logits, axis=-1)[0] # The predicted class (index) | |
logging.debug(f"Logits: {logits}") | |
logging.debug(f"Predicted class: {predicted_class}") | |
return predicted_class | |
# Function to normalize text | |
def normalize_text(text): | |
text = text.lower() | |
text = re.sub(r'\s+', ' ', text) # Replace multiple spaces with single space | |
text = re.sub(r'[^\w\s]', '', text) # Remove punctuation | |
return text | |
# Function to find the intent based on user message | |
def find_intent(user_message, intents_json): | |
normalized_message = normalize_text(user_message) | |
for intent in intents_json["intents"]: | |
for pattern in intent["patterns"]: | |
normalized_pattern = normalize_text(pattern) | |
if normalized_pattern in normalized_message: | |
return intent | |
return None | |
# Function to get the response based on the found intent | |
def getResponse(user_message, intents_json): | |
intent = find_intent(user_message, intents_json) | |
if intent: | |
responses = intent.get("responses", []) | |
if responses: | |
result = random.choice(responses) | |
logging.debug(f"Response chosen: {result}") # Log the chosen response | |
return result | |
else: | |
logging.debug(f"No responses found for intent.") | |
return "I'm not sure what to say about that!" | |
return "Sorry, I didn't understand that." | |
# Route for the home page | |
def index(): | |
return render_template('chat.html') | |
# Route to handle the chat messages | |
def chat(): | |
msg = request.form["msg"] | |
response = chatbot_response(msg) | |
return response | |
# Function to generate the chatbot response | |
def chatbot_response(user_message): | |
# Simulate model processing time | |
time.sleep(random.uniform(0.5, 1.5)) # Simulate delay | |
# Use the "model" to get a response from the intents | |
data_file = open('intents.json').read() | |
intents = json.loads(data_file) | |
# Use pattern matching to get response from intents | |
res = getResponse(user_message, intents) | |
logging.debug(f"Final chatbot response: {res}") | |
return res | |
if __name__ == "__main__": | |
# Run the app on all interfaces | |
app.run(host='0.0.0.0', port=5000, debug=False) # Ensure it's running on all interfaces | |