import os import json import functools as ft import collections as cl from pathlib import Path import gradio as gr from openai import OpenAI from mylib import ( FileManager, ChatController, MessageHandler, NumericCitations, ) # # # ChatState = cl.namedtuple('ChatState', 'database, messenger, chat') @ft.cache def scancfg(): with open(os.getenv('FILE_CHAT_CONFIG')) as fp: return json.load(fp) # # # def load(): config = scancfg() (_openai, _chat) = map(config.get, ('openai', 'chat')) client = OpenAI(api_key=_openai['api_key']) database = FileManager(client, _chat['prefix']) messenger = MessageHandler(client, NumericCitations) chat = ChatController(client, database, _openai, _chat) return ChatState(database, messenger, chat) def eject(state): state.database.cleanup() state.chat.cleanup() def upload(data, state): try: return state.database(data) except InterruptedError as err: raise gr.Error(str(err)) def prompt(message, history, state): if state.database: response = state.messenger(state.chat(message)) history.append(( message, response, )) else: gr.Warning('Please upload your documents to begin') return ( # textbox submit outputs '', # clear the input text history, # update the chat output ) # # # with gr.Blocks() as demo: state = gr.State( value=load, delete_callback=eject, ) howto = Path('static/howto').with_suffix('.md') with gr.Row(): with gr.Accordion(label='Instructions', open=False): gr.Markdown(howto.read_text()) with gr.Row(): with gr.Column(): data = gr.UploadButton( label='Select and upload your files', file_count='multiple', ) repository = gr.Textbox( label='Files uploaded', placeholder='Upload your files to begin!', interactive=False, ) data.upload( fn=upload, inputs=[ data, state, ], outputs=repository, ) with gr.Column(scale=2): chatbot = gr.Chatbot( height='70vh', show_copy_button=True, ) chatbot.change(scroll_to_output=True) interaction = gr.Textbox( label='Ask a question about your documents and press "Enter"', ) interaction.submit( fn=prompt, inputs=[ interaction, chatbot, state, ], outputs=[ interaction, chatbot, ], ) if __name__ == '__main__': kwargs = scancfg().get('gradio') demo.queue().launch(server_name='0.0.0.0', **kwargs)