Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| from time import sleep | |
| import gradio as gr | |
| import uvicorn | |
| from datetime import datetime | |
| from typing import List, Tuple | |
| from starlette.config import Config | |
| from starlette.middleware.sessions import SessionMiddleware | |
| from starlette.responses import RedirectResponse | |
| from authlib.integrations.starlette_client import OAuth, OAuthError | |
| from fastapi import FastAPI, Request | |
| from shared import Client, User, OAuthProvider | |
| app = FastAPI() | |
| config = {} | |
| clients = {} | |
| llm_host_names = [] | |
| oauth = None | |
| def init_oauth(): | |
| global oauth | |
| google_client_id = os.environ.get("GOOGLE_CLIENT_ID") | |
| google_client_secret = os.environ.get("GOOGLE_CLIENT_SECRET") | |
| secret_key = os.environ.get('SECRET_KEY') or "a_very_secret_key" | |
| starlette_config = Config(environ={"GOOGLE_CLIENT_ID": google_client_id, | |
| "GOOGLE_CLIENT_SECRET": google_client_secret}) | |
| oauth = OAuth(starlette_config) | |
| oauth.register( | |
| name='google', | |
| server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', | |
| client_kwargs={'scope': 'openid email profile'} | |
| ) | |
| app.add_middleware(SessionMiddleware, secret_key=secret_key) | |
| def init_config(): | |
| """ | |
| Initialize configuration. A configured `api_url` or `api_key` may be an | |
| envvar reference OR a literal value. Configuration should follow the | |
| format: | |
| {"<llm_host_name>": {"api_key": "<api_key>", | |
| "api_url": "<api_url>" | |
| } | |
| } | |
| """ | |
| global config | |
| global clients | |
| global llm_host_names | |
| config = json.loads(os.environ['CONFIG']) | |
| client_config = config.get("clients") or config | |
| for name in client_config: | |
| model_personas = client_config[name].get("personas", {}) | |
| client = Client( | |
| api_url=os.environ.get(client_config[name]['api_url'], | |
| client_config[name]['api_url']), | |
| api_key=os.environ.get(client_config[name]['api_key'], | |
| client_config[name]['api_key']), | |
| personas=model_personas | |
| ) | |
| clients[name] = client | |
| llm_host_names = list(client_config.keys()) | |
| def get_allowed_models(user: User) -> List[str]: | |
| """ | |
| Get a list of allowed endpoints for a specified user domain. Allowed domains | |
| are configured in each model's configuration and may optionally be overridden | |
| in the Gradio demo configuration. | |
| :param user: User to get permissions for | |
| :return: List of allowed endpoints from configuration (including empty | |
| strings for disallowed endpoints) | |
| """ | |
| overrides = config.get("permissions_override", {}) | |
| allowed_endpoints = [] | |
| for client in clients: | |
| permission = overrides.get(client, | |
| clients[client].config.inference.permissions) | |
| if not permission: | |
| # Permissions not specified (None or empty dict); model is public | |
| allowed_endpoints.append(client) | |
| elif user.oauth == OAuthProvider.GOOGLE and user.permissions_id in \ | |
| permission.get("google_domains", []): | |
| # Google oauth domain is in the allowed domain list | |
| allowed_endpoints.append(client) | |
| else: | |
| allowed_endpoints.append("") | |
| print(f"No permission to access {client}") | |
| return allowed_endpoints | |
| def parse_radio_select(radio_select: tuple) -> (str, str): | |
| """ | |
| Parse radio selection to determine the requested model and persona | |
| :param radio_select: List of radio selection states | |
| :return: Selected model, persona | |
| """ | |
| value_index = next(i for i in range(len(radio_select)) if radio_select[i] is not None) | |
| model = llm_host_names[value_index] | |
| persona = radio_select[value_index] | |
| return model, persona | |
| def get_login_button(request: gr.Request) -> gr.Button: | |
| """ | |
| Get a login/logout button based on current login status | |
| :param request: Gradio request to evaluate | |
| :return: Button for either login or logout action | |
| """ | |
| user = get_user(request).username | |
| print(f"Getting login button for {user}") | |
| if user == "guest": | |
| return gr.Button("Login", link="/login") | |
| else: | |
| return gr.Button(f"Logout {user}", link="/logout") | |
| def get_user(request: Request) -> User: | |
| """ | |
| Get a unique user email address for the specified request | |
| :param request: FastAPI Request object with user session data | |
| :return: String user email address or "guest" | |
| """ | |
| # {'iss': 'https://accounts.google.com', | |
| # 'azp': '***.apps.googleusercontent.com', | |
| # 'aud': '***.apps.googleusercontent.com', | |
| # 'sub': '###', | |
| # 'hd': 'neon.ai', | |
| # 'email': '[email protected]', | |
| # 'email_verified': True, | |
| # 'at_hash': '***', | |
| # 'nonce': '***', | |
| # 'name': 'Daniel McKnight', | |
| # 'picture': 'https://lh3.googleusercontent.com/a/***', | |
| # 'given_name': '***', | |
| # 'family_name': '***', | |
| # 'iat': ###, | |
| # 'exp': ###} | |
| if not request: | |
| return User(OAuthProvider.NONE, "guest", "") | |
| user_dict = request.session.get("user", {}) | |
| if user_dict.get("iss") == "https://accounts.google.com": | |
| user = User(OAuthProvider.GOOGLE, user_dict["email"], user_dict["hd"]) | |
| elif user_dict: | |
| print(f"Unknown user session data: {user_dict}") | |
| user = User(OAuthProvider.NONE, "guest", "") | |
| else: | |
| user = User(OAuthProvider.NONE, "guest", "") | |
| print(user) | |
| return user | |
| async def logout(request: Request): | |
| """ | |
| Remove the user session context and reload an un-authenticated session | |
| :param request: FastAPI Request object with user session data | |
| :return: Redirect to `/` | |
| """ | |
| request.session.pop('user', None) | |
| return RedirectResponse(url='/') | |
| async def login(request: Request): | |
| """ | |
| Start oauth flow for login with Google | |
| :param request: FastAPI Request object | |
| """ | |
| redirect_uri = request.url_for('auth') | |
| # Ensure that the `redirect_uri` is https | |
| from urllib.parse import urlparse, urlunparse | |
| redirect_uri = urlunparse(urlparse(str(redirect_uri))._replace(scheme='https')) | |
| return await oauth.google.authorize_redirect(request, redirect_uri) | |
| async def auth(request: Request): | |
| """ | |
| Callback endpoint for Google oauth | |
| :param request: FastAPI Request object | |
| """ | |
| try: | |
| access_token = await oauth.google.authorize_access_token(request) | |
| except OAuthError: | |
| return RedirectResponse(url='/') | |
| request.session['user'] = dict(access_token)["userinfo"] | |
| return RedirectResponse(url='/') | |
| def respond( | |
| message: str, | |
| history: List[Tuple[str, str]], | |
| conversational: bool, | |
| max_tokens: int, | |
| *radio_select, | |
| ): | |
| """ | |
| Send user input to a vLLM backend and return the generated response | |
| :param message: String input from the user | |
| :param history: Optional list of chat history (<user message>,<llm message>) | |
| :param conversational: If true, include chat history | |
| :param max_tokens: Maximum tokens for the LLM to generate | |
| :param radio_select: List of radio selection args to parse | |
| :return: String LLM response | |
| """ | |
| model, persona = parse_radio_select(radio_select) | |
| client = clients[model] | |
| messages = [] | |
| try: | |
| system_prompt = client.personas[persona] | |
| except KeyError: | |
| supported_personas = list(client.personas.keys()) | |
| raise gr.Error(f"Model '{model}' does not support persona '{persona}', only {supported_personas}") | |
| if system_prompt is not None: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| if conversational: | |
| for val in history[-2:]: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": message}) | |
| completion = client.openai.chat.completions.create( | |
| model=client.vllm_model_name, | |
| messages=messages, | |
| max_tokens=max_tokens, | |
| temperature=0, | |
| extra_body={ | |
| "add_special_tokens": True, | |
| "repetition_penalty": 1.05, | |
| "use_beam_search": True, | |
| "best_of": 5, | |
| }, | |
| ) | |
| response = completion.choices[0].message.content | |
| return response | |
| def get_model_options(request: gr.Request) -> List[gr.Radio]: | |
| """ | |
| Get allowed models for the specified session. | |
| :param request: Gradio request object to get user from | |
| :return: List of Radio objects for available models | |
| """ | |
| if request: | |
| # `user` is a valid Google email address or 'guest' | |
| user = get_user(request.request) | |
| else: | |
| user = User(OAuthProvider.NONE, "guest", "") | |
| print(f"Getting models for {user.username}") | |
| allowed_llm_host_names = get_allowed_models(user) | |
| radio_infos = [f"{name} ({clients[name].vllm_model_name})" | |
| if name in clients else "Not Authorized" | |
| for name in allowed_llm_host_names] | |
| # Components | |
| radios = [gr.Radio(choices=clients[name].personas.keys() if name in clients else [], | |
| value=None, label=info) for name, info | |
| in zip(allowed_llm_host_names, radio_infos)] | |
| # Select the first available option by default | |
| radios[0].value = list(clients[allowed_llm_host_names[0]].personas.keys())[0] | |
| print(f"Set default persona to {radios[0].value} for {allowed_llm_host_names[0]}") | |
| # # Ensure we always have the same number of rows | |
| # while len(radios) < len(llm_host_names): | |
| # radios.append(gr.Radio(choices=[], value=None, label="Not Authorized")) | |
| return radios | |
| def init_gradio() -> gr.Blocks: | |
| """ | |
| Initialize a Gradio demo | |
| :return: | |
| """ | |
| conversational_checkbox = gr.Checkbox(value=True, label="conversational") | |
| max_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=512, step=64, | |
| label="Max new tokens") | |
| radios = get_model_options(None) | |
| with gr.Blocks() as blocks: | |
| # Events | |
| radio_state = gr.State([radio.value for radio in radios]) | |
| def radio_click(state, *new_state): | |
| """ | |
| Handle any state changes that require re-rendering radio buttons | |
| :param state: Previous radio state representation (before selection) | |
| :param new_state: Current radio state (including selection) | |
| :return: Desired new state (current option selected, previous option | |
| deselected) | |
| """ | |
| # Login and model options are triggered on load. This sleep is just | |
| # a hack to make sure those events run before this logic to select | |
| # the default model | |
| sleep(0.1) | |
| try: | |
| changed_index = next(i for i in range(len(state)) | |
| if state[i] != new_state[i]) | |
| changed_value = new_state[changed_index] | |
| except StopIteration: | |
| # TODO: This is the result of some error in rendering a selected | |
| # option. | |
| # Changed to current selection | |
| changed_value = [i for i in new_state if i is not None][0] | |
| changed_index = new_state.index(changed_value) | |
| clean_state = [None if i != changed_index else changed_value | |
| for i in range(len(state))] | |
| return clean_state, *clean_state | |
| # Compile | |
| hf_config = config.get("huggingface_text") or dict() | |
| accordion_info = hf_config.get("accordian_info") or \ | |
| "Persona and LLM Options - Choose one:" | |
| version = hf_config.get("version") or \ | |
| f"v{datetime.now().strftime('%Y-%m-%d')}" | |
| title = hf_config.get("title") or \ | |
| f"Neon AI BrainForge Personas and Large Language Models ({version})" | |
| with gr.Accordion(label=accordion_info, open=True, | |
| render=False) as accordion: | |
| [radio.render() for radio in radios] | |
| conversational_checkbox.render() | |
| max_tokens_slider.render() | |
| _ = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| conversational_checkbox, | |
| max_tokens_slider, | |
| *radios, | |
| ], | |
| additional_inputs_accordion=accordion, | |
| title=title, | |
| concurrency_limit=5, | |
| ) | |
| # Render login/logout button | |
| login_button = gr.Button("Log In") | |
| blocks.load(get_login_button, None, login_button) | |
| accordion.render() | |
| blocks.load(get_model_options, None, radios) | |
| return blocks | |
| if __name__ == "__main__": | |
| init_config() | |
| init_oauth() | |
| blocks = init_gradio() | |
| app = gr.mount_gradio_app(app, blocks, '/') | |
| uvicorn.run(app, host='0.0.0.0', port=7860) | |