File size: 6,799 Bytes
f2019a4
23de817
f2019a4
 
 
 
 
 
 
82a9fe8
f2019a4
1e44fa4
846f47f
feacba8
f2019a4
31559f1
 
208476f
 
31559f1
cc4fb7d
9e80596
f2019a4
9df01d3
 
9e80596
f2019a4
31559f1
 
b2c6adc
f2019a4
31559f1
208476f
 
 
 
f2019a4
459a15e
c4ea8b5
 
459a15e
31559f1
 
6e2babb
b2c6adc
2bca27e
31559f1
2bca27e
31559f1
 
 
 
 
 
 
 
 
 
 
 
6e2babb
 
31559f1
 
 
 
 
 
 
 
f2019a4
9e80596
 
cc4fb7d
9e80596
 
 
 
 
 
f2019a4
9e80596
 
f2019a4
9e80596
 
f2019a4
9e80596
9934680
f2019a4
9e80596
 
 
 
 
 
 
 
 
 
 
459a15e
b053d03
f27ccbe
22b51ff
9e80596
 
22b51ff
 
 
 
 
 
 
31559f1
 
 
 
22b51ff
 
 
 
 
 
 
 
 
 
9e80596
459a15e
 
b053d03
 
208476f
9e80596
 
 
 
 
dcbc7a2
9e80596
dcbc7a2
9e80596
 
 
 
 
b053d03
9e80596
 
b053d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9bfa30
34bbc8b
3481362
f2019a4
9e80596
 
f2019a4
 
 
 
 
 
b053d03
54eb092
d9bfa30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b053d03
7468778
b053d03
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModel, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from vllm import LLM, SamplingParams
import torch
import gradio as gr
import json
import os
import shutil
import requests
import numpy as np
import pandas as pd
from threading import Thread
from FlagEmbedding import BGEM3FlagModel
from sklearn.metrics.pairwise import cosine_similarity

from transformers import AutoModelForSequenceClassification

device = "cuda" if torch.cuda.is_available() else "cpu"

#Importing the embedding model
embedding_model = BGEM3FlagModel('BAAI/bge-m3',  
                       use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation

embeddings = np.load("embeddings_albert_tchap.npy")
embeddings_data = pd.read_json("embeddings_albert_tchap.json")
embeddings_text = embeddings_data["text_with_context"].tolist()

#Importing the classifier/router (deberta)
classifier_model = AutoModelForSequenceClassification.from_pretrained("AgentPublic/chatrag-deberta")
classifier_tokenizer = AutoTokenizer.from_pretrained("AgentPublic/chatrag-deberta")

#Importing the actual generative LLM (llama-based)
model_name = "Pclanglais/Tchap"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to('cuda:0')

system_prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nTu es Albert, l'agent conversationnel des services publics qui peut décrire des documents de référence ou aider à des tâches de rédaction<|eot_id|>"
source_text = "Les sources utilisées par Albert-Tchap vont apparaître ici'"


#Function to guess whether we use the RAG or not.
def classification_chatrag(query):
  print(query)
  encoding = classifier_tokenizer(query, return_tensors="pt")
  encoding = {k: v.to(classifier_model.device) for k,v in encoding.items()}

  outputs = classifier_model(**encoding)

  logits = outputs.logits
  logits.shape

  # apply sigmoid + threshold
  sigmoid = torch.nn.Sigmoid()
  probs = sigmoid(logits.squeeze().cpu())
  predictions = np.zeros(probs.shape)

  # Extract the float value from the tensor
  float_value = round(probs.item()*100)

  print(float_value)

  if float_value > 50:
    status = True
    print("We activate RAG")
  else:
    status = False
    print("We remove RAG")
  return status

#Vector search over the database
def vector_search(sentence_query):

    query_embedding = embedding_model.encode(sentence_query, 
                            batch_size=12, 
                            max_length=256, # If you don't need such a long length, you can set a smaller value to speed up the encoding process.
                            )['dense_vecs']

    # Reshape the query embedding to fit the cosine_similarity function requirements
    query_embedding_reshaped = query_embedding.reshape(1, -1)
    
    # Compute cosine similarities
    similarities = cosine_similarity(query_embedding_reshaped, embeddings)
    
    # Find the index of the closest document (highest similarity)
    closest_doc_index = np.argmax(similarities)
    
    # Closest document's embedding
    closest_doc_embedding = embeddings_text[closest_doc_index]
    
    return closest_doc_embedding


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [29, 0]
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False


def predict(history_transformer_format):

    print(history_transformer_format)
    stop = StopOnTokens()

    messages = []
    id_message = 1
    total_message = len(history_transformer_format)
    for item in history_transformer_format:

        #Once we target the ongoing post we add the source.
        if id_message == total_message:
            if assess_rag:
                question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0] + "\n\n### Source ###\n" + source_text
            else:
                question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
        else:
            question = "<|start_header_id|>user<|end_header_id|>\n\n"+ item[0]
        answer = "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"+item[1]
        result = "".join([question, answer])
        messages.append(result)
        id_message = id_message + 1

    messages = "".join(messages)

    print(messages)

    messages = system_prompt + messages

    print(messages)

    model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
    streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=False,
        top_p=0.95,
        temperature=0.4,
        stopping_criteria=StoppingCriteriaList([stop])
        )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    history_transformer_format[-1][1] = ""
    for new_token in streamer:
        if new_token != '<':
            history_transformer_format[-1][1] += new_token
            yield history_transformer_format

def user(message, history):
    global source_text
    global assess_rag
    #For now, we only query the vector database once, at the start.
    if len(history) == 0:
        assess_rag = classification_chatrag(message)
        if assess_rag:
            source_text = vector_search(message)
        else:
            source_text = "Albert-Tchap n'utilise pas de sources comme votre requête n'a pas l'air d'en recueillir."
    
    history_transformer_format = history + [[message, ""]]

    print(history_transformer_format)
    source_text = "<h3>Sources</h3><p>" + source_text + "</p>"
    return "", history_transformer_format, source_text

# Define the Gradio interface
title = "Tchap"
description = "Le chatbot du service public"
examples = [
    [
        "Qui peut bénéficier de l'AIP?",  # user_message
        0.7  # temperature
    ]
]

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=2):  
            chatbot = gr.Chatbot()
            msg = gr.Textbox()
            clear = gr.Button("Clear")
        
            history = gr.State()
        
            msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot, user_output], queue=False).then(
                predict, chatbot, chatbot
            )
            
            clear.click(lambda: None, None, chatbot, queue=False)
        with gr.Column(scale=1):
            user_output = gr.HTML()  # To display the user's message


demo.queue()
demo.launch()