File size: 6,494 Bytes
cc5b602
6952a60
6f619d7
 
 
7eeaa8f
ab33f5f
6f619d7
 
aa2a2cc
 
 
6f619d7
 
 
 
 
7eeaa8f
6952a60
7eeaa8f
 
 
aa2a2cc
7eeaa8f
 
aa2a2cc
 
 
 
5e79225
417f21a
a9fe0e7
 
7eeaa8f
6f619d7
7eeaa8f
85585d6
398b913
51a7d9e
29c0142
86bea01
51a7d9e
 
e6367a7
86bea01
51a7d9e
bd34f0b
 
86bea01
bd34f0b
7eeaa8f
bd34f0b
 
 
51a7d9e
 
 
bd34f0b
 
 
 
 
 
 
51a7d9e
 
7eeaa8f
 
 
aa2a2cc
 
7eeaa8f
 
 
9cfb768
7eeaa8f
 
 
29a2985
 
 
 
8d17362
7eeaa8f
2272289
 
 
 
7eeaa8f
 
2272289
3518617
7eeaa8f
3518617
7eeaa8f
d2fff9f
7eeaa8f
1cf4e84
 
7eeaa8f
 
 
d2fff9f
ab33f5f
0e2883b
fc09eb0
 
 
 
 
 
 
 
 
 
 
 
ab33f5f
fc09eb0
d2fff9f
fc09eb0
 
 
 
 
 
 
 
 
d2fff9f
d24883f
ab33f5f
 
a77828e
 
 
fc09eb0
9b39d58
fc09eb0
7eeaa8f
3518617
 
7eeaa8f
b8c0c7f
8d17362
7eeaa8f
1ce6977
fc09eb0
ab33f5f
a35c442
fc09eb0
 
 
 
 
 
 
 
7eeaa8f
ab33f5f
0fc53a3
c3485a7
bf416c9
7eeaa8f
 
51a7d9e
82b38de
51a7d9e
 
 
fc09eb0
51a7d9e
 
 
 
7eeaa8f
 
 
 
9c72529
51a7d9e
 
 
 
 
 
 
 
 
 
82b38de
51a7d9e
 
3569c20
51a7d9e
 
bd34f0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a7d9e
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import os
import signal
import threading
import time
import subprocess
import spaces
import asynico

OLLAMA = os.path.expanduser("~/ollama")
process = None
OLLAMA_SERVICE_THREAD = None

if not os.path.exists(OLLAMA):
    subprocess.run("curl -L https://ollama.com/download/ollama-linux-amd64 -o ~/ollama", shell=True)
    os.chmod(OLLAMA, 0o755)

def ollama_service_thread():
    global process
    process = subprocess.Popen("~/ollama serve", shell=True, preexec_fn=os.setsid)
    process.wait()
    
def terminate():
    global process, OLLAMA_SERVICE_THREAD
    if process:
        os.killpg(os.getpgid(process.pid), signal.SIGTERM)
    if OLLAMA_SERVICE_THREAD:
        OLLAMA_SERVICE_THREAD.join()
    process = None
    OLLAMA_SERVICE_THREAD = None
    print("Ollama service stopped.")

# Uncomment and modify the model to what you want locally
# model = "moondream" 
# model = os.environ.get("MODEL")

# subprocess.run(f"~/ollama pull {model}", shell=True)

import ollama
import gradio as gr
from ollama import Client
client = Client(host='http://localhost:11434', timeout=120)

HF_TOKEN = os.environ.get("HF_TOKEN", None)

TITLE = "<h1><center>ollama-Chat</center></h1>"

DESCRIPTION = f"""
<center>
<p>Feel free to test models with ollama.
<br>
Input <em>/pull model_name</em> to pull model.
</p>
</center>
"""

CSS = """
.duplicate-button {
    margin: auto !important;
    color: white !important;
    background: black !important;
    border-radius: 100vh !important;
}
h3 {
    text-align: center;
}
"""
INIT_SIGN = ""

def init():
    global OLLAMA_SERVICE_THREAD
    OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
    OLLAMA_SERVICE_THREAD.start()
    print("Giving ollama serve a moment")
    time.sleep(10)
    global INIT_SIGN
    INIT_SIGN = "FINISHED"

def ollama_func(command):
    if " " in command:
        c1, c2 = command.split(" ")
    else:
        c1 = command
        c2 = ""
    function_map = {
        "/init": init,
        "/pull": lambda: ollama.pull(c2),
        "/list": ollama.list,
        "/bye": terminate,
    }
    if c1 in function_map:
        function_map.get(c1)()
        return "Running..."
    else:
        return "No supported command."
        
@spaces.GPU()      
def launch():
    global OLLAMA_SERVICE_THREAD
    OLLAMA_SERVICE_THREAD = threading.Thread(target=ollama_service_thread)
    OLLAMA_SERVICE_THREAD.start()
    print("Giving ollama serve a moment")
    time.sleep(10)
  
async def stream_chat(message: str, history: list, model: str, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
    print(f"message: {message}")
    conversation = []
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt}, 
            {"role": "assistant", "content": answer},
        ])
        conversation.append({"role": "user", "content": message})
        
        print(f"Conversation is -\n{conversation}")
               
        response = client.chat(
            model=model,
            stream=True,
            messages=conversation,
            keep_alive="60s",
            options={
                'num_predict': max_new_tokens,
                'temperature': temperature,
                'top_p': top_p,
                'top_k': top_k,
                'repeat_penalty': penalty,
                'low_vram': True,
            },
        )
        
        print(response)
        
        buffer = ""
        for chunk in response:
            buffer += chunk["message"]["content"]
            yield buffer


def main(message: str, history: list, model: str, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
    if message.startswith("/"):
        resp = ollama_func(message)
        yield resp
    else:
        if not INIT_SIGN:
            yield "Please initialize Ollama"
        else:
            if not process:
                launch()
               
            await response = stream_chat(
                message,
                history,
                model,
                temperature,
                max_new_tokens,
                top_p,
                top_k,
                penalty             
            )

            yield response
            
            

chatbot = gr.Chatbot(height=600, placeholder=DESCRIPTION)

with gr.Blocks(css=CSS, theme="soft") as demo:
    gr.HTML(TITLE)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
    gr.ChatInterface(
        fn=main,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Textbox(
                value="qwen2:0.5b",
                label="Model",
                render=False,
            ),
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=2048,
                step=1,
                value=1024,
                label="Max New Tokens",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=0.8,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=20,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.0,
                label="Repetition penalty",
                render=False,
            ),
        ],
        examples=[
            ["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
            ["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
            ["Tell me a random fun fact about the Roman Empire."],
            ["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
        ],
        cache_examples=False,
    )


if __name__ == "__main__":
    demo.launch()