Pluto0616 commited on
Commit
7378138
·
1 Parent(s): ac3dfcf
Files changed (2) hide show
  1. app.py +294 -0
  2. requirements.txt +218 -0
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """This script refers to the dialogue example of streamlit, the interactive
2
+ generation code of chatglm2 and transformers.
3
+ We mainly modified part of the code logic to adapt to the
4
+ generation of our model.
5
+ Please refer to these links below for more information:
6
+ 1. streamlit chat example:
7
+ https://docs.streamlit.io/knowledge-base/tutorials/build-conversational-apps
8
+ 2. chatglm2:
9
+ https://github.com/THUDM/ChatGLM2-6B
10
+ 3. transformers:
11
+ https://github.com/huggingface/transformers
12
+ Please run with the command `streamlit run path/to/web_demo.py
13
+ --server.address=0.0.0.0 --server.port 7860`.
14
+ Using `python path/to/web_demo.py` may cause unknown problems.
15
+ """
16
+ # isort: skip_file
17
+ import copy
18
+ import warnings
19
+ from dataclasses import asdict, dataclass
20
+ from typing import Callable, List, Optional
21
+ import os
22
+ import streamlit as st
23
+ import torch
24
+ from torch import nn
25
+ from transformers.generation.utils import (LogitsProcessorList,
26
+ StoppingCriteriaList)
27
+ from transformers.utils import logging
28
+
29
+ from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip
30
+ os.system('git lfs install')
31
+ os.system("git clone https://huggingface.co/Pluto0616/intern_study_L1_5")
32
+
33
+ logger = logging.get_logger(__name__)
34
+ model_name_or_path="intern_study_L1_5"
35
+
36
+
37
+ @dataclass
38
+ class GenerationConfig:
39
+ # this config is used for chat to provide more diversity
40
+ max_length: int = 32768
41
+ top_p: float = 0.8
42
+ temperature: float = 0.8
43
+ do_sample: bool = True
44
+ repetition_penalty: float = 1.005
45
+
46
+
47
+ @torch.inference_mode()
48
+ def generate_interactive(
49
+ model,
50
+ tokenizer,
51
+ prompt,
52
+ generation_config: Optional[GenerationConfig] = None,
53
+ logits_processor: Optional[LogitsProcessorList] = None,
54
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
55
+ prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
56
+ List[int]]] = None,
57
+ additional_eos_token_id: Optional[int] = None,
58
+ **kwargs,
59
+ ):
60
+ inputs = tokenizer([prompt], padding=True, return_tensors='pt')
61
+ input_length = len(inputs['input_ids'][0])
62
+ for k, v in inputs.items():
63
+ inputs[k] = v.cuda()
64
+ input_ids = inputs['input_ids']
65
+ _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
66
+ if generation_config is None:
67
+ generation_config = model.generation_config
68
+ generation_config = copy.deepcopy(generation_config)
69
+ model_kwargs = generation_config.update(**kwargs)
70
+ bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
71
+ generation_config.bos_token_id,
72
+ generation_config.eos_token_id,
73
+ )
74
+ if isinstance(eos_token_id, int):
75
+ eos_token_id = [eos_token_id]
76
+ if additional_eos_token_id is not None:
77
+ eos_token_id.append(additional_eos_token_id)
78
+ has_default_max_length = kwargs.get(
79
+ 'max_length') is None and generation_config.max_length is not None
80
+ if has_default_max_length and generation_config.max_new_tokens is None:
81
+ warnings.warn(
82
+ f"Using 'max_length''s default \
83
+ ({repr(generation_config.max_length)}) \
84
+ to control the generation length. "
85
+ 'This behaviour is deprecated and will be removed from the \
86
+ config in v5 of Transformers -- we'
87
+ ' recommend using `max_new_tokens` to control the maximum \
88
+ length of the generation.',
89
+ UserWarning,
90
+ )
91
+ elif generation_config.max_new_tokens is not None:
92
+ generation_config.max_length = generation_config.max_new_tokens + \
93
+ input_ids_seq_length
94
+ if not has_default_max_length:
95
+ logger.warn( # pylint: disable=W4902
96
+ f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
97
+ f"and 'max_length'(={generation_config.max_length}) seem to "
98
+ "have been set. 'max_new_tokens' will take precedence. "
99
+ 'Please refer to the documentation for more information. '
100
+ '(https://huggingface.co/docs/transformers/main/'
101
+ 'en/main_classes/text_generation)',
102
+ UserWarning,
103
+ )
104
+
105
+ if input_ids_seq_length >= generation_config.max_length:
106
+ input_ids_string = 'input_ids'
107
+ logger.warning(
108
+ f'Input length of {input_ids_string} is {input_ids_seq_length}, '
109
+ f"but 'max_length' is set to {generation_config.max_length}. "
110
+ 'This can lead to unexpected behavior. You should consider'
111
+ " increasing 'max_new_tokens'.")
112
+
113
+ # 2. Set generation parameters if not already defined
114
+ logits_processor = logits_processor if logits_processor is not None \
115
+ else LogitsProcessorList()
116
+ stopping_criteria = stopping_criteria if stopping_criteria is not None \
117
+ else StoppingCriteriaList()
118
+
119
+ logits_processor = model._get_logits_processor(
120
+ generation_config=generation_config,
121
+ input_ids_seq_length=input_ids_seq_length,
122
+ encoder_input_ids=input_ids,
123
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
124
+ logits_processor=logits_processor,
125
+ )
126
+
127
+ stopping_criteria = model._get_stopping_criteria(
128
+ generation_config=generation_config,
129
+ stopping_criteria=stopping_criteria)
130
+ logits_warper = model._get_logits_warper(generation_config)
131
+
132
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
133
+ scores = None
134
+ while True:
135
+ model_inputs = model.prepare_inputs_for_generation(
136
+ input_ids, **model_kwargs)
137
+ # forward pass to get next token
138
+ outputs = model(
139
+ **model_inputs,
140
+ return_dict=True,
141
+ output_attentions=False,
142
+ output_hidden_states=False,
143
+ )
144
+
145
+ next_token_logits = outputs.logits[:, -1, :]
146
+
147
+ # pre-process distribution
148
+ next_token_scores = logits_processor(input_ids, next_token_logits)
149
+ next_token_scores = logits_warper(input_ids, next_token_scores)
150
+
151
+ # sample
152
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
153
+ if generation_config.do_sample:
154
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
155
+ else:
156
+ next_tokens = torch.argmax(probs, dim=-1)
157
+
158
+ # update generated ids, model inputs, and length for next step
159
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
160
+ model_kwargs = model._update_model_kwargs_for_generation(
161
+ outputs, model_kwargs, is_encoder_decoder=False)
162
+ unfinished_sequences = unfinished_sequences.mul(
163
+ (min(next_tokens != i for i in eos_token_id)).long())
164
+
165
+ output_token_ids = input_ids[0].cpu().tolist()
166
+ output_token_ids = output_token_ids[input_length:]
167
+ for each_eos_token_id in eos_token_id:
168
+ if output_token_ids[-1] == each_eos_token_id:
169
+ output_token_ids = output_token_ids[:-1]
170
+ response = tokenizer.decode(output_token_ids)
171
+
172
+ yield response
173
+ # stop when each sentence is finished
174
+ # or if we exceed the maximum length
175
+ if unfinished_sequences.max() == 0 or stopping_criteria(
176
+ input_ids, scores):
177
+ break
178
+
179
+
180
+ def on_btn_click():
181
+ del st.session_state.messages
182
+
183
+
184
+ @st.cache_resource
185
+ def load_model():
186
+ model = (AutoModelForCausalLM.from_pretrained(
187
+ model_name_or_path,
188
+ trust_remote_code=True).to(torch.bfloat16).cuda())
189
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,
190
+ trust_remote_code=True)
191
+ return model, tokenizer
192
+
193
+
194
+ def prepare_generation_config():
195
+ with st.sidebar:
196
+ max_length = st.slider('Max Length',
197
+ min_value=8,
198
+ max_value=32768,
199
+ value=32768)
200
+ top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
201
+ temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
202
+ st.button('Clear Chat History', on_click=on_btn_click)
203
+
204
+ generation_config = GenerationConfig(max_length=max_length,
205
+ top_p=top_p,
206
+ temperature=temperature)
207
+
208
+ return generation_config
209
+
210
+
211
+ user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
212
+ robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
213
+ cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
214
+ <|im_start|>assistant\n'
215
+
216
+
217
+ def combine_history(prompt):
218
+ messages = st.session_state.messages
219
+ meta_instruction = ('You are a helpful, honest, '
220
+ 'and harmless AI assistant.')
221
+ total_prompt = f'<s><|im_start|>system\n{meta_instruction}<|im_end|>\n'
222
+ for message in messages:
223
+ cur_content = message['content']
224
+ if message['role'] == 'user':
225
+ cur_prompt = user_prompt.format(user=cur_content)
226
+ elif message['role'] == 'robot':
227
+ cur_prompt = robot_prompt.format(robot=cur_content)
228
+ else:
229
+ raise RuntimeError
230
+ total_prompt += cur_prompt
231
+ total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
232
+ return total_prompt
233
+
234
+
235
+ def main():
236
+ st.title('internlm2_5-7b-chat-assistant')
237
+
238
+ # torch.cuda.empty_cache()
239
+ print('load model begin.')
240
+ model, tokenizer = load_model()
241
+ print('load model end.')
242
+
243
+ generation_config = prepare_generation_config()
244
+
245
+ # Initialize chat history
246
+ if 'messages' not in st.session_state:
247
+ st.session_state.messages = []
248
+
249
+ # Display chat messages from history on app rerun
250
+ for message in st.session_state.messages:
251
+ with st.chat_message(message['role'], avatar=message.get('avatar')):
252
+ st.markdown(message['content'])
253
+
254
+ # Accept user input
255
+ if prompt := st.chat_input('What is up?'):
256
+ # Display user message in chat message container
257
+
258
+ with st.chat_message('user', avatar='user'):
259
+
260
+ st.markdown(prompt)
261
+ real_prompt = combine_history(prompt)
262
+ # Add user message to chat history
263
+ st.session_state.messages.append({
264
+ 'role': 'user',
265
+ 'content': prompt,
266
+ 'avatar': 'user'
267
+ })
268
+
269
+ with st.chat_message('robot', avatar='assistant'):
270
+
271
+ message_placeholder = st.empty()
272
+ for cur_response in generate_interactive(
273
+ model=model,
274
+ tokenizer=tokenizer,
275
+ prompt=real_prompt,
276
+ additional_eos_token_id=92542,
277
+ device='cuda:0',
278
+ **asdict(generation_config),
279
+ ):
280
+ # Display robot response in chat message container
281
+ message_placeholder.markdown(cur_response + '▌')
282
+ message_placeholder.markdown(cur_response)
283
+ # Add robot response to chat history
284
+ st.session_state.messages.append({
285
+ 'role': 'robot',
286
+ 'content': cur_response, # pylint: disable=undefined-loop-variable
287
+ 'avatar': 'assistant',
288
+ })
289
+ torch.cuda.empty_cache()
290
+
291
+
292
+ if __name__ == '__main__':
293
+ main()
294
+
requirements.txt ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.27.0
2
+ addict==2.4.0
3
+ aiohttp==3.9.3
4
+ aiosignal==1.3.1
5
+ aliyun-python-sdk-core==2.14.0
6
+ aliyun-python-sdk-kms==2.16.2
7
+ altair==5.2.0
8
+ annotated-types==0.6.0
9
+ anyio==4.2.0
10
+ argon2-cffi==23.1.0
11
+ argon2-cffi-bindings==21.2.0
12
+ arrow==1.3.0
13
+ arxiv==2.1.0
14
+ asttokens==2.4.1
15
+ async-lru==2.0.4
16
+ async-timeout==4.0.3
17
+ attrs==23.2.0
18
+ Babel==2.14.0
19
+ beautifulsoup4==4.12.3
20
+ bitsandbytes==0.42.0
21
+ bleach==6.1.0
22
+ blinker==1.7.0
23
+ cachetools==5.3.2
24
+ certifi==2024.2.2
25
+ cffi==1.16.0
26
+ charset-normalizer==3.3.2
27
+ click==8.1.7
28
+ colorama==0.4.6
29
+ comm==0.2.1
30
+ contourpy==1.2.0
31
+ crcmod==1.7
32
+ cryptography==42.0.2
33
+ cycler==0.12.1
34
+ datasets==2.17.0
35
+ debugpy==1.8.1
36
+ decorator==5.1.1
37
+ deepspeed==0.13.1
38
+ defusedxml==0.7.1
39
+ dill==0.3.8
40
+ distro==1.9.0
41
+ einops==0.8.0
42
+ einx==0.3.0
43
+ et-xmlfile==1.1.0
44
+ exceptiongroup==1.2.0
45
+ executing==2.0.1
46
+ fastapi==0.112.0
47
+ fastjsonschema==2.19.1
48
+ feedparser==6.0.10
49
+ filelock==3.14.0
50
+ fonttools==4.48.1
51
+ fqdn==1.5.1
52
+ frozendict==2.4.4
53
+ frozenlist==1.4.1
54
+ fsspec==2023.10.0
55
+ func-timeout==4.3.5
56
+ gast==0.5.4
57
+ gitdb==4.0.11
58
+ GitPython==3.1.41
59
+ google-search-results==2.4.2
60
+ griffe==0.40.1
61
+ h11==0.14.0
62
+ hjson==3.1.0
63
+ httpcore==1.0.3
64
+ httpx==0.26.0
65
+ huggingface-hub==0.24.2
66
+ idna==3.6
67
+ imageio==2.34.2
68
+ importlib-metadata==7.0.1
69
+ ipykernel==6.29.2
70
+ ipython==8.21.0
71
+ ipywidgets==8.1.2
72
+ isoduration==20.11.0
73
+ jedi==0.19.1
74
+ Jinja2==3.1.3
75
+ jmespath==0.10.0
76
+ json5==0.9.14
77
+ jsonpointer==2.4
78
+ jsonschema==4.21.1
79
+ jsonschema-specifications==2023.12.1
80
+ kiwisolver==1.4.5
81
+ lagent==0.2.1
82
+ lazy_loader==0.4
83
+ llvmlite==0.43.0
84
+ lxml==5.1.0
85
+ markdown-it-py==3.0.0
86
+ MarkupSafe==2.1.5
87
+ matplotlib==3.8.2
88
+ matplotlib-inline==0.1.6
89
+ mdurl==0.1.2
90
+ mistune==3.0.2
91
+ mmengine==0.10.3
92
+ modelscope==1.12.0
93
+ mpi4py_mpich==3.1.5
94
+ mpmath==1.3.0
95
+ multidict==6.0.5
96
+ multiprocess==0.70.16
97
+ nbclient==0.9.0
98
+ nbconvert==7.16.0
99
+ nbformat==5.9.2
100
+ nest-asyncio==1.6.0
101
+ networkx==3.2.1
102
+ ninja==1.11.1.1
103
+ notebook==7.0.8
104
+ notebook_shim==0.2.3
105
+ numba==0.60.0
106
+ numpy==1.26.4
107
+ nvidia-cublas-cu12==12.1.3.1
108
+ nvidia-cuda-cupti-cu12==12.1.105
109
+ nvidia-cuda-nvrtc-cu12==12.1.105
110
+ nvidia-cuda-runtime-cu12==12.1.105
111
+ nvidia-cudnn-cu12==8.9.2.26
112
+ nvidia-cufft-cu12==11.0.2.54
113
+ nvidia-curand-cu12==10.3.2.106
114
+ nvidia-cusolver-cu12==11.4.5.107
115
+ nvidia-cusparse-cu12==12.1.0.106
116
+ nvidia-nccl-cu12==2.19.3
117
+ nvidia-nvjitlink-cu12==12.3.101
118
+ nvidia-nvtx-cu12==12.1.105
119
+ openai==1.12.0
120
+ opencv-python==4.9.0.80
121
+ openpyxl==3.1.2
122
+ oss2==2.17.0
123
+ overrides==7.7.0
124
+ packaging==24.1
125
+ pandas==2.2.0
126
+ pandocfilters==1.5.1
127
+ parso==0.8.3
128
+ peft==0.8.2
129
+ pexpect==4.9.0
130
+ phx-class-registry==4.1.0
131
+ pillow==10.2.0
132
+ platformdirs==4.2.0
133
+ prometheus-client==0.19.0
134
+ prompt-toolkit==3.0.43
135
+ protobuf==4.25.2
136
+ psutil==5.9.8
137
+ ptyprocess==0.7.0
138
+ pure-eval==0.2.2
139
+ py-cpuinfo==9.0.0
140
+ pyarrow==15.0.0
141
+ pyarrow-hotfix==0.6
142
+ pybase16384==0.3.7
143
+ pycparser==2.21
144
+ pycryptodome==3.20.0
145
+ pydantic==2.6.1
146
+ pydantic_core==2.16.2
147
+ pydeck==0.8.1b0
148
+ Pygments==2.17.2
149
+ pynvml==11.5.0
150
+ pyparsing==3.1.1
151
+ python-dateutil==2.8.2
152
+ python-json-logger==2.0.7
153
+ python-pptx==0.6.23
154
+ PyYAML==6.0.1
155
+ pyzmq==25.1.2
156
+ qtconsole==5.5.1
157
+ QtPy==2.4.1
158
+ referencing==0.33.0
159
+ regex==2023.12.25
160
+ rfc3339-validator==0.1.4
161
+ rfc3986-validator==0.1.1
162
+ rich==13.4.2
163
+ rpds-py==0.17.1
164
+ safetensors==0.4.2
165
+ scikit-image==0.24.0
166
+ scipy==1.12.0
167
+ seaborn==0.13.2
168
+ Send2Trash==1.8.2
169
+ sentencepiece==0.1.99
170
+ sgmllib3k==1.0.0
171
+ simplejson==3.19.2
172
+ six==1.16.0
173
+ smmap==5.0.1
174
+ sniffio==1.3.0
175
+ sortedcontainers==2.4.0
176
+ soupsieve==2.5
177
+ stack-data==0.6.3
178
+ starlette==0.37.2
179
+ sympy==1.12
180
+ tenacity==8.2.3
181
+ termcolor==2.4.0
182
+ terminado==0.18.0
183
+ tifffile==2024.7.24
184
+ tiktoken==0.6.0
185
+ timeout-decorator==0.5.0
186
+ tinycss2==1.2.1
187
+ tokenizers==0.15.2
188
+ toml==0.10.2
189
+ tomli==2.0.1
190
+ toolz==0.12.1
191
+ torch==2.2.1
192
+ torchvision==0.17.1
193
+ tornado==6.4
194
+ tqdm==4.65.2
195
+ traitlets==5.14.1
196
+ transformers==4.39.0
197
+ transformers-stream-generator==0.0.4
198
+ triton==2.2.0
199
+ types-python-dateutil==2.8.19.20240106
200
+ typing_extensions==4.9.0
201
+ tzdata==2024.1
202
+ tzlocal==5.2
203
+ uri-template==1.3.0
204
+ urllib3==1.26.18
205
+ uvicorn==0.30.6
206
+ validators==0.22.0
207
+ watchdog==4.0.0
208
+ wcwidth==0.2.13
209
+ webcolors==1.13
210
+ webencodings==0.5.1
211
+ websocket-client==1.7.0
212
+ widgetsnbextension==4.0.10
213
+ XlsxWriter==3.1.9
214
+ xtuner==0.1.23
215
+ xxhash==3.4.1
216
+ yapf==0.40.2
217
+ yarl==1.9.4
218
+ zipp==3.17.0