Zhanming commited on
Commit
d930631
·
1 Parent(s): c7fdaab

update application files

Browse files
Files changed (2) hide show
  1. model.py +78 -0
  2. style.css +16 -0
model.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from threading import Thread
2
+ from typing import Iterator
3
+
4
+ #import torch
5
+ from transformers.utils import logging
6
+ from ctransformers import AutoModelForCausalLM
7
+ from transformers import TextIteratorStreamer, AutoTokenizer
8
+
9
+ logging.set_verbosity_info()
10
+ logger = logging.get_logger("transformers")
11
+
12
+ config = {"max_new_tokens": 256, "repetition_penalty": 1.1,
13
+ "temperature": 0.1, "stream": True}
14
+ model_id = "TheBloke/Llama-2-7B-Chat-GGML"
15
+ device = "cpu"
16
+
17
+
18
+ model = AutoModelForCausalLM.from_pretrained(model_id, model_type="llama", lib="avx2", hf=True)
19
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
20
+
21
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
22
+ system_prompt: str) -> str:
23
+ #logger.info("get_prompt chat_history=%s",chat_history)
24
+ #logger.info("get_prompt system_prompt=%s",system_prompt)
25
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
26
+ #logger.info("texts=%s",texts)
27
+ do_strip = False
28
+ for user_input, response in chat_history:
29
+ user_input = user_input.strip() if do_strip else user_input
30
+ do_strip = True
31
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
32
+ message = message.strip() if do_strip else message
33
+ #logger.info("get_prompt message=%s",message)
34
+ texts.append(f'{message} [/INST]')
35
+ #logger.info("get_prompt final texts=%s",texts)
36
+ return ''.join(texts)
37
+
38
+
39
+ def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
40
+ #logger.info("get_input_token_length=%s",message)
41
+ prompt = get_prompt(message, chat_history, system_prompt)
42
+ #logger.info("prompt=%s",prompt)
43
+ input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
44
+ #logger.info("input_ids=%s",input_ids)
45
+ return input_ids.shape[-1]
46
+
47
+
48
+ def run(message: str,
49
+ chat_history: list[tuple[str, str]],
50
+ system_prompt: str,
51
+ max_new_tokens: int = 1024,
52
+ temperature: float = 0.8,
53
+ top_p: float = 0.95,
54
+ top_k: int = 50) -> Iterator[str]:
55
+ prompt = get_prompt(message, chat_history, system_prompt)
56
+ inputs = tokenizer([prompt], return_tensors='pt', add_special_tokens=False).to(device)
57
+
58
+ streamer = TextIteratorStreamer(tokenizer,
59
+ timeout=15.,
60
+ skip_prompt=True,
61
+ skip_special_tokens=True)
62
+ generate_kwargs = dict(
63
+ inputs,
64
+ streamer=streamer,
65
+ max_new_tokens=max_new_tokens,
66
+ do_sample=True,
67
+ top_p=top_p,
68
+ top_k=top_k,
69
+ temperature=temperature,
70
+ num_beams=1,
71
+ )
72
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
73
+ t.start()
74
+
75
+ outputs = []
76
+ for text in streamer:
77
+ outputs.append(text)
78
+ yield "".join(outputs)
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ #component-0 {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }