Spaces:
Runtime error
Runtime error
File size: 6,202 Bytes
12e1362 0f44a4a 12e1362 3b4d478 12e1362 598e82e 12e1362 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import spaces
import os
import gradio as gr
from models import download_models
from rag_backend import Backend
from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
from llama_cpp_agent.providers import LlamaCppPythonProvider
from llama_cpp_agent.chat_history import BasicChatHistory
from llama_cpp_agent.chat_history.messages import Roles
import cv2
# get the models
huggingface_token = os.environ.get('HF_TOKEN')
download_models(huggingface_token)
documents_paths = {
'blockchain': 'data/blockchain',
'metaverse': 'data/metaverse',
'payment': 'data/payment'
}
# initialize backend (not ideal as global variable...)
backend = Backend()
cv2.setNumThreads(1)
@spaces.GPU(duration=20)
def respond(
message,
history: list[tuple[str, str]],
model,
system_message,
max_tokens,
temperature,
top_p,
top_k,
repeat_penalty,
):
chat_template = MessagesFormatterType.GEMMA_2
print("HISTORY SO FAR ", history)
matched_path = None
words = message.lower()
for key, path in documents_paths.items():
if len(history) == 1 and key in words: # check if the user mentions a path word only during second interaction (i.e history has only one entry)
matched_path = path
break
print("matched_path", matched_path)
if matched_path: # this case would only be true in second interaction
original_message = history[0][0]
print("** matched path!!")
query_engine = backend.create_index_for_query_engine(matched_path)
message = backend.generate_prompt(query_engine, original_message)
gr.Info("Relevant context indexed from docs...")
elif (not matched_path) and (len(history) > 1):
print("Using context from storage db")
query_engine = backend.load_index_for_query_engine()
message = backend.generate_prompt(query_engine, message)
gr.Info("Relevant context extracted from db...")
# Load model only if it's not already loaded or if a new model is selected
if backend.llm is None or backend.llm_model != model:
try:
backend.load_model(model)
except Exception as e:
return f"Error loading model: {str(e)}"
provider = LlamaCppPythonProvider(backend.llm)
agent = LlamaCppAgent(
provider,
system_prompt=f"{system_message}",
predefined_messages_formatter_type=chat_template,
debug_output=True
)
settings = provider.get_provider_default_settings()
settings.temperature = temperature
settings.top_k = top_k
settings.top_p = top_p
settings.max_tokens = max_tokens
settings.repeat_penalty = repeat_penalty
settings.stream = True
messages = BasicChatHistory()
# add user and assistant messages to the history
for msn in history:
user = {'role': Roles.user, 'content': msn[0]}
assistant = {'role': Roles.assistant, 'content': msn[1]}
messages.add_message(user)
messages.add_message(assistant)
try:
stream = agent.get_chat_response(
message,
llm_sampling_settings=settings,
chat_history=messages,
returns_streaming_generator=True,
print_output=False
)
outputs = ""
for output in stream:
outputs += output
yield outputs
except Exception as e:
yield f"Error during response generation: {str(e)}"
demo = gr.ChatInterface(
fn=respond,
css="""
.gradio-container {
background-color: #B9D9EB;
color: #003366;
}""",
additional_inputs=[
gr.Dropdown([
'Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf',
'Mistral-Nemo-Instruct-2407-Q5_K_M.gguf',
'gemma-2-2b-it-Q6_K_L.gguf',
'openchat-3.6-8b-20240522-Q6_K.gguf',
'Llama-3-Groq-8B-Tool-Use-Q6_K.gguf',
'MiniCPM-V-2_6-Q6_K.gguf',
'llama-3.1-storm-8b-q5_k_m.gguf',
'orca-2-7b-patent-instruct-llama-2-q5_k_m.gguf'
],
value="gemma-2-2b-it-Q6_K_L.gguf",
label="Model"
),
gr.Textbox(value="""Solamente all'inizio, presentati come Odi, un assistente ricercatore italiano creato dagli Osservatori del Politecnico di Milano e specializzato nel fornire risposte precise e pertinenti solo ad argomenti di innovazione digitale.
Solo nella tua prima risposta, se non è chiaro, chiedi all'utente di indicare a quale di queste tre sezioni degli Osservatori si riferisce la sua domanda: 'Blockchain', 'Payment' o 'Metaverse'. Nel fornire la risposta cita il report da cui la hai ottenuta.
Per le risposte successive, utilizza la cronologia della chat o il contesto fornito per aiutare l'utente a ottenere una risposta accurata.
Non rispondere mai a domande che non sono pertinenti a questi argomenti.""", label="System message"),
gr.Slider(minimum=1, maximum=4096, value=3048, step=1, label="Max tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=1.2, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p",
),
gr.Slider(
minimum=0,
maximum=100,
value=30,
step=1,
label="Top-k",
),
gr.Slider(
minimum=0.0,
maximum=2.0,
value=1.1,
step=0.1,
label="Repetition penalty",
),
],
retry_btn="Riprova",
undo_btn="Annulla",
clear_btn="Riavvia chat",
submit_btn="Invia",
title="Odi, l'assistente ricercatore degli Osservatori",
chatbot=gr.Chatbot(
scale=1,
likeable=False,
show_copy_button=True
),
examples=[["Ciao, in cosa puoi aiutarmi?"],["Quanto vale il mercato italiano?"], ["Per favore dammi informazioni sugli ambiti applicativi"], ["Chi è Francesco Bruschi?"], ["Svelami una buona ricetta milanese"] ],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|