File size: 3,076 Bytes
ef3d4ad
529dafe
a5aec38
ef3d4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
529dafe
ef3d4ad
 
 
 
8336dcb
 
 
 
 
 
 
529dafe
8336dcb
529dafe
8336dcb
529dafe
8336dcb
 
 
 
529dafe
 
ef3d4ad
e946c57
 
32d4c53
e946c57
529dafe
a9490de
 
 
 
ef3d4ad
529dafe
a3c275c
 
 
 
 
 
 
 
529dafe
e776497
 
 
 
ef3d4ad
 
 
 
 
e946c57
 
 
 
a5aec38
 
 
 
 
ef3d4ad
a5aec38
529dafe
01851b4
 
 
 
529dafe
 
 
 
 
 
 
e946c57
 
 
 
529dafe
 
ef3d4ad
529dafe
8d48750
 
 
 
a9490de
24a698a
 
 
529dafe
 
e946c57
 
 
 
 
e776497
 
 
 
529dafe
ef3d4ad
 
e656628
1651184
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import collections as cl
from pathlib import Path

import gradio as gr
from openai import OpenAI

from mylib import (
    FileManager,
    ChatController,
    MessageHandler,
    NumericCitations,
)

#
#
#
ChatState = cl.namedtuple('ChatState', 'database, messenger, chat')

#
#
#
def open_ai_vars(prefix='OPENAI_KWARGS_'):
    n = len(prefix)
    for (k, v) in os.environ.items():
        if k.startswith(prefix):
            key = k[n:].lower()
            yield (key, v)

def load():
    client = OpenAI()

    database = FileManager(client)
    messenger = MessageHandler(client, NumericCitations)

    kwargs = dict(open_ai_vars())
    instructions = Path('static', 'system-prompt').with_suffix('.txt')
    chat = ChatController(client, database, instructions, **kwargs)

    return ChatState(database, messenger, chat)

def eject(state):
    state.database.cleanup()
    state.chat.cleanup()

def upload(data, state):
    try:
        return state.database(data)
    except InterruptedError as err:
        raise gr.Error(str(err))

def prompt(message, history, state):
    if state.database:
        response = state.messenger(state.chat(message))
        history.append((
            message,
            response,
        ))
    else:
        gr.Warning('Please upload your documents to begin')

    return (     # textbox submit outputs
        '',      # clear the input text
        history, # update the chat output
    )

#
#
#
with gr.Blocks() as demo:
    state = gr.State(
        value=load,
        delete_callback=eject,
    )
    howto = Path('static/howto').with_suffix('.md')

    with gr.Row():
        with gr.Accordion(label='Instructions', open=False):
            gr.Markdown(howto.read_text())
    with gr.Row():

        with gr.Column():
            data = gr.UploadButton(
                label='Select and upload your files',
                file_count='multiple',
            )
            repository = gr.Textbox(
                label='Files uploaded',
                placeholder='Upload your files to begin!',
                interactive=False,
            )
            data.upload(
                fn=upload,
                inputs=[
                    data,
                    state,
                ],
                outputs=repository,
            )

        with gr.Column(scale=2):
            chatbot = gr.Chatbot(
                height='70vh',
                show_copy_button=True,
            )
            chatbot.change(scroll_to_output=True)
            interaction = gr.Textbox(
                label='Ask a question about your documents and press "Enter"',
            )
            interaction.submit(
                fn=prompt,
                inputs=[
                    interaction,
                    chatbot,
                    state,
                ],
                outputs=[
                    interaction,
                    chatbot,
                ],
            )

if __name__ == '__main__':
    auth = tuple(os.getenv(f'GRADIO_{x}') for x in ('USER', 'PASS'))
    demo.launch(auth=auth)