File size: 3,375 Bytes
352a6c0
 
cc932be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352a6c0
96a08ea
 
 
 
 
352a6c0
 
 
 
 
 
 
 
 
96a08ea
 
b15fb69
 
b9d96b3
b15fb69
 
96a08ea
 
 
 
 
 
 
 
 
b15fb69
 
b64a075
273fe29
 
96a08ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273fe29
b64a075
96a08ea
 
cc932be
96a08ea
352a6c0
cc932be
 
352a6c0
 
 
 
 
 
 
 
 
 
 
cc932be
352a6c0
 
 
 
 
 
 
 
96a08ea
352a6c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr

import os
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline
import torch

# Define the model repository
REPO_NAME = 'schuler/experimental-JP47D20'
# REPO_NAME = 'schuler/experimental-JP47D21-KPhi-3-micro-4k-instruct'

# How to cache?
def load_model(repo_name):
    tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
    generator_conf = GenerationConfig.from_pretrained(repo_name)
    model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16)
    return tokenizer, generator_conf, model

tokenizer, generator_conf, model = load_model(REPO_NAME)

global_error = ''
try:
    generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
    global_error =  f"Failed to load model: {str(e)}"

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    result = 'none'
    try:
        # Build the conversation prompt
        prompt = ''
        messages = []
        if (len(system_message)>0):
            prompt = "<|assistant|>"+system_message+f"<|end|>\n"    
        for val in history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})
    
        messages.append({"role": "user", "content": message})
    
        for message in messages:
            role = "<|assistant|>" if message['role'] == 'assistant' else "<|user|>"
            prompt += f"\n{role}\n{message['content']}\n<|end|>\n"        
        prompt += f"\n<|user|>\n{message}\n<|end|><|assistant|>\n"

        """
        # Generate the response
        response_output = generator(
            prompt,
            generation_config=generator_conf,
            max_new_tokens=64,
            do_sample=True,
            top_p=0.25,
            repetition_penalty=1.2
        )
    
        generated_text = response_output[0]['generated_text']
    
        # st.session_state.last_response = generated_text
    
        # Extract the assistant's response
        result = generated_text[len(prompt):].strip()
        """
        result = message+':'+prompt
    except Exception as error:
        result = str(error)

    yield result

    
    """
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response
    """


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot." + global_error, label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()