from typing import Any, Callable, List, Tuple import huggingface_hub from dataclasses import dataclass from datetime import datetime from time import sleep import inspect import ast from random import randint from urllib.parse import quote from black import Mode, format_str import gradio as gr from huggingface_hub import InferenceClient from pymongo import MongoClient from constants import * from config import DemoConfig from tools import Tools @dataclass class Function: name: str short_description: str description_function: Callable[[Any], str] explanation_function: Callable[[Any], str] FUNCTIONS = [ Function( name="get_current_location", short_description="Finding your city", description_function=lambda *_, **__: "Finding your city", explanation_function=lambda result: f"Found you in {result}!", ), Function( name="sort_results", short_description="Sorting results", description_function=lambda places, sort, descending=True, first_n=None: f"Sorting results by {sort} from " + ("lowest to highest" if not descending else "highest to lowest"), explanation_function=lambda result: "Done!", ), Function( name="get_latitude_longitude", short_description="Convert to coordinates", description_function=lambda location: f"Converting {location} into latitude and longitude coordinates", explanation_function=lambda result: "Converted!", ), Function( name="get_distance", short_description="Calcuate distance", description_function=lambda place_1, place_2: "Calculating distances", explanation_function=lambda result: result[2], ), Function( name="get_recommendations", short_description="Read recommendations", description_function=lambda topics, **__: f"Reading recommendations for the following " + ( f"topics: {', '.join(topics)}" if len(topics) > 1 else f"topic: {topics[0]}" ), explanation_function=lambda result: f"Read {len(result)} recommendations", ), Function( name="find_places_near_location", short_description="Look for places", description_function=lambda type_of_place, location, radius_miles=50: f"Looking for places near {location} within {radius_miles} with the following " + ( f"types: {', '.join(type_of_place)}" if isinstance(type_of_place, list) else f"type: {type_of_place}" ), explanation_function=lambda result: f"Found " + (f"{len(result)} places!" if len(result) > 1 else f"1 place!"), ), Function( name="get_some_reviews", short_description="Fetching reviews", description_function=lambda place_names, **_: f"Fetching reviews for the requested items", explanation_function=lambda result: f"Fetched {len(result)} reviews!", ), Function( name="out_of_domain", short_description="The provided query does not relate to locations, reviews, or recommendations.", description_function=lambda user_query : "Irrelevant query detected.", explanation_function = lambda user_query : "Irrelevant query detected." ) ] class FunctionsHelper: FUNCTION_DEFINITION_TEMPLATE = '''Function: def {name}{signature}: """ {docstring} """ ''' PROMPT_TEMPLATE = \ """ {function_definitions} Example: User Query: I am driving from Austin to San Antonio, passing through San Marcos, give me good food for each stop. Call: sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="Austin")), sort="rating"); sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="San Marcos")), sort="rating"); sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="San Antonio")), sort="rating") Example: User Query: What's the nearest I need to walk to get some food near Stanford University? Call: sort_results(places=get_recommendations(topics=["food"], lat_long=get_latitude_longitude(location="Stanford University")), sort="distance") Example: User Query: Can you tell me if I should go to San Francisco or San Jose for high end Japanese food? Call: sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Francisco")), sort="rating"); sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Jose")), sort="rating"); get_some_reviews(place_names=sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Francisco")), sort="rating")); get_some_reviews(place_names=sort_results(places=get_recommendations(topics=["high-end food", "Japanese"], lat_long=get_latitude_longitude(location="San Jose")), sort="rating")) Example: User Query: What's this bird app stuff? Call: out_of_domain(user_query="What's this bird app stuff") Example: User Query: What is your political affiliation? Call: out_of_domain(user_query="What is your political affiliation?") Example: User Query: What are people saying about Chipotle in Austin? Call: get_some_reviews(topics=["Chipotle"], lat_long=get_latitude_longitude(location="Austin")) Example: User Query: What are people saying about some of the better Chipotles in Austin? Call: get_some_reviews(place_names=sort_results(places=get_recommendations(topics=["Chipotles"], lat_long=get_latitude_longitude(location="Austin")), sort="rating")) User Query: {query} Call:""" def __init__(self, tools: Tools) -> None: self.tools = tools function_definitions = "" for function in FUNCTIONS: f = getattr(tools, function.name) signature = inspect.signature(f) docstring = inspect.getdoc(f) function_str = self.FUNCTION_DEFINITION_TEMPLATE.format( name=function.name, signature=signature, docstring=docstring ) function_definitions += function_str self.prompt_without_query = self.PROMPT_TEMPLATE.format( function_definitions=function_definitions, query="{query}" ) def get_prompt(self, query: str): return self.prompt_without_query.format(query=query) def get_function_call_plan(self, function_call_str: str) -> List[str]: function_call_list = [] locals_to_pass = {"function_call_list": function_call_list} for f in FUNCTIONS: name = f.name exec( f"def {name}(**_):\n\tfunction_call_list.append('{f.short_description}')", locals_to_pass, ) calls = [c.strip() for c in function_call_str.split(";") if c.strip()] [eval(call, locals_to_pass) for call in calls] return function_call_list def run_function_call(self, function_call_str: str): function_call_list = [] locals_to_pass = {"function_call_list": function_call_list, "tools": self.tools} for f in FUNCTIONS: name = f.name locals_to_pass[f"{name}_description_function"] = f.description_function locals_to_pass[f"{name}_explanation_function"] = f.explanation_function function_definition = f""" def {name}(**kwargs): result = tools.{f.name}(**kwargs) function_call_list.append(({name}_description_function(**kwargs), {name}_explanation_function(result))) return result """ exec(function_definition, locals_to_pass) calls = [c.strip() for c in function_call_str.split(";") if c.strip()] for call in calls: locals_to_pass["function_call_list"] = function_call_list = [] result = eval(call, locals_to_pass) yield result, function_call_list class RavenDemo(gr.Blocks): def __init__(self, config: DemoConfig) -> None: theme = gr.themes.Soft( primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.blue, ) super().__init__(theme=theme, css=CSS, title="NexusRaven V2 Demo") self.config = config self.tools = Tools(config) self.functions_helper = FunctionsHelper(self.tools) mongo_client = MongoClient(host=config.mongo_endpoint) self.collection = mongo_client[config.mongo_collection]["logs"] self.raven_client = InferenceClient( model=config.raven_endpoint, token=config.hf_token ) self.summary_model_client = InferenceClient(config.summary_model_endpoint) self.max_num_steps = 20 self.function_call_name_set = set([f.name for f in FUNCTIONS]) with self: gr.HTML(HEADER_HTML) with gr.Row(): gr.Image( "NexusRaven.png", show_label=False, show_share_button=True, min_width=200, scale=1, ) with gr.Column(scale=4, min_width=800): gr.Markdown(INTRO_TEXT, elem_classes="inner-large-font") with gr.Row(): examples = [ gr.Button(query_name) for query_name in EXAMPLE_QUERIES ] user_input = gr.Textbox( placeholder="Ask anything about places, recommendations, or reviews!", show_label=False, autofocus=True, ) should_chat = gr.Checkbox(label="Enable Chat Summary", info="If set, summarizes the returned results.", value=True) raven_function_call = gr.Code( label="🐦‍⬛ NexusRaven V2 13B zero-shot generated function call", language="python", interactive=False, lines=10, ) with gr.Accordion( "Executing plan generated by 🐦‍⬛ NexusRaven V2 13B", open=True ) as steps_accordion: steps = [ gr.Textbox(visible=False, show_label=False) for _ in range(self.max_num_steps) ] with gr.Column(): initial_relevant_places = self.get_relevant_places([]) relevant_places = gr.State(initial_relevant_places) place_dropdown_choices = self.get_place_dropdown_choices( initial_relevant_places ) places_dropdown = gr.Dropdown( choices=place_dropdown_choices, value=place_dropdown_choices[0], label="Relevant places", ) gmaps_html = gr.HTML(self.get_gmaps_html(initial_relevant_places[0])) summary_model_summary = gr.Textbox( label="Chat summary", interactive=False, show_copy_button=True, lines=10, max_lines=1000, autoscroll=False, elem_classes="inner-large-font", ) with gr.Accordion("Raven inputs", open=False): gr.Textbox( label="Available functions", value="`" + "`, `".join(f.name for f in FUNCTIONS) + "`", interactive=False, show_copy_button=True, ) gr.Textbox( label="Raven prompt", value=self.functions_helper.get_prompt("{query}"), interactive=False, show_copy_button=True, lines=20, ) has_error = gr.State(False) user_input.submit( fn=self.on_submit, inputs=[user_input, should_chat], outputs=[ user_input, raven_function_call, summary_model_summary, relevant_places, places_dropdown, gmaps_html, steps_accordion, *steps, has_error, ], concurrency_limit=20, # not a hyperparameter api_name=False, ).then( self.check_for_error, inputs=has_error, outputs=[], ) for i, button in enumerate(examples): button.click( fn=EXAMPLE_QUERIES.get, inputs=button, outputs=user_input, api_name=f"button_click_{i}", ) places_dropdown.input( fn=self.get_gmaps_html_from_dropdown, inputs=[places_dropdown, relevant_places], outputs=gmaps_html, ) def on_submit(self, query: str, should_chat : bool, request: gr.Request): def get_returns(): return ( user_input, raven_function_call, summary_model_summary, relevant_places, places_dropdown, gmaps_html, steps_accordion, *steps, has_error, ) def on_error(): initial_return[0] = gr.Textbox(interactive=True, autofocus=False) initial_return[-1] = True return initial_return user_input = gr.Textbox(interactive=False) raven_function_call = "" summary_model_summary = "" relevant_places = [] places_dropdown = "" gmaps_html = "" steps_accordion = gr.Accordion(open=True) steps = [gr.Textbox(value="", visible=False) for _ in range(self.max_num_steps)] has_error = False initial_return = list(get_returns()) yield initial_return raven_prompt = self.functions_helper.get_prompt( query.replace("'", r"\'").replace('"', r"\"") ) print(f"{'-' * 80}\nPrompt sent to Raven\n\n{raven_prompt}\n\n{'-' * 80}\n") stream = self.raven_client.text_generation( raven_prompt, **RAVEN_GENERATION_KWARGS ) for s in stream: for c in s: raven_function_call += c raven_function_call = raven_function_call.removesuffix("Thought:").removesuffix("") yield get_returns() raw_raven_response = raven_function_call print(f"Raw Raven response before formatting: {raw_raven_response}") r_calls = [c.strip() for c in raven_function_call.split(";") if c.strip()] f_r_calls = [] for r_c in r_calls: try: f_r_call = format_str(r_c.strip(), mode=Mode()) except: yield on_error() return if not self.whitelist_function_names(f_r_call): yield on_error() return f_r_calls.append(f_r_call) raven_function_call = "; ".join(f_r_calls) yield get_returns() self._set_client_ip(request) function_call_plan = self.functions_helper.get_function_call_plan( raven_function_call ) for i, v in enumerate(function_call_plan): steps[i] = gr.Textbox(value=f"{i+1}. {v}", visible=True) yield get_returns() sleep(0.1) results_gen = self.functions_helper.run_function_call(raven_function_call) results = [] previous_num_calls = 0 for result, function_call_list in results_gen: results.extend(result) for i, (description, explanation) in enumerate(function_call_list): i = i + previous_num_calls if len(description) > 100: description = function_call_plan[i] to_stream = f"{i+1}. {description} ..." steps[i] = "" for c in to_stream: steps[i] += c sleep(0.005) yield get_returns() to_stream = "." * randint(0, 5) for c in to_stream: steps[i] += c sleep(0.2) yield get_returns() to_stream = f" {explanation}" for c in to_stream: steps[i] += c sleep(0.005) yield get_returns() previous_num_calls += len(function_call_list) try: relevant_places = self.get_relevant_places(results) except: relevant_places = self.get_relevant_places([]) gmaps_html = self.get_gmaps_html(relevant_places[0]) places_dropdown_choices = self.get_place_dropdown_choices(relevant_places) places_dropdown = gr.Dropdown( choices=places_dropdown_choices, value=places_dropdown_choices[0] ) steps_accordion = gr.Accordion(open=False) yield get_returns() while True and should_chat: try: summary_model_prompt = self.get_summary_model_prompt(results, query) print( f"{'-' * 80}\nPrompt sent to summary model\n\n{summary_model_prompt}\n\n{'-' * 80}\n" ) stream = self.summary_model_client.text_generation( summary_model_prompt, **SUMMARY_MODEL_GENERATION_KWARGS ) for s in stream: s = s.removesuffix("") for c in s: summary_model_summary += c summary_model_summary = ( summary_model_summary.lstrip().removesuffix( "" ) ) yield get_returns() except huggingface_hub.inference._text_generation.ValidationError: if len(results) > 1: new_length = (3 * len(results)) // 4 results = results[:new_length] continue else: break break self.collection.insert_one( { "query": query, "raven_output": raw_raven_response, "summary_output": summary_model_summary, } ) user_input = gr.Textbox(interactive=True, autofocus=False) yield get_returns() def check_for_error(self, has_error: bool) -> None: if has_error: raise gr.Error(ERROR_MESSAGE) def whitelist_function_names(self, function_call_str: str) -> bool: """ Defensive function name whitelisting inspired by @evan-nexusflow """ for expr in ast.walk(ast.parse(function_call_str)): if not isinstance(expr, ast.Call): continue expr: ast.Call function_name = expr.func.id if function_name not in self.function_call_name_set: return False return True def get_summary_model_prompt(self, results: List, query: str) -> None: # TODO check what outputs are returned and return them properly ALLOWED_KEYS = [ "author_name", "text", "for_location", "time", "author_url", "language", "original_language", "name", "opening_hours", "rating", "user_ratings_total", "vicinity", "distance", "formatted_address", "price_level", "types", ] ALLOWED_KEYS = set(ALLOWED_KEYS) results_str = "" for idx, res in enumerate(results): if isinstance(res, str): results_str += f"{res}\n" continue assert isinstance(res, dict) item_str = "" for key, value in res.items(): if key not in ALLOWED_KEYS: continue key = key.replace("_", " ").capitalize() item_str += f"\t{key}: {value}\n" results_str += f"Result {idx + 1}\n{item_str}\n" current_time = datetime.now().strftime("%b %d, %Y %H:%M:%S") try: current_location = self.tools.get_current_location()[0] except: current_location = "Current location not found." prompt = SUMMARY_MODEL_PROMPT.format( current_location=current_location, current_time=current_time, results=results_str, query=query, ) return prompt def get_relevant_places(self, results: List) -> List[Tuple[str, str]]: """ Returns ------- relevant_places: List[Tuple[str, str]] A list of tuples, where each tuple is (address, name) """ # We use a dict to preserve ordering, while enforcing uniqueness relevant_places = dict() for result in results: if "formatted_address" in result and "name" in result: relevant_places[(result["formatted_address"], result["name"])] = None elif "formatted_address" in result and "for_location" in result: relevant_places[ (result["formatted_address"], result["for_location"]) ] = None elif "vicinity" in result and "name" in result: relevant_places[(result["vicinity"], result["name"])] = None relevant_places = list(relevant_places.keys()) if not relevant_places: current_location = self.tools.get_current_location()[0] relevant_places.append((current_location, current_location)) return relevant_places def get_place_dropdown_choices( self, relevant_places: List[Tuple[str, str]] ) -> List[str]: return [p[1] for p in relevant_places] def get_gmaps_html(self, relevant_place: Tuple[str, str]) -> str: address, name = relevant_place return GMAPS_EMBED_HTML_TEMPLATE.format( address=quote(address), location=quote(name) ) def get_gmaps_html_from_dropdown( self, place_name: str, relevant_places: List[Tuple[str, str]] ) -> str: relevant_place = [p for p in relevant_places if p[1] == place_name][0] return self.get_gmaps_html(relevant_place) def _set_client_ip(self, request: gr.Request) -> None: client_ip = request.client.host if ( "headers" in request.kwargs and "x-forwarded-for" in request.kwargs["headers"] ): x_forwarded_for = request.kwargs["headers"]["x-forwarded-for"] else: x_forwarded_for = request.headers.get("x-forwarded-for", None) if x_forwarded_for: client_ip = x_forwarded_for.split(",")[0].strip() self.tools.client_ip = client_ip demo = RavenDemo(DemoConfig.load_from_env()) if __name__ == "__main__": demo.launch( share=True, allowed_paths=["logo.png", "NexusRaven.png"], favicon_path="logo.png", )