File size: 6,202 Bytes
12e1362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f44a4a
12e1362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b4d478
12e1362
 
 
 
 
 
 
598e82e
12e1362
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import spaces
import os
import gradio as gr
from models import download_models
from rag_backend import Backend
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
import cv2

# get the models
huggingface_token = os.environ.get('HF_TOKEN')
download_models(huggingface_token)

documents_paths = {
    'blockchain': 'data/blockchain',
    'metaverse': 'data/metaverse',
    'payment': 'data/payment'
}


# initialize backend (not ideal as global variable...)
backend = Backend()

cv2.setNumThreads(1)

@spaces.GPU(duration=20)
def respond(
    message,
    history: list[tuple[str, str]],
    model,
    system_message,
    max_tokens,
    temperature,
    top_p,
    top_k,
    repeat_penalty,
):
    chat_template = MessagesFormatterType.GEMMA_2

    print("HISTORY SO FAR ", history)

    matched_path = None
    words = message.lower() 

    for key, path in documents_paths.items():
        if len(history) == 1 and key in words: # check if the user mentions a path word only during second interaction (i.e history has only one entry)
            matched_path = path
            break     
    print("matched_path", matched_path) 

    if matched_path:  # this case would only be true in second interaction
        original_message = history[0][0]
        print("** matched path!!")
        query_engine = backend.create_index_for_query_engine(matched_path)
        message = backend.generate_prompt(query_engine, original_message)
        
        gr.Info("Relevant context indexed from docs...")

    elif (not matched_path) and (len(history) > 1):
        print("Using context from storage db")         
        query_engine = backend.load_index_for_query_engine()
        message = backend.generate_prompt(query_engine, message)

        gr.Info("Relevant context extracted from db...")

    # Load model only if it's not already loaded or if a new model is selected
    if backend.llm is None or backend.llm_model != model:
        try:
            backend.load_model(model)
        except Exception as e:
            return f"Error loading model: {str(e)}"

    provider = LlamaCppPythonProvider(backend.llm)

    agent = LlamaCppAgent(
        provider,
        system_prompt=f"{system_message}",
        predefined_messages_formatter_type=chat_template,
        debug_output=True
    )

    settings = provider.get_provider_default_settings()
    settings.temperature = temperature
    settings.top_k = top_k
    settings.top_p = top_p
    settings.max_tokens = max_tokens
    settings.repeat_penalty = repeat_penalty
    settings.stream = True

    messages = BasicChatHistory()

    # add user and assistant messages to the history
    for msn in history:
        user = {'role': Roles.user, 'content': msn[0]}
        assistant = {'role': Roles.assistant, 'content': msn[1]}
        messages.add_message(user)
        messages.add_message(assistant)

    
    try:
        stream = agent.get_chat_response(
            message, 
            llm_sampling_settings=settings,
            chat_history=messages,
            returns_streaming_generator=True,
            print_output=False
        )

        outputs = ""
        for output in stream:
            outputs += output
            yield outputs
    except Exception as e:
        yield f"Error during response generation: {str(e)}"
        


demo = gr.ChatInterface(
    fn=respond,
    css="""
    .gradio-container {
        background-color: #B9D9EB;
        color: #003366;
    }""",
    additional_inputs=[
        gr.Dropdown([
                'Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf',
                'Mistral-Nemo-Instruct-2407-Q5_K_M.gguf',
                'gemma-2-2b-it-Q6_K_L.gguf',
                'openchat-3.6-8b-20240522-Q6_K.gguf',
                'Llama-3-Groq-8B-Tool-Use-Q6_K.gguf',
                'MiniCPM-V-2_6-Q6_K.gguf',
                'llama-3.1-storm-8b-q5_k_m.gguf',
                'orca-2-7b-patent-instruct-llama-2-q5_k_m.gguf'
            ],
            value="gemma-2-2b-it-Q6_K_L.gguf",
            label="Model"
        ),
        gr.Textbox(value="""Solamente all'inizio, presentati come Odi, un assistente ricercatore italiano creato dagli Osservatori del Politecnico di Milano e specializzato nel fornire risposte precise e pertinenti solo ad argomenti di innovazione digitale. 
        Solo nella tua prima risposta, se non è chiaro, chiedi all'utente di indicare a quale di queste tre sezioni degli Osservatori si riferisce la sua domanda: 'Blockchain', 'Payment' o 'Metaverse'. Nel fornire la risposta cita il report da cui la hai ottenuta.
Per le risposte successive, utilizza la cronologia della chat o il contesto fornito per aiutare l'utente a ottenere una risposta accurata. 
Non rispondere mai a domande che non sono pertinenti a questi argomenti.""", label="System message"),
        gr.Slider(minimum=1, maximum=4096, value=3048, step=1, label="Max tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=1.2, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p",
        ),
        gr.Slider(
            minimum=0,
            maximum=100,
            value=30,
            step=1,
            label="Top-k",
        ),
        gr.Slider(
            minimum=0.0,
            maximum=2.0,
            value=1.1,
            step=0.1,
            label="Repetition penalty",
            
        ),

    ],
    retry_btn="Riprova",
    undo_btn="Annulla",
    clear_btn="Riavvia chat",
    submit_btn="Invia",
    title="Odi, l'assistente ricercatore degli Osservatori",
    chatbot=gr.Chatbot(
        scale=1,
        likeable=False,
        show_copy_button=True
    ),
    examples=[["Ciao, in cosa puoi aiutarmi?"],["Quanto vale il mercato italiano?"], ["Per favore dammi informazioni sugli ambiti applicativi"], ["Chi è Francesco Bruschi?"], ["Svelami una buona ricetta milanese"] ],
    cache_examples=False,
)

if __name__ == "__main__":
    demo.launch()