File size: 3,631 Bytes
352a6c0
 
75c74b0
cc932be
 
 
 
7b5b897
 
cc932be
 
94559fc
cc932be
 
 
64b4ed5
c0252bb
cc932be
 
 
352a6c0
96a08ea
 
94559fc
96a08ea
 
 
352a6c0
94559fc
352a6c0
 
 
 
 
 
 
 
96a08ea
 
b15fb69
 
b9d96b3
b15fb69
 
96a08ea
 
 
 
 
 
 
 
ad8bce1
 
c6e9b1e
 
 
273fe29
75c74b0
96a08ea
 
 
 
4329549
96a08ea
4329549
c0252bb
 
96a08ea
 
 
 
 
 
75c74b0
43c5e78
96a08ea
75c74b0
 
 
cc932be
96a08ea
352a6c0
cc932be
 
352a6c0
 
 
 
 
 
 
 
 
 
 
cc932be
352a6c0
 
 
 
 
 
 
 
7b5b897
4329549
c0252bb
352a6c0
 
 
4329549
352a6c0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr

import os, sys
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline
import torch

# Define the model repository
# REPO_NAME = 'schuler/experimental-JP47D20'
REPO_NAME = 'schuler/experimental-JP47D21-KPhi-3-micro-4k-instruct'

# How to cache?
@spaces.GPU()
def load_model(repo_name):
    tokenizer = AutoTokenizer.from_pretrained(repo_name, trust_remote_code=True)
    generator_conf = GenerationConfig.from_pretrained(repo_name)
    model = AutoModelForCausalLM.from_pretrained(repo_name, trust_remote_code=True, torch_dtype=torch.bfloat16, attn_implementation="eager")
    model.to('cuda')
    return tokenizer, generator_conf, model

tokenizer, generator_conf, model = load_model(REPO_NAME)

global_error = ''
try:
    @spaces.GPU()
    generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
    global_error =  f"Failed to load model: {str(e)}"

@spaces.GPU()
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    result = 'none'
    try:
        # Build the conversation prompt
        prompt = ''
        messages = []
        if (len(system_message)>0):
            prompt = "<|assistant|>"+system_message+f"<|end|>\n"    
        for val in history:
            if val[0]:
                messages.append({"role": "user", "content": val[0]})
            if val[1]:
                messages.append({"role": "assistant", "content": val[1]})
    
        messages.append({"role": "user", "content": message})
    
        for hmessage in messages:
            role = "<|assistant|>" if hmessage['role'] == 'assistant' else "<|user|>"
            prompt += f"{role}{hmessage['content']}<|end|>"        
        # prompt += f"<|user|>{message}<|end|><|assistant|>"
        prompt += f"<|assistant|>"

        # """
        # Generate the response
        response_output = generator(
            prompt,
            generation_config=generator_conf,
            max_new_tokens=max_tokens,
            do_sample=True,
            top_p=top_p,
            repetition_penalty=1.2,
            temperature=temperature
        )
    
        generated_text = response_output[0]['generated_text']
    
        # Extract the assistant's response
        result = generated_text[len(prompt):].strip()
        # """
        # result = prompt +':'+result
    except Exception as error:
        exc_type, exc_obj, exc_tb = sys.exc_info()
        fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
        result = str(error) +':'+ exc_type +':'+ fname +':'+ exc_tb.tb_lineno

    yield result

    
    """
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response
    """


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="" + global_error, label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=1.0, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.25,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch()