File size: 16,234 Bytes
c9fb0e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer, util
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import re
import traceback
import torch
import os
from sentence_transformers import SentenceTransformer, util
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import re
import pandas as pd
import json


# Preprocessing text by lowercasing, removing punctuation, and extra spaces
def optimized_preprocess_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Compute cosine similarity between two texts using TF-IDF
def optimized_compute_text_similarity(text1, text2):
    tfidf = TfidfVectorizer(stop_words='english', ngram_range=(1, 1))
    tfidf_matrix = tfidf.fit_transform([text1, text2])
    cosine_sim = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2]).flatten()
    return cosine_sim[0]

# Compute SBERT similarity between question and context
def compute_sbert_similarity(question, context, model):
    embeddings = model.encode([question, context], convert_to_tensor=True)
    similarity = util.pytorch_cos_sim(embeddings[0], embeddings[1]).item()
    return similarity

# Use hybrid approach: TF-IDF to narrow down top N contexts, then SBERT for refined similarity
def hybrid_sbert_approach(question, filtered_contexts, model, top_n=10):
    tfidf = TfidfVectorizer(stop_words='english')
    contexts_combined = [question] + filtered_contexts
    tfidf_matrix = tfidf.fit_transform(contexts_combined)
    
    # Calculate TF-IDF similarity and rank contexts
    similarity_scores = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:]).flatten()
    ranked_contexts = [filtered_contexts[i] for i in similarity_scores.argsort()[::-1][:top_n]]
    
    # Refine using SBERT
    sbert_similarities = [compute_sbert_similarity(question, context, model) for context in ranked_contexts]
    ranked_by_sbert = sorted(zip(ranked_contexts, sbert_similarities), key=lambda x: x[1], reverse=True)
    
    return [context for context, _ in ranked_by_sbert]

# RAG with optimized SBERT function
def optimized_generate_rag_context(question, filtered_contexts, selected_context_window=2):
    hybrid_retrieved_contexts = hybrid_sbert_approach(question, filtered_contexts, sbert_model, top_n=int(selected_context_window))
    rag_context = "\n".join(hybrid_retrieved_contexts[:selected_context_window])
    return rag_context

# Extract unique contexts and filter them by length
def extract_and_filter_contexts(data, min_length=151, max_length=3706):
    unique_contexts = data['context'].unique()
    filtered_contexts = [context for context in unique_contexts if min_length <= len(context) <= max_length]
    return filtered_contexts

# Compute the TF-IDF matrix for the question and contexts
def compute_tfidf_and_similarity_scores(question, contexts):
    tfidf = TfidfVectorizer(stop_words='english')
    contexts_combined = [question] + contexts
    tfidf_matrix = tfidf.fit_transform(contexts_combined)
    
    # Calculate the cosine similarity scores
    similarity_scores = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:]).flatten()
    return tfidf_matrix, similarity_scores

# Rank contexts based on similarity scores
def rank_contexts_by_similarity(contexts, similarity_scores):
    ranked_indices = similarity_scores.argsort()[::-1]
    ranked_contexts = [contexts[i] for i in ranked_indices]
    ranked_scores = similarity_scores[ranked_indices]
    return ranked_contexts, ranked_scores

# Select the top contexts based on the selected window
def select_top_contexts(selected_context_window, ranked_contexts, ranked_scores):
    count = int(selected_context_window)
    top_contexts = ranked_contexts[:count]
    top_scores = ranked_scores[:count]
    return top_contexts, top_scores


# Helper function to maintain chat history and generate the response
def maintain_chat_history(message, chat_history):
    if chat_history is None:
        chat_history = []
    chat_history.append({"role": "user", "content": message})
    return chat_history

def generate_rag_context(question, filtered_contexts, selected_context_window = 3):
    tfidf_matrix, similarity_scores = compute_tfidf_and_similarity_scores(question, filtered_contexts)
    ranked_contexts, ranked_scores = rank_contexts_by_similarity(filtered_contexts, similarity_scores)
    top_contexts, top_scores = select_top_contexts(str(selected_context_window), ranked_contexts, ranked_scores)
    rag_context = "\n".join(top_contexts)
    return rag_context

def load_squad_data(filepath):
    with open(filepath, 'r') as f:
        squad_data = json.load(f)
    return squad_data



# Preprocess the data: extract contexts, questions, and answers from the SQuAD data
def raw_preprocess_data(squad_data):
    contexts = []
    questions = []
    answers = []

    for group in squad_data['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    # Make a copy to avoid modifying the original answer
                    answers.append({
                        'text': answer['text'],
                        'answer_start': answer['answer_start']
                    })

    return contexts, questions, answers


# Add the end index of the answer in the context
def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        else:
            # Try to find the correct position if there's a mismatch
            for n in range(1, 30):
                if context[start_idx - n:end_idx - n] == gold_text:
                    answer['answer_start'] = start_idx - n
                    answer['answer_end'] = end_idx - n
                    break
                elif context[start_idx + n:end_idx + n] == gold_text:
                    answer['answer_start'] = start_idx + n
                    answer['answer_end'] = end_idx + n
                    break
            else:
                answer['answer_start'] = -1
                answer['answer_end'] = -1


# Create a DataFrame from the contexts, questions, and answers
def create_dataframe(contexts, questions, answers):
    data = pd.DataFrame({
        'context': contexts,
        'question': questions,
        'answer_text': [answer['text'] for answer in answers],
        'answer_start': [answer['answer_start'] for answer in answers],
        'answer_end': [answer.get('answer_end', -1) for answer in answers]
    })

    # Remove samples with -1 start index
    data = data[data['answer_start'] != -1].reset_index(drop=True)
    return data

# Check if a GPU (CUDA) is available; otherwise, use the CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Loading the pre-trained SBERT model globally for efficiency
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')

# Available models
electra_models = [
    "./models/fine_tuned_electra_model_1000",
    "./models/fine_tuned_electra_model_20000",
    "./models/fine_tuned_electra_model_5000",
    "./models/fine_tuned_electra_model_all"
]
other_models = [
    "./models/fine_tuned_bert_base_cased_1000",
    "./models/fine_tuned_bert_base_cased_all",
    "./models/fine_tuned_distilbert_base_uncased_10000",
    "./models/fine_tuned_distilgpt2_10000",
    "./models/fine_tuned_retro-reader_intensive_1000",
    "./models/fine_tuned_retro-reader_intensive_5000",
    "./models/fine_tuned_retro-reader_sketchy_1000"
]

DATA_DIR = './data'

# Load and preprocess data
squad_data = load_squad_data(DATA_DIR+ '/train-v1.1.json')
contexts, questions, answers = raw_preprocess_data(squad_data)
add_end_idx(answers, contexts)
data = create_dataframe(contexts, questions, answers)

# Function to generate a response with logging and custom content
def generate_response(message, chat_history, model_name, debug, rag, selected_context_window):
    try:
        if chat_history is None:
            chat_history = []
        context = message

        # Determine if the model is for question answering based on its name
        is_question_answering = "electra_model" in model_name

        # Initialize the tokenizer and model
        if is_question_answering:
            model = pipeline("question-answering", model=model_name, tokenizer=model_name, device=device)
        else:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            model = AutoModelForCausalLM.from_pretrained(model_name)
            model.to(device)

        # Append the new user message to the chat history
        chat_history.append({"role": "user", "content": message})

        if is_question_answering:
            if rag:
                filtered_contexts = extract_and_filter_contexts(data, min_length=100, max_length=4000)
                context = generate_rag_context(message, filtered_contexts, selected_context_window)
            else:
                context = "\n".join([turn["content"] for turn in chat_history if turn["role"] == "user"])

            if debug:
                print("context:\n" + context)
                print("message:\n" + message)

            # Call the pipeline for question-answering
            answer = model(question=message, context=context)
            response = answer['answer']

        else:
            # Prepare the conversation history for a regular chatbot
            conversation = ""
            for turn in chat_history:
                if turn["role"] == "user":
                    conversation += f"User: {turn['content']}\n"
                else:
                    conversation += f"Assistant: {turn['content']}\n"

            if debug:
                print("Conversation being sent to the model:\n", conversation)

            # Encode the input and generate a response
            inputs = tokenizer.encode(conversation + "Assistant:", return_tensors='pt').to(device)
            outputs = model.generate(
                inputs,
                max_length=inputs.shape[1] + 100,
                pad_token_id=tokenizer.eos_token_id,
                do_sample=True,
                top_p=0.95,
                top_k=50,
                temperature=0.7,
                eos_token_id=tokenizer.eos_token_id,
            )
            response = tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract the assistant's reply
            response = response[len(conversation):].strip()
            if "User:" in response:
                response = response.split("User:")[0].strip()

        # Append the assistant's response to the chat history
        chat_history.append({"role": "assistant", "content": response})
        if debug:
            print("Generated response:", response)
            print("Configurations:")
            print(f"Model Name: {model_name}")
            print(f"Is Question Answering: {is_question_answering}")
            print(f"RAG Enabled: {rag}")
            print(f"Selected Context Window: {selected_context_window}")

        # Return the updated chat history and the assistant's response
        display_history = [[turn["content"], chat_history[i + 1]["content"]] for i, turn in enumerate(chat_history[:-1]) if turn["role"] == "user" and i + 1 < len(chat_history)]
        return display_history, chat_history

    except Exception as e:
        # Capture the traceback details
        error_message = f"An error occurred: {str(e)}"
        detailed_error = traceback.format_exc()
        chat_history.append({"role": "assistant", "content": error_message})
        if debug:
            print("Error Details:\n", detailed_error)

        # Ensure safe generation of the display history
        try:
            display_history = [[turn["content"], chat_history[i + 1]["content"]] for i, turn in enumerate(chat_history[:-1]) if turn["role"] == "user" and i + 1 < len(chat_history)]
        except Exception as history_error:
            if debug:
                print("Error while generating display history:", str(history_error))
            display_history = []

        return display_history, chat_history

# Gradio Interface Configuration
def run_prod_chatbot(local=True):
    with gr.Blocks() as demo:
        gr.Markdown("""
        <div style="text-align: center;">
            <h1><strong>SQuAD Q&A ChatBot</strong></h1>
            <h3>Authors: <a href="https://github.com/zainnobody">Zain Ali</a> & <a href="https://github.com/AIBenHopwood/">Ben Hopwood</a></h3>
            <p>
                <a href="https://github.com/zainnobody/AAI-520-Final-Project" target="_blank">Code: GitHub link</a> &nbsp;|&nbsp;
                <a href="https://huggingface.co/zainnobody/AAI-520-Final-Project-Models" target="_blank">Models: Huggingface link</a>
            </p>
        </div>
        
        <div style="text-align: center;">
            <p>
                This project aims to develop a chatbot capable of multi-turn, context-adaptive conversations across various topics, using the Stanford Question Answering Dataset (SQuAD) as the primary source for training.
            </p>
        </div>
        
        <div style="text-align: center;">
            <h4>University of San Diego - AAI 520</h4>
        </div>
        
                """)
        with gr.Row(variant="compact"):
            model_dropdown = gr.Dropdown(
                choices=electra_models + other_models,
                label="Select Model",
                value="./models/fine_tuned_electra_model_all"
            )
            # Column for Use RAG and Debug Mode checkboxes
            with gr.Column():
                rag_checkbox = gr.Checkbox(
                    label="Use RAG", 
                    value=True, 
                    interactive=True
                )
                debug_checkbox = gr.Checkbox(
                    label="Debug Mode", 
                    value=False
                )
            context_window_dropdown = gr.Dropdown(
                choices=[1, 2, 3],
                label="Select Context Window",
                value=1
            )
    
        # Commented out the is_question_answering_checkbox, making it auto detectable. Leaving this as a reminder that other models do not use pipeline
        # is_question_answering_checkbox = gr.Checkbox(
        #     label="Use Question Answering (Electra Only)", 
        #     value=True
        # )
        
        chatbot = gr.Chatbot()
        state = gr.State([])
        
        with gr.Row():
            # Textbox taking 75% of the space
            msg = gr.Textbox(label="Your message", placeholder="Type your message here and press Enter", scale=3)
            # Send button taking 25% of the space and stretching full width
            send_btn = gr.Button("Send", scale=1)

        
            
        send_btn.click(lambda message, chat_history, model_name, debug, rag, selected_context_window: generate_response(message, chat_history, model_name, debug, rag, selected_context_window),
                      inputs=[msg, state, model_dropdown, debug_checkbox, rag_checkbox, context_window_dropdown],
                      outputs=[chatbot, state])
        msg.submit(lambda message, chat_history, model_name, debug, rag, selected_context_window: generate_response(message, chat_history, model_name, debug, rag, selected_context_window),
                   inputs=[msg, state, model_dropdown, debug_checkbox, rag_checkbox, context_window_dropdown],
                   outputs=[chatbot, state])

    if local:
        demo.launch(share=True)
    else:
        demo.launch(server_name="0.0.0.0", server_port=None)

# Launch the Gradio app
run_prod_chatbot()