Spaces:
Sleeping
Sleeping
File size: 4,930 Bytes
631f6af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# HF libraries
from langchain_community.llms import HuggingFaceHub
from langchain_core.prompts import ChatPromptTemplate
from langchain import hub
import gradio as gr
from langchain.agents import AgentExecutor
from langchain.agents.format_scratchpad import format_log_to_str
from langchain.agents.output_parsers import (
ReActJsonSingleInputOutputParser,
)
# Import things that are needed generically
from typing import List, Dict
from langchain.tools.render import render_text_description
import os
from dotenv import load_dotenv
from innovation_pathfinder_ai.structured_tools.structured_tools import (
arxiv_search, get_arxiv_paper, google_search, wikipedia_search
)
# hacky and should be replaced with a database
from innovation_pathfinder_ai.source_container.container import (
all_sources
)
# from langchain_community.chat_message_histories import ChatMessageHistory
# from langchain_core.runnables.history import RunnableWithMessageHistory
# message_history = ChatMessageHistory()
config = load_dotenv(".env")
HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
GOOGLE_CSE_ID = os.getenv('GOOGLE_CSE_ID')
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
LANGCHAIN_TRACING_V2 = "true"
LANGCHAIN_ENDPOINT = "https://api.smith.langchain.com"
LANGCHAIN_API_KEY = os.getenv('LANGCHAIN_API_KEY')
LANGCHAIN_PROJECT = os.getenv('LANGCHAIN_PROJECT')
# Load the model from the Hugging Face Hub
llm = HuggingFaceHub(repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1", model_kwargs={
"temperature":0.1,
"max_new_tokens":1024,
"repetition_penalty":1.2,
"return_full_text":False
})
tools = [
arxiv_search,
wikipedia_search,
google_search,
# get_arxiv_paper,
]
prompt = hub.pull("hwchase17/react-json")
prompt = prompt.partial(
tools=render_text_description(tools),
tool_names=", ".join([t.name for t in tools]),
)
# define the agent
chat_model_with_stop = llm.bind(stop=["\nObservation"])
agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_log_to_str(x["intermediate_steps"]),
}
| prompt
| chat_model_with_stop
| ReActJsonSingleInputOutputParser()
)
# instantiate AgentExecutor
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
max_iterations=10, # cap number of iterations
#max_execution_time=60, # timout at 60 sec
return_intermediate_steps=True,
handle_parsing_errors=True,
)
if __name__ == "__main__":
def add_text(history, text):
history = history + [(text, None)]
return history, ""
def bot(history):
response = infer(history[-1][0], history)
sources = collect_urls(all_sources)
src_list = '\n'.join(sources)
response_w_sources = response['output']+"\n\n\n Sources: \n\n\n"+src_list
history[-1][1] = response_w_sources
return history
def infer(question, history):
query = question
result = agent_executor.invoke(
{
"input": question,
}
)
return result
def vote(data: gr.LikeData):
if data.liked:
print("You upvoted this response: " + data.value)
else:
print("You downvoted this response: " + data.value)
def collect_urls(data_list):
urls = []
for item in data_list:
# Check if item is a string and contains 'link:'
if isinstance(item, str) and 'link:' in item:
start = item.find('link:') + len('link: ')
end = item.find(',', start)
url = item[start:end if end != -1 else None].strip()
urls.append(url)
# Check if item is a dictionary and has 'Entry ID'
elif isinstance(item, dict) and 'Entry ID' in item:
urls.append(item['Entry ID'])
return urls
css="""
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
"""
title = """
<div style="text-align: center;max-width: 700px;">
<p>Hello Dave, how can I help today?<br />
</div>
"""
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Tab("Google|Wikipedia|Arxiv"):
with gr.Column(elem_id="col-container"):
gr.HTML(title)
with gr.Row():
question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
chatbot = gr.Chatbot([], elem_id="chatbot")
chatbot.like(vote, None, None)
clear = gr.Button("Clear")
question.submit(add_text, [chatbot, question], [chatbot, question], queue=False).then(
bot, chatbot, chatbot
)
clear.click(lambda: None, None, chatbot, queue=False)
demo.queue()
demo.launch(debug=True)
x = 0 # for debugging purposes |