andresicedo commited on
Commit
57ceb42
·
verified ·
1 Parent(s): fa734c9

Update app.py

Browse files

added logic to classify the user input as one of the five faq response keys, else generate a response from the chatbot model

Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -4,6 +4,9 @@ import gradio as gr # Import Gradio for the interface
4
  # Load a text-generation model
5
  chatbot = pipeline("text-generation", model="microsoft/DialoGPT-medium")
6
 
 
 
 
7
  # Customize the bot's knowledge base with predefined responses
8
  faq_responses = {
9
  "study tips": "Here are some study tips: 1) Break your study sessions into 25-minute chunks (Pomodoro Technique). 2) Test yourself frequently. 3) Stay organized using planners or apps like Notion or Todoist.",
@@ -15,10 +18,25 @@ faq_responses = {
15
 
16
  # Define the chatbot's response function
17
  def faq_chatbot(user_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Check if the user's input matches any FAQ keywords
19
- for key, response in faq_responses.items():
20
- if key in user_input.lower():
21
- return response
22
 
23
  # If no FAQ match, use the AI model to generate a response
24
  conversation = chatbot(user_input, max_length=50, num_return_sequences=1)
 
4
  # Load a text-generation model
5
  chatbot = pipeline("text-generation", model="microsoft/DialoGPT-medium")
6
 
7
+ # Load the classification model
8
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
9
+
10
  # Customize the bot's knowledge base with predefined responses
11
  faq_responses = {
12
  "study tips": "Here are some study tips: 1) Break your study sessions into 25-minute chunks (Pomodoro Technique). 2) Test yourself frequently. 3) Stay organized using planners or apps like Notion or Todoist.",
 
18
 
19
  # Define the chatbot's response function
20
  def faq_chatbot(user_input):
21
+ # Classify the user input by passing the FAQ keywords as labels
22
+ classified_user_input = classifier(user_input, candidate_labels=list(faq_responses.keys()))
23
+
24
+ # Get the highest confidence score label, ie. the most likely of the FAQ
25
+ predicted_label = classified_user_input["labels"][0]
26
+ confidence_score = classified_user_input["scores"][0]
27
+
28
+ # Confidence threshold (adjust if needed)
29
+ threshold = 0.5
30
+
31
+ # If the classification confidence is high, return the corresponding FAQ response
32
+ if confidence_score > threshold:
33
+ return faq_responses[predicted_label]
34
+
35
+
36
  # Check if the user's input matches any FAQ keywords
37
+ # for key, response in faq_responses.items():
38
+ # if key in user_input.lower():
39
+ # return response
40
 
41
  # If no FAQ match, use the AI model to generate a response
42
  conversation = chatbot(user_input, max_length=50, num_return_sequences=1)