hayas commited on
Commit
f0dff07
·
1 Parent(s): 9cc5e5a
Files changed (6) hide show
  1. .pre-commit-config.yaml +55 -0
  2. .vscode/settings.json +21 -0
  3. README.md +5 -4
  4. app.py +137 -0
  5. requirements.txt +9 -0
  6. style.css +16 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.5.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ args: ["--fix=lf"]
14
+ - id: requirements-txt-fixer
15
+ - id: trailing-whitespace
16
+ - repo: https://github.com/myint/docformatter
17
+ rev: v1.7.5
18
+ hooks:
19
+ - id: docformatter
20
+ args: ["--in-place"]
21
+ - repo: https://github.com/pycqa/isort
22
+ rev: 5.12.0
23
+ hooks:
24
+ - id: isort
25
+ args: ["--profile", "black"]
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v1.6.1
28
+ hooks:
29
+ - id: mypy
30
+ args: ["--ignore-missing-imports"]
31
+ additional_dependencies:
32
+ ["types-python-slugify", "types-requests", "types-PyYAML"]
33
+ - repo: https://github.com/psf/black
34
+ rev: 23.10.1
35
+ hooks:
36
+ - id: black
37
+ language_version: python3.10
38
+ args: ["--line-length", "119"]
39
+ - repo: https://github.com/kynan/nbstripout
40
+ rev: 0.6.1
41
+ hooks:
42
+ - id: nbstripout
43
+ args:
44
+ [
45
+ "--extra-keys",
46
+ "metadata.interpreter metadata.kernelspec cell.metadata.pycharm",
47
+ ]
48
+ - repo: https://github.com/nbQA-dev/nbQA
49
+ rev: 1.7.0
50
+ hooks:
51
+ - id: nbqa-black
52
+ - id: nbqa-pyupgrade
53
+ args: ["--py37-plus"]
54
+ - id: nbqa-isort
55
+ args: ["--float-to-top"]
.vscode/settings.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "[python]": {
3
+ "editor.defaultFormatter": "ms-python.black-formatter",
4
+ "editor.formatOnType": true,
5
+ "editor.codeActionsOnSave": {
6
+ "source.organizeImports": true
7
+ }
8
+ },
9
+ "black-formatter.args": [
10
+ "--line-length=119"
11
+ ],
12
+ "isort.args": ["--profile", "black"],
13
+ "flake8.args": [
14
+ "--max-line-length=119"
15
+ ],
16
+ "ruff.args": [
17
+ "--line-length=119"
18
+ ],
19
+ "editor.formatOnSave": true,
20
+ "files.insertFinalNewline": true
21
+ }
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: CALM2 7B Chat
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: CALM2-7B-chat
3
+ emoji:
4
+ colorFrom: red
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 4.0.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ suggested-hardware: t4-small
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from threading import Thread
5
+ from typing import Iterator
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = "# CALM2-7B-chat"
13
+
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "32768"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "cyberagent/calm2-7b-chat"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+
26
+
27
+ def apply_chat_template(conversation: list[dict[str, str]]) -> str:
28
+ prompt = "\n".join([f"{c['role']}: {c['content']}" for c in conversation])
29
+ prompt = f"{prompt}\nASSISTANT: "
30
+ return prompt
31
+
32
+
33
+ @spaces.GPU
34
+ @torch.inference_mode()
35
+ def generate(
36
+ message: str,
37
+ chat_history: list[tuple[str, str]],
38
+ max_new_tokens: int = 1024,
39
+ temperature: float = 0.7,
40
+ top_p: float = 0.95,
41
+ top_k: int = 50,
42
+ repetition_penalty: float = 1.0,
43
+ ) -> Iterator[str]:
44
+ conversation = []
45
+ for user, assistant in chat_history:
46
+ conversation.extend([{"role": "USER", "content": user}, {"role": "ASSISTANT", "content": assistant}])
47
+ conversation.append({"role": "USER", "content": message})
48
+
49
+ prompt = apply_chat_template(conversation)
50
+ input_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
51
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
52
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
53
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
54
+ input_ids = input_ids.to(model.device)
55
+
56
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
57
+ generate_kwargs = dict(
58
+ {"input_ids": input_ids},
59
+ streamer=streamer,
60
+ max_new_tokens=max_new_tokens,
61
+ do_sample=True,
62
+ top_p=top_p,
63
+ top_k=top_k,
64
+ temperature=temperature,
65
+ num_beams=1,
66
+ repetition_penalty=repetition_penalty,
67
+ )
68
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
69
+ t.start()
70
+
71
+ outputs = []
72
+ for text in streamer:
73
+ outputs.append(text)
74
+ yield "".join(outputs)
75
+
76
+
77
+ chat_interface = gr.ChatInterface(
78
+ fn=generate,
79
+ chatbot=gr.Chatbot(show_label=False, layout="panel", height=600),
80
+ additional_inputs_accordion_name="詳細設定",
81
+ additional_inputs=[
82
+ gr.Slider(
83
+ label="Max new tokens",
84
+ minimum=1,
85
+ maximum=MAX_MAX_NEW_TOKENS,
86
+ step=1,
87
+ value=DEFAULT_MAX_NEW_TOKENS,
88
+ ),
89
+ gr.Slider(
90
+ label="Temperature",
91
+ minimum=0.1,
92
+ maximum=4.0,
93
+ step=0.1,
94
+ value=0.7,
95
+ ),
96
+ gr.Slider(
97
+ label="Top-p (nucleus sampling)",
98
+ minimum=0.05,
99
+ maximum=1.0,
100
+ step=0.05,
101
+ value=0.95,
102
+ ),
103
+ gr.Slider(
104
+ label="Top-k",
105
+ minimum=1,
106
+ maximum=1000,
107
+ step=1,
108
+ value=50,
109
+ ),
110
+ gr.Slider(
111
+ label="Repetition penalty",
112
+ minimum=1.0,
113
+ maximum=2.0,
114
+ step=0.05,
115
+ value=1.0,
116
+ ),
117
+ ],
118
+ stop_btn=None,
119
+ examples=[
120
+ ["東京の観光名所を教えて。"],
121
+ ["落武者って何?"],
122
+ ["暴れん坊将軍って誰のこと?"],
123
+ ["人がヘリを食べるのにかかる時間は?"],
124
+ ],
125
+ )
126
+
127
+ with gr.Blocks(css="style.css") as demo:
128
+ gr.Markdown(DESCRIPTION)
129
+ gr.DuplicateButton(
130
+ value="Duplicate Space for private use",
131
+ elem_id="duplicate-button",
132
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
133
+ )
134
+ chat_interface.render()
135
+
136
+ if __name__ == "__main__":
137
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.24.1
2
+ bitsandbytes==0.41.1
3
+ gradio==4.0.2
4
+ protobuf==3.20.3
5
+ scipy==1.11.3
6
+ sentencepiece==0.1.99
7
+ spaces==0.18.0
8
+ torch==2.0.0
9
+ transformers==4.35.0
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ .contain {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }