File size: 3,149 Bytes
5195c78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eacc14
5195c78
0eacc14
5195c78
 
 
 
 
 
0eacc14
5195c78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eacc14
 
 
 
 
5195c78
 
0eacc14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5195c78
 
0eacc14
5195c78
 
0eacc14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from transformers import pipeline
import gradio as gr
import json

# Initialize the pipeline with the new model
pipe = pipeline("text-generation", model="Blexus/Quble_test_model_v1_INSTRUCT_v1")

DATABASE_PATH = "database.json"

def load_database():
    try:
        with open(DATABASE_PATH, "r") as file:
            return json.load(file)
    except FileNotFoundError:
        return {}

def save_database(database):
    with open(DATABASE_PATH, "w") as file:
        json.dump(database, file)

def format_prompt(message, system, history):
    # Format prompt according to the new template
    prompt = f"SYSTEM: {system}\n<|endofsystem|>\n"
    for user_prompt, bot_response in history:
        prompt += f"USER: {user_prompt}\n\n\nASSISTANT: {bot_response}<|endoftext|>\n"
    prompt += f"USER: {message}\n\n\nASSISTANT:"
    return prompt

def generate(
    prompt, system, history, temperature=0.9, max_new_tokens=4096, top_p=0.9, repetition_penalty=1.2,
):
    database = load_database()  # Load the database
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    formatted_prompt = format_prompt(prompt, history)
    if formatted_prompt in database:
        response = database[formatted_prompt]
    else:
        # Use the pipeline to generate the response
        response = pipe(formatted_prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty)[0]["generated_text"]
        response_text = response.split("ASSISTANT:")[1].strip()  # Extract the assistant's response
        database[formatted_prompt] = response_text
        save_database(database)  # Save the updated database

    yield response_text

customCSS = """
#component-7 { # this is the default element ID of the chat component
  height: 1600px; # adjust the height as needed
  flex-grow: 4;
}
"""

additional_inputs=[
    gr.TextBox(
        label="System prompt",
        value="You are a helpful assistant, with no access to external functions.",
        info="System prompt",
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=1024,
        minimum=64,
        maximum=4096,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.ChatInterface(
        generate,
        additional_inputs=additional_inputs,
    )

demo.queue().launch(debug=True)