Spaces:
Sleeping
Sleeping
min-new-token 추가
Browse files- app.py +34 -19
- test.ipynb +120 -139
app.py
CHANGED
@@ -2,13 +2,40 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import random
|
4 |
import time
|
5 |
-
from transformers import pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
-
generator = pipeline(
|
8 |
-
'text-generation',
|
9 |
-
model="heegyu/bluechat-v0",
|
10 |
-
device="cuda:0" if torch.cuda.is_available() else 'cpu'
|
11 |
-
)
|
12 |
|
13 |
def query(message, chat_history, max_turn=4):
|
14 |
prompt = []
|
@@ -21,19 +48,7 @@ def query(message, chat_history, max_turn=4):
|
|
21 |
prompt.append(f"<usr> {message}")
|
22 |
prompt = "\n".join(prompt) + "\n<bot>"
|
23 |
|
24 |
-
|
25 |
-
prompt,
|
26 |
-
# repetition_penalty=1.3,
|
27 |
-
# no_repeat_ngram_size=2,
|
28 |
-
eos_token_id=2, # \n
|
29 |
-
max_new_tokens=128,
|
30 |
-
do_sample=True,
|
31 |
-
top_p=0.9,
|
32 |
-
)[0]['generated_text']
|
33 |
-
|
34 |
-
print(output)
|
35 |
-
|
36 |
-
response = output[len(prompt):]
|
37 |
return response.strip()
|
38 |
|
39 |
with gr.Blocks() as demo:
|
|
|
2 |
import torch
|
3 |
import random
|
4 |
import time
|
5 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
6 |
+
|
7 |
+
|
8 |
+
model_name="heegyu/bluechat-v0"
|
9 |
+
device="cuda:0" if torch.cuda.is_available() else 'cpu'
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
11 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
12 |
+
|
13 |
+
# generator = pipeline(
|
14 |
+
# 'text-generation',
|
15 |
+
# model="heegyu/bluechat-v0",
|
16 |
+
# device="cuda:0" if torch.cuda.is_available() else 'cpu'
|
17 |
+
# )
|
18 |
+
|
19 |
+
def get_message(prompt, min_new_tokens=16, max_turn=4):
|
20 |
+
prompt = prompt.strip()
|
21 |
+
ids = tokenizer(prompt, return_tensors="pt").to(device)
|
22 |
+
min_length = ids['input_ids'].shape[1] + min_new_tokens
|
23 |
+
|
24 |
+
output = model.generate(
|
25 |
+
**ids,
|
26 |
+
no_repeat_ngram_size=3,
|
27 |
+
eos_token_id=2, # 375=\n 2=</s>, 0:open-end
|
28 |
+
max_new_tokens=128,
|
29 |
+
min_length=min_length,
|
30 |
+
do_sample=True,
|
31 |
+
top_p=0.7,
|
32 |
+
early_stopping=True
|
33 |
+
) # [0]['generated_text']
|
34 |
+
|
35 |
+
output = tokenizer.decode(output.cpu()[0])
|
36 |
+
print(output)
|
37 |
+
return output[len(prompt):]
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
def query(message, chat_history, max_turn=4):
|
41 |
prompt = []
|
|
|
48 |
prompt.append(f"<usr> {message}")
|
49 |
prompt = "\n".join(prompt) + "\n<bot>"
|
50 |
|
51 |
+
response = get_message(prompt, 8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
return response.strip()
|
53 |
|
54 |
with gr.Blocks() as demo:
|
test.ipynb
CHANGED
@@ -2,161 +2,42 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
-
"outputs": [
|
8 |
-
{
|
9 |
-
"name": "stderr",
|
10 |
-
"output_type": "stream",
|
11 |
-
"text": [
|
12 |
-
"/opt/anaconda3/lib/python3.9/site-packages/huggingface_hub/utils/_hf_folder.py:92: UserWarning: A token has been found in `/Users/casa/.huggingface/token`. This is the old path where tokens were stored. The new location is `/Users/casa/.cache/huggingface/token` which is configurable using `HF_HOME` environment variable. Your token has been copied to this new location. You can now safely delete the old token file manually or use `huggingface-cli logout`.\n",
|
13 |
-
" warnings.warn(\n"
|
14 |
-
]
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"data": {
|
18 |
-
"application/vnd.jupyter.widget-view+json": {
|
19 |
-
"model_id": "e42b34cf3f07417592f26316fea86e1a",
|
20 |
-
"version_major": 2,
|
21 |
-
"version_minor": 0
|
22 |
-
},
|
23 |
-
"text/plain": [
|
24 |
-
"Downloading (…)lve/main/config.json: 0%| | 0.00/944 [00:00<?, ?B/s]"
|
25 |
-
]
|
26 |
-
},
|
27 |
-
"metadata": {},
|
28 |
-
"output_type": "display_data"
|
29 |
-
},
|
30 |
-
{
|
31 |
-
"data": {
|
32 |
-
"application/vnd.jupyter.widget-view+json": {
|
33 |
-
"model_id": "4f89d76d6b7e4cf59a9dd631bd739221",
|
34 |
-
"version_major": 2,
|
35 |
-
"version_minor": 0
|
36 |
-
},
|
37 |
-
"text/plain": [
|
38 |
-
"Downloading pytorch_model.bin: 0%| | 0.00/1.66G [00:00<?, ?B/s]"
|
39 |
-
]
|
40 |
-
},
|
41 |
-
"metadata": {},
|
42 |
-
"output_type": "display_data"
|
43 |
-
},
|
44 |
-
{
|
45 |
-
"data": {
|
46 |
-
"application/vnd.jupyter.widget-view+json": {
|
47 |
-
"model_id": "a690f8b53a204d489f4d53a937068ac6",
|
48 |
-
"version_major": 2,
|
49 |
-
"version_minor": 0
|
50 |
-
},
|
51 |
-
"text/plain": [
|
52 |
-
"Downloading (…)neration_config.json: 0%| | 0.00/111 [00:00<?, ?B/s]"
|
53 |
-
]
|
54 |
-
},
|
55 |
-
"metadata": {},
|
56 |
-
"output_type": "display_data"
|
57 |
-
},
|
58 |
-
{
|
59 |
-
"data": {
|
60 |
-
"application/vnd.jupyter.widget-view+json": {
|
61 |
-
"model_id": "14302bef459f485a998d908b131f43ec",
|
62 |
-
"version_major": 2,
|
63 |
-
"version_minor": 0
|
64 |
-
},
|
65 |
-
"text/plain": [
|
66 |
-
"Downloading (…)okenizer_config.json: 0%| | 0.00/771 [00:00<?, ?B/s]"
|
67 |
-
]
|
68 |
-
},
|
69 |
-
"metadata": {},
|
70 |
-
"output_type": "display_data"
|
71 |
-
},
|
72 |
-
{
|
73 |
-
"data": {
|
74 |
-
"application/vnd.jupyter.widget-view+json": {
|
75 |
-
"model_id": "33826da838e1402581f62fafd3657b90",
|
76 |
-
"version_major": 2,
|
77 |
-
"version_minor": 0
|
78 |
-
},
|
79 |
-
"text/plain": [
|
80 |
-
"Downloading (…)olve/main/vocab.json: 0%| | 0.00/1.27M [00:00<?, ?B/s]"
|
81 |
-
]
|
82 |
-
},
|
83 |
-
"metadata": {},
|
84 |
-
"output_type": "display_data"
|
85 |
-
},
|
86 |
-
{
|
87 |
-
"data": {
|
88 |
-
"application/vnd.jupyter.widget-view+json": {
|
89 |
-
"model_id": "3ebc87d16a79449998bcb21e33d2ec0b",
|
90 |
-
"version_major": 2,
|
91 |
-
"version_minor": 0
|
92 |
-
},
|
93 |
-
"text/plain": [
|
94 |
-
"Downloading (…)olve/main/merges.txt: 0%| | 0.00/925k [00:00<?, ?B/s]"
|
95 |
-
]
|
96 |
-
},
|
97 |
-
"metadata": {},
|
98 |
-
"output_type": "display_data"
|
99 |
-
},
|
100 |
-
{
|
101 |
-
"data": {
|
102 |
-
"application/vnd.jupyter.widget-view+json": {
|
103 |
-
"model_id": "d70c4a2755d04e0d995686f9425b49f8",
|
104 |
-
"version_major": 2,
|
105 |
-
"version_minor": 0
|
106 |
-
},
|
107 |
-
"text/plain": [
|
108 |
-
"Downloading (…)/main/tokenizer.json: 0%| | 0.00/3.07M [00:00<?, ?B/s]"
|
109 |
-
]
|
110 |
-
},
|
111 |
-
"metadata": {},
|
112 |
-
"output_type": "display_data"
|
113 |
-
},
|
114 |
-
{
|
115 |
-
"data": {
|
116 |
-
"application/vnd.jupyter.widget-view+json": {
|
117 |
-
"model_id": "cd341cbb7ff445daa312695cc9be1a13",
|
118 |
-
"version_major": 2,
|
119 |
-
"version_minor": 0
|
120 |
-
},
|
121 |
-
"text/plain": [
|
122 |
-
"Downloading (…)cial_tokens_map.json: 0%| | 0.00/96.0 [00:00<?, ?B/s]"
|
123 |
-
]
|
124 |
-
},
|
125 |
-
"metadata": {},
|
126 |
-
"output_type": "display_data"
|
127 |
-
}
|
128 |
-
],
|
129 |
"source": [
|
130 |
"import torch\n",
|
131 |
"import random\n",
|
132 |
"import time\n",
|
133 |
-
"from transformers import pipeline\n",
|
134 |
"\n",
|
135 |
-
"
|
136 |
-
"
|
137 |
-
"
|
138 |
-
"
|
139 |
-
")"
|
140 |
]
|
141 |
},
|
142 |
{
|
143 |
"cell_type": "code",
|
144 |
-
"execution_count":
|
145 |
"metadata": {},
|
146 |
"outputs": [],
|
147 |
"source": [
|
148 |
"\n",
|
149 |
-
"def query(prompt, max_turn=4):\n",
|
150 |
-
"
|
151 |
-
"
|
152 |
-
"
|
153 |
-
"
|
|
|
|
|
154 |
" max_new_tokens=128,\n",
|
|
|
155 |
" do_sample=True,\n",
|
156 |
" top_p=0.7,\n",
|
157 |
" early_stopping=True\n",
|
158 |
-
" )[0]['generated_text']\n",
|
159 |
-
"\n",
|
160 |
" print(output)\n",
|
161 |
"\n",
|
162 |
" # response = output[len(prompt):]\n",
|
@@ -165,19 +46,34 @@
|
|
165 |
},
|
166 |
{
|
167 |
"cell_type": "code",
|
168 |
-
"execution_count":
|
169 |
"metadata": {},
|
170 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
{
|
172 |
"name": "stdout",
|
173 |
"output_type": "stream",
|
174 |
"text": [
|
|
|
175 |
"0 : 안녕하세요</s>\n",
|
176 |
"1 : 반가워요</s>\n",
|
177 |
"0 : 요즘 좋아하는 음악 있으신가요?</s>\n",
|
178 |
"1 : 최근에 들어서인지 너무 많이 들어요</s>\n",
|
179 |
"0 : 음 주로 어떤거요?</s>\n",
|
180 |
-
"1 :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
]
|
182 |
}
|
183 |
],
|
@@ -191,6 +87,91 @@
|
|
191 |
"1 : \n",
|
192 |
"\"\"\")"
|
193 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
}
|
195 |
],
|
196 |
"metadata": {
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 40,
|
6 |
"metadata": {},
|
7 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"source": [
|
9 |
"import torch\n",
|
10 |
"import random\n",
|
11 |
"import time\n",
|
12 |
+
"from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer\n",
|
13 |
"\n",
|
14 |
+
"model_name=\"heegyu/bluechat-v0\"\n",
|
15 |
+
"device=\"cuda:0\" if torch.cuda.is_available() else 'cpu'\n",
|
16 |
+
"model = AutoModelForCausalLM.from_pretrained(model_name)\n",
|
17 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
|
|
18 |
]
|
19 |
},
|
20 |
{
|
21 |
"cell_type": "code",
|
22 |
+
"execution_count": 54,
|
23 |
"metadata": {},
|
24 |
"outputs": [],
|
25 |
"source": [
|
26 |
"\n",
|
27 |
+
"def query(prompt, min_new_tokens=16, max_turn=4):\n",
|
28 |
+
" ids = tokenizer(prompt.strip(), return_tensors=\"pt\").to(device)\n",
|
29 |
+
" min_length = ids['input_ids'].shape[1] + min_new_tokens\n",
|
30 |
+
" output = model.generate(\n",
|
31 |
+
" **ids,\n",
|
32 |
+
" no_repeat_ngram_size=3,\n",
|
33 |
+
" eos_token_id=2, # 375=\\n 2=</s>, 0:open-end\n",
|
34 |
" max_new_tokens=128,\n",
|
35 |
+
" min_length=min_length,\n",
|
36 |
" do_sample=True,\n",
|
37 |
" top_p=0.7,\n",
|
38 |
" early_stopping=True\n",
|
39 |
+
" ) # [0]['generated_text']\n",
|
40 |
+
" output = tokenizer.decode(output.cpu()[0])\n",
|
41 |
" print(output)\n",
|
42 |
"\n",
|
43 |
" # response = output[len(prompt):]\n",
|
|
|
46 |
},
|
47 |
{
|
48 |
"cell_type": "code",
|
49 |
+
"execution_count": 42,
|
50 |
"metadata": {},
|
51 |
"outputs": [
|
52 |
+
{
|
53 |
+
"name": "stderr",
|
54 |
+
"output_type": "stream",
|
55 |
+
"text": [
|
56 |
+
"Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.\n"
|
57 |
+
]
|
58 |
+
},
|
59 |
{
|
60 |
"name": "stdout",
|
61 |
"output_type": "stream",
|
62 |
"text": [
|
63 |
+
"\n",
|
64 |
"0 : 안녕하세요</s>\n",
|
65 |
"1 : 반가워요</s>\n",
|
66 |
"0 : 요즘 좋아하는 음악 있으신가요?</s>\n",
|
67 |
"1 : 최근에 들어서인지 너무 많이 들어요</s>\n",
|
68 |
"0 : 음 주로 어떤거요?</s>\n",
|
69 |
+
"1 : \n",
|
70 |
+
" music : music songs 수록곡을 즐겨들어요</s><bot> 앗 어떤 장르를 주로 들으시나요?</s>\n",
|
71 |
+
"1 : music songs 좋죠</s>\n",
|
72 |
+
"bot> 저도 요즘 들어 좋아하게 된 곡들 위주로 들어요 ㅎㅎ</s>\n",
|
73 |
+
"2 : music songs 어떤 노래들 자주 들어요?</s>\n",
|
74 |
+
"bot> 저 music songs someone이 제일 좋더라구요 ㅎㅎ</s>\n",
|
75 |
+
"1 : music songs는 어떤 곡들 주로 들어요?</s>\n",
|
76 |
+
"bot> 저 music songs는 주로 music songs를 많이 들어요 ㅎㅎ</s>\n"
|
77 |
]
|
78 |
}
|
79 |
],
|
|
|
87 |
"1 : \n",
|
88 |
"\"\"\")"
|
89 |
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 48,
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [
|
96 |
+
{
|
97 |
+
"name": "stderr",
|
98 |
+
"output_type": "stream",
|
99 |
+
"text": [
|
100 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"name": "stdout",
|
105 |
+
"output_type": "stream",
|
106 |
+
"text": [
|
107 |
+
"<usr> 안녕하세요\n",
|
108 |
+
"<bot> 안녕하세요~ 저녁 드셨나요? ㅎㅎ? ㅎㅎ</s>\n"
|
109 |
+
]
|
110 |
+
}
|
111 |
+
],
|
112 |
+
"source": [
|
113 |
+
"query(\"\"\"\n",
|
114 |
+
"<usr> 안녕하세요\n",
|
115 |
+
"<bot>\n",
|
116 |
+
"\"\"\", 8)"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": 55,
|
122 |
+
"metadata": {},
|
123 |
+
"outputs": [
|
124 |
+
{
|
125 |
+
"name": "stderr",
|
126 |
+
"output_type": "stream",
|
127 |
+
"text": [
|
128 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"name": "stdout",
|
133 |
+
"output_type": "stream",
|
134 |
+
"text": [
|
135 |
+
"<usr> 안녕하세요 식사 하셨나요?\n",
|
136 |
+
"<bot> 안녕하세요 네~ 점심 먹었어요 식사하셨나요?\n",
|
137 |
+
"네~ 뭐드셨나요?</s>\n"
|
138 |
+
]
|
139 |
+
}
|
140 |
+
],
|
141 |
+
"source": [
|
142 |
+
"query(\"\"\"\n",
|
143 |
+
"<usr> 안녕하세요 식사 하셨나요?\n",
|
144 |
+
"<bot>\n",
|
145 |
+
"\"\"\", 8)"
|
146 |
+
]
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"cell_type": "code",
|
150 |
+
"execution_count": 63,
|
151 |
+
"metadata": {},
|
152 |
+
"outputs": [
|
153 |
+
{
|
154 |
+
"name": "stderr",
|
155 |
+
"output_type": "stream",
|
156 |
+
"text": [
|
157 |
+
"Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"name": "stdout",
|
162 |
+
"output_type": "stream",
|
163 |
+
"text": [
|
164 |
+
"<usr> 창업에 관심이 있나요?\n",
|
165 |
+
"<bot> 네! 근데 요즘 창업에 대한 관심이 많이 떨어지더라구요</s>\n"
|
166 |
+
]
|
167 |
+
}
|
168 |
+
],
|
169 |
+
"source": [
|
170 |
+
"query(\"\"\"\n",
|
171 |
+
"<usr> 창업에 관심이 있나요?\n",
|
172 |
+
"<bot>\n",
|
173 |
+
"\"\"\", 8)"
|
174 |
+
]
|
175 |
}
|
176 |
],
|
177 |
"metadata": {
|