artek0chumak commited on
Commit
fd09410
·
1 Parent(s): e3e6a48
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitmodules +3 -0
  2. app.py +6 -8
  3. petals +1 -0
  4. petals/README.md +0 -203
  5. petals/cli/__init__.py +0 -0
  6. petals/cli/config.json +0 -20
  7. petals/cli/convert_model.py +0 -93
  8. petals/cli/deploy_server.sh +0 -79
  9. petals/cli/inference_one_block.py +0 -53
  10. petals/cli/local_server_config_example.cfg +0 -5
  11. petals/cli/remote_server_config_example.cfg +0 -6
  12. petals/cli/run_local_servers.sh +0 -109
  13. petals/cli/run_remote_servers.sh +0 -110
  14. petals/cli/run_server.py +0 -129
  15. petals/cli/speed_test.py +0 -1941
  16. petals/examples/prompt-tuning-personachat.ipynb +0 -339
  17. petals/pyproject.toml +0 -10
  18. petals/requirements-dev.txt +0 -6
  19. petals/requirements.txt +0 -8
  20. petals/src/__init__.py +0 -6
  21. petals/src/bloom/__init__.py +0 -2
  22. petals/src/bloom/block.py +0 -255
  23. petals/src/bloom/from_pretrained.py +0 -86
  24. petals/src/bloom/model.py +0 -583
  25. petals/src/bloom/ops.py +0 -246
  26. petals/src/client/__init__.py +0 -5
  27. petals/src/client/inference_session.py +0 -216
  28. petals/src/client/remote_forward_backward.py +0 -156
  29. petals/src/client/remote_generation.py +0 -257
  30. petals/src/client/remote_model.py +0 -197
  31. petals/src/client/remote_sequential.py +0 -103
  32. petals/src/client/sequence_manager.py +0 -153
  33. petals/src/client/sequential_autograd.py +0 -204
  34. petals/src/client/spending_policy.py +0 -14
  35. petals/src/data_structures.py +0 -41
  36. petals/src/dht_utils.py +0 -180
  37. petals/src/server/__init__.py +0 -0
  38. petals/src/server/backend.py +0 -84
  39. petals/src/server/block_selection.py +0 -106
  40. petals/src/server/cache.py +0 -143
  41. petals/src/server/handler.py +0 -421
  42. petals/src/server/runtime.py +0 -198
  43. petals/src/server/server.py +0 -475
  44. petals/src/server/task_pool.py +0 -178
  45. petals/src/server/task_prioritizer.py +0 -20
  46. petals/src/server/throughput.py +0 -127
  47. petals/src/utils/__init__.py +0 -0
  48. petals/src/utils/convert_8bit.py +0 -41
  49. petals/src/utils/generation_algorithms.py +0 -78
  50. petals/src/utils/generation_constraints.py +0 -84
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "petals"]
2
+ path = petals
3
+ url = https://github.com/bigscience-workshop/petals
app.py CHANGED
@@ -8,17 +8,15 @@ import gradio as gr
8
  from src.client.remote_model import DistributedBloomForCausalLM
9
 
10
 
11
- # MODEL_NAME = "bigscience/test-bloomd-6b3" # select model you like
12
- # INITIAL_PEERS = ["/ip4/193.106.95.184/tcp/31000/p2p/QmSg7izCDtowVTACbUmWvEiQZNY4wgCQ9T9Doo66K59X6q"]
13
 
14
- # tokenizer = transformers.BloomTokenizerFast.from_pretrained("bigscience/test-bloomd-6b3")
15
- # model = DistributedBloomForCausalLM.from_pretrained("bigscience/test-bloomd-6b3", initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32)
16
 
17
  def inference(text, seq_length=1):
18
- # input_ids = tokenizer([text], return_tensors="pt").input_ids
19
- # output = model.generate(input_ids, max_new_tokens=seq_length)
20
- # return tokenizer.batch_decode(output)[0]
21
- return text
22
 
23
  iface = gr.Interface(
24
  fn=inference,
 
8
  from src.client.remote_model import DistributedBloomForCausalLM
9
 
10
 
11
+ MODEL_NAME = "bigscience/bloom-petals"
 
12
 
13
+ tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
14
+ model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME)
15
 
16
  def inference(text, seq_length=1):
17
+ input_ids = tokenizer([text], return_tensors="pt").input_ids
18
+ output = model.generate(input_ids, max_new_tokens=seq_length)
19
+ return tokenizer.batch_decode(output)[0]
 
20
 
21
  iface = gr.Interface(
22
  fn=inference,
petals ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit ab41223b17c17dd1035a42318b03d4b92decd063
petals/README.md DELETED
@@ -1,203 +0,0 @@
1
- <p align="center">
2
- <img src="https://i.imgur.com/7eR7Pan.png" width="400"><br>
3
- Decentralized platform for running 100B+ language models<br><br>
4
- <a href="https://github.com/bigscience-workshop/petals/actions">
5
- <img src="https://github.com/bigscience-workshop/petals/actions/workflows/run-tests.yaml/badge.svg?branch=main">
6
- </a>
7
- <a href="https://github.com/psf/black">
8
- <img src="https://img.shields.io/badge/code%20style-black-000000.svg">
9
- </a>
10
- </p>
11
-
12
- ## Key features
13
-
14
- - Run inference or fine-tune large language models like [BLOOM-176B](https://huggingface.co/bigscience/bloom) by joining compute resources with people all over the Internet. No need to have high-end GPUs.
15
- - It's difficult to fit the whole BLOOM-176B into GPU memory [unless](https://twitter.com/Tim_Dettmers/status/1559892918395031552) you have multiple high-end GPUs. Instead, **Petals** allows to load and serve a small part of the model, then team up with people serving all the other parts to run inference or fine-tuning.
16
- - This way, one inference step takes ≈ 1 sec — much faster than possible with offloading. Enough for chatbots and other interactive apps.
17
- - Beyond traditional language model APIs — you can employ any fine-tuning and sampling methods by executing custom paths through the model or accessing its hidden states. This allows for the comforts of an API with the flexibility of PyTorch.
18
-
19
- <p align="center">
20
- <b><a href="https://arxiv.org/pdf/2209.01188.pdf">[Read paper]</a></b> | <b><a href="https://petals.ml/">[View website]</a></b>
21
- </p>
22
-
23
- ## How it works?
24
-
25
- <p align="center">
26
- <img src="https://i.imgur.com/RTYF3yW.png" width="800">
27
- </p>
28
-
29
- ### 🛠️ Examples
30
-
31
- Petals integrates seamlessly with PyTorch and the Hugging Face [Transformers](https://github.com/huggingface/transformers) library.
32
-
33
- This snippet shows how to **(a)** generate text with BLOOM and **(b)** solve a sequence classification task via soft prompt tuning:
34
-
35
- ```python
36
- # Initialize distributed BLOOM and connect to the swarm
37
- model = DistributedBloomForCausalLM.from_pretrained(
38
- "bigscience/bloom-petals", tuning_mode="ptune", initial_peers=SEE_BELOW
39
- ) # Embeddings & prompts are on your device, BLOOM blocks are distributed
40
-
41
- print("Generated:", model.generate(tokenized_prefix, max_new_tokens=5))
42
-
43
- # Training (updates only local prompts / adapters)
44
- optimizer = torch.optim.AdamW(model.parameters())
45
- for input_ids, labels in data_loader:
46
- outputs = model.forward(input_ids)
47
- loss = cross_entropy(outputs.logits, labels)
48
- optimizer.zero_grad()
49
- loss.backward()
50
- optimizer.step()
51
- ```
52
-
53
- ### 🚧 This project is in active development
54
-
55
- Be careful: some features may not work, interfaces may change, and we have no detailed docs yet (see [roadmap](https://github.com/bigscience-workshop/petals/issues/12)).
56
-
57
- A stable version of the code and a public swarm open to everyone will be released in November 2022. You can [subscribe](https://petals.ml/) to be emailed when it happens or fill in [this form](https://forms.gle/TV3wtRPeHewjZ1vH9) to help the public launch by donating GPU time. In the meantime, you can launch and use your own private swarm.
58
-
59
- ### 🔒 Privacy and security
60
-
61
- If you work with sensitive data, you should only use a private swarm (or a subset of servers in the public swarm) hosted by people and institutions you trust, who are authorized to process this data.
62
-
63
- This is important because it's technically possible for peers serving model layers to recover input data or model outputs. Also, if there are malicious peers, they may alter their outputs to influence the model outputs. See a more detailed discussion in Section 4 of our [paper](https://arxiv.org/pdf/2209.01188.pdf).
64
-
65
- ## FAQ
66
-
67
- 1. **What's the motivation for people to host model layers in the public swarm?**
68
-
69
- People who run inference and fine-tuning themselves get a certain speedup if they host a part of the model locally. Some may be also motivated to "give back" to the community helping them to run the model (similarly to how [BitTorrent](https://en.wikipedia.org/wiki/BitTorrent) users help others by sharing data they have already downloaded).
70
-
71
- Since it may be not enough for everyone, we are also working on introducing explicit __incentives__ ("bloom points") for people donating their GPU time to the public swarm. Once this system is ready, people who earned these points will be able to spend them on inference/fine-tuning with higher priority or increased security guarantees, or (maybe) exchange them for other rewards.
72
-
73
- 2. **Why is the platform named "Petals"?**
74
-
75
- "Petals" is a metaphor for people serving different parts of the model. Together, they host the entire language model &mdash; [BLOOM](https://huggingface.co/bigscience/bloom).
76
-
77
- While our platform focuses on BLOOM now, we aim to support more [foundation models](https://arxiv.org/abs/2108.07258) in future.
78
-
79
- ## Installation
80
-
81
- Here's how to install the dependencies with conda:
82
- ```
83
- conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch
84
- pip install -r requirements.txt
85
- ```
86
-
87
- This script uses Anaconda to install cuda-enabled PyTorch.
88
- If you don't have anaconda, you can get it from [here](https://www.anaconda.com/products/distribution).
89
- If you don't want anaconda, you can install PyTorch [any other way](https://pytorch.org/get-started/locally/).
90
- If you want to run models with 8-bit weights, please install **PyTorch with CUDA 11** or newer for compatility with [bitsandbytes](https://github.com/timDettmers/bitsandbytes).
91
-
92
- __OS support:__ Currently, Petals only supports Linux operating systems. On Windows 11, you can run Petals with GPU enabled inside WSL2 ([read more](https://learn.microsoft.com/en-us/windows/ai/directml/gpu-cuda-in-wsl)).
93
- For macOS, you can *probably* run everything normally if you manage to install dependencies, but we do not guarantee this.
94
-
95
-
96
- ## 🚀 Getting Started
97
-
98
- This is a toy example running on a local machine without GPU and with a tiny model.
99
- For a detailed instruction with larger models, see ["Launch your own swarm"](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm).
100
-
101
- First, run a couple of servers, each in a separate shell. To launch your first server, run:
102
- ```bash
103
- python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
104
- --host_maddrs /ip4/127.0.0.1/tcp/31337 # use port 31337, local connections only
105
- ```
106
-
107
- This server will host 8 (out of 24) blocks of a [tiny 560M version](https://huggingface.co/bloom-testing/test-bloomd-560m-main) of the BLOOM model that was converted for Petals.
108
-
109
- > If you'd like to run a swarm of servers with the full BLOOM straight away, please see [this instruction](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) (you'll need several GPUs!). To run a different model, see [this wiki page](https://github.com/bigscience-workshop/petals/wiki/Run-a-custom-model-with-PETALS).
110
-
111
- Once the server has started, it will print out a ton of information, including an important line like this:
112
-
113
- ```bash
114
- Mon Day 01:23:45.678 [INFO] Running DHT node on ['/ip4/127.0.0.1/tcp/31337/p2p/ALongStringOfCharacters'], initial peers = []
115
- ```
116
-
117
- You can use this address (`/ip4/whatever/else`) to connect additional servers. Open another terminal and run:
118
-
119
- ```bash
120
- python -m cli.run_server bloom-testing/test-bloomd-560m-main --num_blocks 8 --torch_dtype float32 \
121
- --host_maddrs /ip4/127.0.0.1/tcp/0 \
122
- --initial_peers /ip4/127.0... # <-- TODO: Copy the address of another server here
123
- # e.g. --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS1GecIfYouAreReadingThisYouNeedToCopyYourServerAddressCBBq
124
- ```
125
-
126
- You can assign `--initial_peers` to one or multiple addresses of other servers, not necessarily the first one.
127
- The only requirement is that at least one of them is running at the time.
128
-
129
- Before you proceed, __please run 3 servers__ for a total of 24 blocks (3x8). If you are running a different model,
130
- make sure your servers have enough total `--num_blocks` to cover that model.
131
-
132
- Once your have enough servers, you can use them to train and/or inference the model:
133
- ```python
134
- import torch
135
- import torch.nn.functional as F
136
- import transformers
137
- from src import DistributedBloomForCausalLM
138
-
139
- initial_peers = [TODO_put_one_or_more_server_addresses_here] # e.g. ["/ip4/127.0.0.1/tcp/more/stuff/here"]
140
- tokenizer = transformers.BloomTokenizerFast.from_pretrained("bloom-testing/test-bloomd-560m-main")
141
- model = DistributedBloomForCausalLM.from_pretrained(
142
- "bloom-testing/test-bloomd-560m-main", initial_peers=initial_peers, low_cpu_mem_usage=True, torch_dtype=torch.float32
143
- ) # this model has only embeddings / logits, all transformer blocks rely on remote servers
144
-
145
-
146
- inputs = tokenizer("a cat sat", return_tensors="pt")["input_ids"]
147
- remote_outputs = model.generate(inputs, max_length=10)
148
- print(tokenizer.decode(remote_outputs[0])) # "a cat sat in the back of the car,"
149
-
150
- # "train" input embeddings by backprop through distributed transformer blocks
151
- model.transformer.word_embeddings.weight.requires_grad = True
152
- outputs = model.forward(input_ids=inputs)
153
- loss = F.cross_entropy(outputs.logits.flatten(0, 1), inputs.flatten())
154
- loss.backward()
155
- print("Gradients (norm):", model.transformer.word_embeddings.weight.grad.norm())
156
- ```
157
-
158
- Of course, this is a simplified code snippet. For actual training, see our example on "deep" prompt-tuning here: [examples/prompt-tuning-personachat.ipynb](./examples/prompt-tuning-personachat.ipynb).
159
-
160
- Here's a [more advanced tutorial](https://github.com/bigscience-workshop/petals/wiki/Launch-your-own-swarm) that covers 8-bit quantization and best practices for running Petals.
161
-
162
- ## 🛠️ Development
163
-
164
- Petals uses pytest with a few plugins. To install them, run `pip install -r requirements-dev.txt`
165
-
166
- To run minimalistic tests, spin up some servers:
167
-
168
- ```bash
169
- export MODEL_NAME=bloom-testing/test-bloomd-560m-main
170
- export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
171
- python -m cli.run_server $MODEL_NAME --block_indices 0:12 --throughput 1 --torch_dtype float32 \
172
- --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 &> server1.log &
173
- sleep 5 # wait for the first server to initialize DHT
174
- python -m cli.run_server $MODEL_NAME --block_indices 12:24 --throughput 1 --torch_dtype float32 \
175
- --initial_peers /ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g &> server2.log &
176
-
177
- tail -f server1.log server2.log # view logs for both servers
178
- # after you're done, kill servers with 'pkill -f cli.run_server'
179
- ```
180
-
181
- Then launch pytest:
182
-
183
- ```
184
- export MODEL_NAME=bloom-testing/test-bloomd-560m-main REF_NAME=bigscience/bloom-560m
185
- export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
186
- PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
187
- ```
188
-
189
- The automated tests use a more complex server configuration that can be found [here](https://github.com/bigscience-workshop/petals/blob/main/.github/workflows/run-tests.yaml).
190
-
191
- ### Code style
192
-
193
- We use [black](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html) and [isort](https://pycqa.github.io/isort/) for all pull requests.
194
- Before commiting your code, simply run `black . && isort .` and you will be fine.
195
-
196
- --------------------------------------------------------------------------------
197
-
198
- <p align="center">
199
- This project is a part of the <a href="https://bigscience.huggingface.co/">BigScience</a> research workshop.
200
- </p>
201
- <p align="center">
202
- <img src="https://petals.ml/bigscience.png" width="150">
203
- </p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/__init__.py DELETED
File without changes
petals/cli/config.json DELETED
@@ -1,20 +0,0 @@
1
- {
2
- "apply_residual_connection_post_layernorm": false,
3
- "attention_dropout": 0.0,
4
- "attention_softmax_in_fp32": true,
5
- "bos_token_id": 1,
6
- "eos_token_id": 2,
7
- "hidden_dropout": 0.0,
8
- "initializer_range": 0.02,
9
- "layer_norm_epsilon": 1e-05,
10
- "masked_softmax_fusion": true,
11
- "model_type": "bloom",
12
- "n_embed": 14336,
13
- "n_layer": 70,
14
- "num_attention_heads": 112,
15
- "pretraining_tp": 4,
16
- "slow_but_exact": false,
17
- "transformers_version": "4.20.0.dev0",
18
- "use_cache": true,
19
- "vocab_size": 250880
20
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/convert_model.py DELETED
@@ -1,93 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import psutil
5
- import torch.backends.quantized
6
- import torch.nn as nn
7
- import transformers
8
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
9
- from huggingface_hub import Repository
10
- from tqdm.auto import tqdm
11
-
12
- from src import BloomModel
13
- from src.bloom.from_pretrained import BLOCK_BRANCH_PREFIX, CLIENT_BRANCH
14
- from src.client import DistributedBloomConfig
15
-
16
- use_hivemind_log_handler("in_root_logger")
17
- logger = get_logger(__file__)
18
-
19
- DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
20
-
21
-
22
- if __name__ == "__main__":
23
- parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.")
24
-
25
- parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained")
26
- parser.add_argument("--revision", type=str, default=None, help="Optional commit id from HF hub")
27
- parser.add_argument("--torch_dtype", type=str, default="auto", help="Load initial model in this dtype")
28
- parser.add_argument("--output_path", type=str, default="./converted_model", help="Track output repo to this folder")
29
- parser.add_argument("--output_repo", type=str, default="bigscience/test-bloomd", help="Push to this HF hub repo")
30
- parser.add_argument("--client_branch", type=str, default=CLIENT_BRANCH, help="Save client version to this branch")
31
- parser.add_argument(
32
- "--block_branch_prefix", type=str, default=BLOCK_BRANCH_PREFIX, help="Save blocks to branches with this prefix"
33
- )
34
- parser.add_argument(
35
- "--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
36
- )
37
- parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
38
- parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
39
- args = parser.parse_args()
40
-
41
- free_ram_gb = psutil.virtual_memory().available / 2**30
42
- if args.model == "bigscience/bloom" and free_ram_gb < 400:
43
- logger.warning(f"ACHTUNG! converting bloom-176b will use up 350-400GB RAM, you have {free_ram_gb:.3f} free")
44
-
45
- assert args.torch_dtype in DTYPE_MAP, f"torch_dtype must be one of {list(DTYPE_MAP.keys())}"
46
- if os.path.exists(args.output_path) and (
47
- len(os.listdir(args.output_path)) != 0 or not os.path.isdir(args.output_path)
48
- ):
49
- raise FileExistsError(f"Output path {args.output_path} already exists and is not an empty directory")
50
-
51
- logger.info(f"Loading source model {args.model} (this may take a few minutes)")
52
- config = DistributedBloomConfig.from_pretrained(
53
- args.model, use_auth_token=args.use_auth_token, revision=args.revision
54
- )
55
- config.dht_prefix = args.output_repo
56
-
57
- model = BloomModel.from_pretrained(
58
- args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
59
- )
60
- if args.resize_token_embeddings:
61
- logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
62
- model.resize_token_embeddings(args.resize_token_embeddings)
63
- config.vocab_size = args.resize_token_embeddings
64
-
65
- tokenizer = transformers.AutoTokenizer.from_pretrained(
66
- args.model, use_auth_token=args.use_auth_token, revision=args.revision
67
- )
68
- os.makedirs(args.output_path, exist_ok=True)
69
-
70
- repo = Repository(args.output_path, clone_from=args.output_repo, use_auth_token=args.use_auth_token)
71
- repo.git_pull()
72
-
73
- transformer_blocks = model.h
74
- logger.info(
75
- f"Saving transformer blocks to {args.output_repo}@{args.block_branch_prefix}0"
76
- f" - {args.output_repo}@{args.block_branch_prefix}{len(transformer_blocks)}"
77
- )
78
- for i, block in enumerate(tqdm(transformer_blocks)):
79
- repo.git_checkout(args.client_branch, create_branch_ok=True)
80
- with repo.commit(
81
- commit_message=args.commit_message, branch=args.block_branch_prefix + str(i), track_large_files=True
82
- ):
83
- torch.save(block.state_dict(), "./pytorch_model.bin")
84
-
85
- logger.info(f"Saving client-side modules to {args.output_repo}@{args.client_branch}")
86
- repo.git_checkout(args.client_branch, create_branch_ok=True)
87
- with repo.commit(commit_message=args.commit_message, branch=args.client_branch, track_large_files=True):
88
- model.h = nn.ModuleList()
89
- model.save_pretrained(".")
90
- tokenizer.save_pretrained(".")
91
- config.save_pretrained(".")
92
-
93
- logger.info(f"Converted {args.model} and pushed to {args.output_repo}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/deploy_server.sh DELETED
@@ -1,79 +0,0 @@
1
- #!/usr/bin/env bash
2
-
3
- #################
4
- # Parse options #
5
- #################
6
-
7
- instructions() {
8
- echo "Usage: $0 [-m] [-i] [ -d ] [ -p ] [ -b ] [-a] [-t]" >&2
9
- echo " -m: model name"
10
- echo " -i: initial peer"
11
- echo " -d: device" >&2
12
- echo " -p: server identity path" >&2
13
- echo " -b: block_ids" >&2
14
- echo " -a: host maddrs" >&2
15
- echo " -t: whether to run local tests" >&2
16
- exit 1
17
- }
18
-
19
- if [ ! $# -ge 8 ]; then
20
- instructions
21
- fi
22
-
23
- while getopts ":m:i:d:p:b:a:t:" option; do
24
- case $option in
25
- m) MODEL_NAME=${OPTARG}
26
- ;;
27
- i) INITIAL_PEER=${OPTARG}
28
- ;;
29
- d) DEVICE=${OPTARG}
30
- ;;
31
- p) SERVER_ID_PATH=${OPTARG}
32
- ;;
33
- b) BLOCK_IDS=${OPTARG}
34
- ;;
35
- a) HOST_MADDR=${OPTARG} # TODO: allow several maddrs
36
- ;;
37
- t) RUN_LOCAL_TESTS=true
38
- ;;
39
- \?) instructions
40
- ;;
41
- esac
42
- done
43
-
44
-
45
- echo "=========="
46
- echo "= Config ="
47
- echo "=========="
48
- echo "Model name: ${MODEL_NAME}"
49
- echo "Initial peer: ${INITIAL_PEER}"
50
- echo "Device: ${DEVICE}"
51
- echo "Server name: ${SERVER_ID_PATH}"
52
- echo "Server address: ${HOST_MADDR}"
53
- echo "Bloom blocks: ${BLOCK_IDS}"
54
-
55
-
56
- ###########################
57
- # Install or activate env #
58
- ###########################
59
-
60
- # TODO fix bug with self calling
61
- source ~/miniconda3/etc/profile.d/conda.sh
62
- if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
63
- conda activate bloom-demo
64
- else
65
- conda create -y --name bloom-demo python=3.8.12 pip
66
- conda activate bloom-demo
67
-
68
- conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
69
- pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
70
- pip install -i https://pypi.org/simple -r requirements.txt
71
- pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
72
- fi
73
-
74
- ##############
75
- # Run server #
76
- ##############
77
-
78
- python -m cli.run_server --converted_model_name_or_path ${MODEL_NAME} --device ${DEVICE} --initial_peer ${INITIAL_PEER} \
79
- --block_indices ${BLOCK_IDS} --compression UNIFORM_8BIT --identity_path ${SERVER_ID_PATH} --host_maddrs ${HOST_MADDR} --load_in_8bit &> ${SERVER_ID_PATH}.log
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/inference_one_block.py DELETED
@@ -1,53 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
5
- from tqdm.auto import trange
6
-
7
- from src.bloom.block import BloomBlock
8
- from src.bloom.model import BloomConfig
9
- from src.bloom.ops import build_alibi_tensor
10
-
11
- use_hivemind_log_handler("in_root_logger")
12
- logger = get_logger(__file__)
13
-
14
- logger.warning("inference_one_block will soon be deprecated in favour of tests!")
15
-
16
-
17
- def print_device_info(device=None):
18
- """Prints device stats. Code from https://stackoverflow.com/a/53374933/12891528"""
19
- device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
20
- logger.info(f"Using device: {device}")
21
-
22
- # Additional Info when using cuda
23
- if device.type == "cuda":
24
- logger.info(torch.cuda.get_device_name(0))
25
- logger.info(f"Memory Usage:")
26
- logger.info(f"Allocated: {round(torch.cuda.memory_allocated(0) / 1024 ** 3, 1)} GB")
27
- logger.info(f"Cached: {round(torch.cuda.memory_cached(0) / 1024 ** 3, 1)} GB")
28
-
29
-
30
- if __name__ == "__main__":
31
- parser = argparse.ArgumentParser(description="Run a single bloom block locally on dummy data")
32
- parser.add_argument("--config", required=True, type=str, help="Path to a config json file")
33
- parser.add_argument("--state_dict", default=None, type=str, help="Optional path to saved block state dict")
34
- parser.add_argument("--layer_index", default=0, type=int, help="Optional path to saved block state dict")
35
- parser.add_argument("--num_steps", default=500, type=int, help="How many inference steps to run")
36
- parser.add_argument("--device", default=None, type=str, help="Run inference on this device")
37
- args = parser.parse_args()
38
-
39
- if args.device is None:
40
- args.device = "cuda" if torch.cuda.is_available() else "cpu"
41
-
42
- config = BloomConfig.from_json_file(args.config)
43
- block = BloomBlock(config, args.layer_index).to(args.device)
44
-
45
- cache = None
46
-
47
- for i in trange(args.num_steps):
48
- dummy_input = torch.randn(1, 1, config.hidden_size, device=args.device)
49
- alibi = build_alibi_tensor(i + 1, config.num_attention_heads).to(args.device)
50
- with torch.no_grad():
51
- outputs, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
52
-
53
- print_device_info(args.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/local_server_config_example.cfg DELETED
@@ -1,5 +0,0 @@
1
- device=cpu
2
- block_ids=2:3
3
- id_path=./server.id
4
- maddr=/ip4/127.0.0.1/tcp/30000
5
- #
 
 
 
 
 
 
petals/cli/remote_server_config_example.cfg DELETED
@@ -1,6 +0,0 @@
1
- name=bloom-peer-0.bloom.net
2
- device=cpu
3
- block_ids=1:3
4
- id_path=./server.id
5
- maddr=/ip4/0.0.0.0/tcp/30000
6
- #
 
 
 
 
 
 
 
petals/cli/run_local_servers.sh DELETED
@@ -1,109 +0,0 @@
1
- # !/usr/bin/env bash
2
-
3
- #################
4
- # Parse options #
5
- #################
6
-
7
- instructions() {
8
- echo "Usage: $0 [-n] [-c]" >&2
9
- echo " -n: number of servers to run" >&2
10
- echo " -c: path to the server configs" >&2
11
- exit 1
12
- }
13
-
14
- if [ $# != 4 ]; then
15
- instructions
16
- fi
17
-
18
- while getopts ":n:c:t:" option; do
19
- case $option in
20
- n) NUM_SERVERS=${OPTARG}
21
- ;;
22
- c) CONFIG_PATH=${OPTARG}
23
- ;;
24
- \?) instructions
25
- ;;
26
- esac
27
- done
28
-
29
-
30
- ###########################
31
- # Install or activate env #
32
- ###########################
33
-
34
- source ~/miniconda3/etc/profile.d/conda.sh
35
- if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
36
- conda activate bloom-demo
37
- else
38
- conda create -y --name bloom-demo python=3.8.12 pip
39
- conda activate bloom-demo
40
-
41
- conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
42
- pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
43
- pip install -i https://pypi.org/simple -r requirements.txt
44
- pip install -i https://test.pypi.org/simple/ bitsandbytes-cuda113
45
- fi
46
-
47
-
48
- #######################
49
- # Create Initial peer #
50
- #######################
51
-
52
- hivemind-dht &> tmp.out &
53
- sleep 5
54
- INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-1])" )
55
- echo "Initial peer: ${INITIAL_PEER}"
56
-
57
-
58
- ##############################
59
- # Initialize the config file #
60
- ##############################
61
-
62
- typeset -A cfg
63
- cfg=( # set default values in config array
64
- [device]="cpu"
65
- [block_ids]="1:2"
66
- [id_path]="server.id"
67
- [maddr]="/ip4/127.0.0.1/tcp/30000"
68
- )
69
-
70
- ###############
71
- # Run servers #
72
- ###############
73
-
74
- for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
75
- do
76
- ###############
77
- # Read config #
78
- ###############
79
-
80
- while read line
81
- do
82
- if echo $line | grep -F = &>/dev/null
83
- then
84
- varname=$(echo "$line" | cut -d '=' -f 1)
85
- cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
86
- fi
87
- done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
88
-
89
- echo "=== Server #${SERVER_ID} ==="
90
- echo "Server ID: ${cfg[id_path]}"
91
- echo "Device: ${cfg[device]}"
92
- echo "Bloom block ids: ${cfg[block_ids]}"
93
- echo "Host maddr: ${cfg[maddr]}"
94
- echo ""
95
-
96
- ##############
97
- # Run server #
98
- ##############
99
-
100
- tmux new-session -d -s "Server_${SERVER_ID}" bash cli/deploy_server.sh -m "bigscience/test-bloomd" -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}
101
- done
102
-
103
- #####################
104
- # Kill initial peer #
105
- #####################
106
-
107
- sleep 10
108
- pkill -f hivemind-dht # TODO: kill only particular pids of hivemind-dht
109
- rm tmp.out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/run_remote_servers.sh DELETED
@@ -1,110 +0,0 @@
1
- # !/usr/bin/env bash
2
-
3
- SSH_KEY_PATH="~/.ssh/<YOUR_KEY>"
4
-
5
- #################
6
- # Parse options #
7
- #################
8
-
9
- instructions() {
10
- echo "Usage: $0 [-u] [-n] [-c]" >&2
11
- echo " -u: username" >&2
12
- echo " -n: number of servers to run" >&2
13
- echo " -c: path to the server configs" >&2
14
- exit 1
15
- }
16
-
17
- if [ $# != 6 ]; then
18
- instructions
19
- fi
20
-
21
- while getopts ":u:n:c:" option; do
22
- case $option in
23
- u) USERNAME=${OPTARG}
24
- ;;
25
- n) NUM_SERVERS=${OPTARG}
26
- ;;
27
- c) CONFIG_PATH=${OPTARG}
28
- ;;
29
- \?) instructions
30
- ;;
31
- esac
32
- done
33
-
34
-
35
- ###########################
36
- # Install or activate env #
37
- ###########################
38
-
39
- source ~/miniconda3/etc/profile.d/conda.sh
40
- if conda env list | grep ".*bloom-demo.*" >/dev/null 2>/dev/null; then
41
- conda activate bloom-demo
42
- else
43
- conda create -y --name bloom-demo python=3.8.12 pip
44
- conda activate bloom-demo
45
-
46
- conda install -y -c conda-forge cudatoolkit-dev==11.3.1 cudatoolkit==11.3.1 cudnn==8.2.1.32
47
- pip install -i https://pypi.org/simple torch==1.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
48
- pip install -i https://pypi.org/simple -r requirements.txt
49
- fi
50
-
51
-
52
- #######################
53
- # Create Initial peer #
54
- #######################
55
-
56
- hivemind-dht &> tmp.out &
57
-
58
- sleep 5
59
- INITIAL_PEER=$(python -c "with open('tmp.out') as f: print(f.readlines()[1].split()[-2])" )
60
- rm tmp.out
61
- echo "Initial peer: ${INITIAL_PEER}"
62
-
63
-
64
- ##############################
65
- # Initialize the config file #
66
- ##############################
67
-
68
- typeset -A cfg
69
- cfg=( # set default values in config array
70
- [name]=""
71
- [device]="cpu"
72
- [block_ids]="1:2"
73
- [id_path]="server.id"
74
- [maddr]="/ip4/0.0.0.0/tcp/30000"
75
- )
76
-
77
- ###############
78
- # Run servers #
79
- ###############
80
-
81
- for SERVER_ID in $(seq 0 $(( $NUM_SERVERS - 1 )) )
82
- do
83
- ###############
84
- # Read config #
85
- ###############
86
-
87
- while read line
88
- do
89
- if echo $line | grep -F = &>/dev/null
90
- then
91
- varname=$(echo "$line" | cut -d '=' -f 1)
92
- cfg[$varname]=$(echo "$line" | cut -d '=' -f 2-)
93
- fi
94
- done < ${CONFIG_PATH}/server_${SERVER_ID}.cfg
95
-
96
- SERVER_NAME="${USERNAME}@${cfg[name]}"
97
- echo "=== Server #${SERVER_ID} ==="
98
- echo "Server name ${SERVER_NAME}"
99
- echo "Server ID: ${cfg[id_path]}"
100
- echo "Device: ${cfg[device]}"
101
- echo "Bloom block ids: ${cfg[block_ids]}"
102
- echo "Host maddr: ${cfg[maddr]}"
103
- echo "================="
104
-
105
- ##############
106
- # Run server #
107
- ##############
108
-
109
- ssh -i ${SSH_KEY_PATH} ${SERVER_NAME} "tmux new-session -d -s 'Server_${SERVER_ID}' 'cd bloom-demo && bash cli/deploy_server.sh -i ${INITIAL_PEER} -d ${cfg[device]} -p ${cfg[id_path]} -b ${cfg[block_ids]} -a ${cfg[maddr]}'"
110
- done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/run_server.py DELETED
@@ -1,129 +0,0 @@
1
- import argparse
2
-
3
- import configargparse
4
- from hivemind.proto.runtime_pb2 import CompressionType
5
- from hivemind.utils.limits import increase_file_limit
6
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
7
- from humanfriendly import parse_size
8
-
9
- from src.server.server import Server
10
-
11
- use_hivemind_log_handler("in_root_logger")
12
- logger = get_logger(__file__)
13
-
14
-
15
- def main():
16
- # fmt:off
17
- parser = configargparse.ArgParser(default_config_files=["config.yml"],
18
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19
- parser.add('-c', '--config', required=False, is_config_file=True, help='config file path')
20
-
21
- group = parser.add_mutually_exclusive_group(required=True)
22
- group.add_argument('--converted_model_name_or_path', type=str, default=None,
23
- help="path or name of a pretrained model, converted with cli/convert_model.py")
24
- group.add_argument('model', nargs='?', type=str, help="same as --converted_model_name_or_path")
25
-
26
- parser.add_argument('--num_blocks', type=int, default=None, help="The number of blocks to serve")
27
- parser.add_argument('--block_indices', type=str, default=None, help="Specific block indices to serve")
28
- parser.add_argument('--prefix', type=str, default=None, help="Announce all blocks with this prefix. By default,"
29
- "use the same name as in the converted model.")
30
- parser.add_argument('--host_maddrs', nargs='+', default=['/ip4/0.0.0.0/tcp/0'], required=False,
31
- help='Multiaddrs to listen for external connections from other p2p instances; default: all IPv4 and TCP: /ip4/0.0.0.0/tcp/0')
32
- parser.add_argument('--announce_maddrs', nargs='+', default=None, required=False,
33
- help='Visible multiaddrs the host announces for external connections from other p2p instances')
34
-
35
- parser.add_argument('--compression', type=str, default='NONE', required=False, help='Tensor compression communication')
36
-
37
- parser.add_argument('--num_handlers', type=int, default=8, required=False,
38
- help='server will use this many processes to handle incoming requests')
39
- parser.add_argument('--min_batch_size', type=int, default=1,
40
- help='Minimum required batch size for all operations (in total tokens)')
41
- parser.add_argument('--max_batch_size', type=int, default=16384,
42
- help='The total number of tokens in the same batch will not exceed this value')
43
- parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
44
- help='Pre-form this many subsequent batches while GPU is processing the current one')
45
- parser.add_argument('--sender_threads', type=int, default=1, required=False,
46
- help='Use this many threads to pass results/exceptions from Runtime to Pools')
47
- parser.add_argument('--inference_max_length', type=int, default=16384,
48
- help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
49
- parser.add_argument('--cache_dir', type=str, default=None,
50
- help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
51
- parser.add_argument('--device', type=str, default=None, required=False,
52
- help='all blocks will use this device in torch notation; default: cuda if available else cpu')
53
- parser.add_argument("--torch_dtype", type=str, default="auto",
54
- help="Use this dtype to store block weights and do computations. "
55
- "By default, respect the dtypes in the pre-trained state dict.")
56
- parser.add_argument('--attn_cache_size', type=str, default=None,
57
- help='The size of GPU memory allocated for storing past attention keys/values between inference'
58
- ' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
59
- parser.add_argument('--revision', type=str, default='main',
60
- help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
61
- "and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
62
-
63
- parser.add_argument('--throughput',
64
- type=lambda value: value if value in ['auto', 'eval'] else float(value),
65
- default='auto',
66
- help='Expected server throughput (a float measured in RPS). '
67
- 'If set to "auto" (default), the script evaluates network and compute throughput '
68
- 'on the first run and uses these estimates for future runs. '
69
- 'If set to "eval", the script re-evaluates the throughput and overrides the cache.')
70
- parser.add_argument('--update_period', type=float, required=False, default=30,
71
- help='Server will report blocks to DHT once in this many seconds')
72
- parser.add_argument('--expiration', type=float, required=False, default=None,
73
- help='DHT entries will expire after this many seconds')
74
- parser.add_argument('--initial_peers', type=str, nargs='*', required=False, default=[],
75
- help='multiaddrs of one or more active DHT peers (if you want to join an existing DHT)')
76
- parser.add_argument('--increase_file_limit', action='store_true',
77
- help='On *nix, this will increase the max number of processes '
78
- 'a server can spawn before hitting "Too many open files"; Use at your own risk.')
79
- parser.add_argument('--stats_report_interval', type=int, required=False,
80
- help='Interval between two reports of batch processing performance statistics')
81
-
82
- parser.add_argument('--custom_module_path', type=str, required=False,
83
- help='Path of a file with custom nn.modules, wrapped into special decorator')
84
- parser.add_argument('--identity_path', type=str, required=False, help='Path to identity file to be used in P2P')
85
-
86
- parser.add_argument("--balance_quality", type=float, default=0.75,
87
- help="Rebalance the swarm if its throughput is worse than this share of the optimal "
88
- "throughput. Use 0.0 to disable rebalancing, values > 1.0 to force rebalancing "
89
- "on each check for debugging purposes.")
90
- parser.add_argument("--mean_balance_check_period", type=float, default=60,
91
- help="Check the swarm's balance every N seconds (and rebalance it if necessary)")
92
-
93
- parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
94
- parser.add_argument('--load_in_8bit', action='store_true', help='Convert the loaded model into mixed-8bit quantized model.')
95
-
96
- # fmt:on
97
- args = vars(parser.parse_args())
98
- args.pop("config", None)
99
-
100
- args["converted_model_name_or_path"] = args.pop("model") or args["converted_model_name_or_path"]
101
-
102
- if args.pop("increase_file_limit"):
103
- increase_file_limit()
104
-
105
- compression_type = args.pop("compression").upper()
106
- compression = getattr(CompressionType, compression_type)
107
-
108
- attn_cache_size = args.pop("attn_cache_size")
109
- if attn_cache_size is not None:
110
- attn_cache_size = parse_size(attn_cache_size)
111
- assert isinstance(
112
- attn_cache_size, (int, type(None))
113
- ), "unrecognized value for attention_cache_bytes, examples: 1.5GB or 1500MB or 1572864000 (bytes)"
114
-
115
- use_auth_token = args.pop("use_auth_token")
116
- args["use_auth_token"] = True if use_auth_token in ("True", "true", "") else use_auth_token
117
-
118
- server = Server(**args, compression=compression, attn_cache_size=attn_cache_size, start=True)
119
-
120
- try:
121
- server.join()
122
- except KeyboardInterrupt:
123
- logger.info("Caught KeyboardInterrupt, shutting down")
124
- finally:
125
- server.shutdown()
126
-
127
-
128
- if __name__ == "__main__":
129
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/cli/speed_test.py DELETED
@@ -1,1941 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- # Copyright 2012 Matt Martz
4
- # All Rights Reserved.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License"); you may
7
- # not use this file except in compliance with the License. You may obtain
8
- # a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14
- # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15
- # License for the specific language governing permissions and limitations
16
- # under the License.
17
-
18
- import csv
19
- import datetime
20
- import errno
21
- import math
22
- import os
23
- import platform
24
- import re
25
- import signal
26
- import socket
27
- import sys
28
- import threading
29
- import timeit
30
- import xml.parsers.expat
31
-
32
- try:
33
- import gzip
34
-
35
- GZIP_BASE = gzip.GzipFile
36
- except ImportError:
37
- gzip = None
38
- GZIP_BASE = object
39
-
40
- __version__ = "2.1.4b1"
41
-
42
-
43
- class FakeShutdownEvent(object):
44
- """Class to fake a threading.Event.isSet so that users of this module
45
- are not required to register their own threading.Event()
46
- """
47
-
48
- @staticmethod
49
- def isSet():
50
- "Dummy method to always return false" ""
51
- return False
52
-
53
- is_set = isSet
54
-
55
-
56
- # Some global variables we use
57
- DEBUG = False
58
- _GLOBAL_DEFAULT_TIMEOUT = object()
59
- PY25PLUS = sys.version_info[:2] >= (2, 5)
60
- PY26PLUS = sys.version_info[:2] >= (2, 6)
61
- PY32PLUS = sys.version_info[:2] >= (3, 2)
62
- PY310PLUS = sys.version_info[:2] >= (3, 10)
63
-
64
- # Begin import game to handle Python 2 and Python 3
65
- try:
66
- import json
67
- except ImportError:
68
- try:
69
- import simplejson as json
70
- except ImportError:
71
- json = None
72
-
73
- try:
74
- import xml.etree.ElementTree as ET
75
-
76
- try:
77
- from xml.etree.ElementTree import _Element as ET_Element
78
- except ImportError:
79
- pass
80
- except ImportError:
81
- from xml.dom import minidom as DOM
82
- from xml.parsers.expat import ExpatError
83
-
84
- ET = None
85
-
86
- try:
87
- from urllib2 import (
88
- AbstractHTTPHandler,
89
- HTTPDefaultErrorHandler,
90
- HTTPError,
91
- HTTPErrorProcessor,
92
- HTTPRedirectHandler,
93
- OpenerDirector,
94
- ProxyHandler,
95
- Request,
96
- URLError,
97
- urlopen,
98
- )
99
- except ImportError:
100
- from urllib.request import (
101
- AbstractHTTPHandler,
102
- HTTPDefaultErrorHandler,
103
- HTTPError,
104
- HTTPErrorProcessor,
105
- HTTPRedirectHandler,
106
- OpenerDirector,
107
- ProxyHandler,
108
- Request,
109
- URLError,
110
- urlopen,
111
- )
112
-
113
- try:
114
- from httplib import BadStatusLine, HTTPConnection
115
- except ImportError:
116
- from http.client import BadStatusLine, HTTPConnection
117
-
118
- try:
119
- from httplib import HTTPSConnection
120
- except ImportError:
121
- try:
122
- from http.client import HTTPSConnection
123
- except ImportError:
124
- HTTPSConnection = None
125
-
126
- try:
127
- from httplib import FakeSocket
128
- except ImportError:
129
- FakeSocket = None
130
-
131
- try:
132
- from Queue import Queue
133
- except ImportError:
134
- from queue import Queue
135
-
136
- try:
137
- from urlparse import urlparse
138
- except ImportError:
139
- from urllib.parse import urlparse
140
-
141
- try:
142
- from urlparse import parse_qs
143
- except ImportError:
144
- try:
145
- from urllib.parse import parse_qs
146
- except ImportError:
147
- from cgi import parse_qs
148
-
149
- try:
150
- from hashlib import md5
151
- except ImportError:
152
- from md5 import md5
153
-
154
- try:
155
- from argparse import SUPPRESS as ARG_SUPPRESS, ArgumentParser as ArgParser
156
-
157
- PARSER_TYPE_INT = int
158
- PARSER_TYPE_STR = str
159
- PARSER_TYPE_FLOAT = float
160
- except ImportError:
161
- from optparse import SUPPRESS_HELP as ARG_SUPPRESS, OptionParser as ArgParser
162
-
163
- PARSER_TYPE_INT = "int"
164
- PARSER_TYPE_STR = "string"
165
- PARSER_TYPE_FLOAT = "float"
166
-
167
- try:
168
- from cStringIO import StringIO
169
-
170
- BytesIO = None
171
- except ImportError:
172
- try:
173
- from StringIO import StringIO
174
-
175
- BytesIO = None
176
- except ImportError:
177
- from io import BytesIO, StringIO
178
-
179
- try:
180
- import __builtin__
181
- except ImportError:
182
- import builtins
183
- from io import FileIO, TextIOWrapper
184
-
185
- class _Py3Utf8Output(TextIOWrapper):
186
- """UTF-8 encoded wrapper around stdout for py3, to override
187
- ASCII stdout
188
- """
189
-
190
- def __init__(self, f, **kwargs):
191
- buf = FileIO(f.fileno(), "w")
192
- super(_Py3Utf8Output, self).__init__(buf, encoding="utf8", errors="strict")
193
-
194
- def write(self, s):
195
- super(_Py3Utf8Output, self).write(s)
196
- self.flush()
197
-
198
- _py3_print = getattr(builtins, "print")
199
- try:
200
- _py3_utf8_stdout = _Py3Utf8Output(sys.stdout)
201
- _py3_utf8_stderr = _Py3Utf8Output(sys.stderr)
202
- except OSError:
203
- # sys.stdout/sys.stderr is not a compatible stdout/stderr object
204
- # just use it and hope things go ok
205
- _py3_utf8_stdout = sys.stdout
206
- _py3_utf8_stderr = sys.stderr
207
-
208
- def to_utf8(v):
209
- """No-op encode to utf-8 for py3"""
210
- return v
211
-
212
- def print_(*args, **kwargs):
213
- """Wrapper function for py3 to print, with a utf-8 encoded stdout"""
214
- if kwargs.get("file") == sys.stderr:
215
- kwargs["file"] = _py3_utf8_stderr
216
- else:
217
- kwargs["file"] = kwargs.get("file", _py3_utf8_stdout)
218
- _py3_print(*args, **kwargs)
219
-
220
- else:
221
- del __builtin__
222
-
223
- def to_utf8(v):
224
- """Encode value to utf-8 if possible for py2"""
225
- try:
226
- return v.encode("utf8", "strict")
227
- except AttributeError:
228
- return v
229
-
230
- def print_(*args, **kwargs):
231
- """The new-style print function for Python 2.4 and 2.5.
232
-
233
- Taken from https://pypi.python.org/pypi/six/
234
-
235
- Modified to set encoding to UTF-8 always, and to flush after write
236
- """
237
- fp = kwargs.pop("file", sys.stdout)
238
- if fp is None:
239
- return
240
-
241
- def write(data):
242
- if not isinstance(data, basestring):
243
- data = str(data)
244
- # If the file has an encoding, encode unicode with it.
245
- encoding = "utf8" # Always trust UTF-8 for output
246
- if isinstance(fp, file) and isinstance(data, unicode) and encoding is not None:
247
- errors = getattr(fp, "errors", None)
248
- if errors is None:
249
- errors = "strict"
250
- data = data.encode(encoding, errors)
251
- fp.write(data)
252
- fp.flush()
253
-
254
- want_unicode = False
255
- sep = kwargs.pop("sep", None)
256
- if sep is not None:
257
- if isinstance(sep, unicode):
258
- want_unicode = True
259
- elif not isinstance(sep, str):
260
- raise TypeError("sep must be None or a string")
261
- end = kwargs.pop("end", None)
262
- if end is not None:
263
- if isinstance(end, unicode):
264
- want_unicode = True
265
- elif not isinstance(end, str):
266
- raise TypeError("end must be None or a string")
267
- if kwargs:
268
- raise TypeError("invalid keyword arguments to print()")
269
- if not want_unicode:
270
- for arg in args:
271
- if isinstance(arg, unicode):
272
- want_unicode = True
273
- break
274
- if want_unicode:
275
- newline = unicode("\n")
276
- space = unicode(" ")
277
- else:
278
- newline = "\n"
279
- space = " "
280
- if sep is None:
281
- sep = space
282
- if end is None:
283
- end = newline
284
- for i, arg in enumerate(args):
285
- if i:
286
- write(sep)
287
- write(arg)
288
- write(end)
289
-
290
-
291
- # Exception "constants" to support Python 2 through Python 3
292
- try:
293
- import ssl
294
-
295
- try:
296
- CERT_ERROR = (ssl.CertificateError,)
297
- except AttributeError:
298
- CERT_ERROR = tuple()
299
-
300
- HTTP_ERRORS = (HTTPError, URLError, socket.error, ssl.SSLError, BadStatusLine) + CERT_ERROR
301
- except ImportError:
302
- ssl = None
303
- HTTP_ERRORS = (HTTPError, URLError, socket.error, BadStatusLine)
304
-
305
- if PY32PLUS:
306
- etree_iter = ET.Element.iter
307
- elif PY25PLUS:
308
- etree_iter = ET_Element.getiterator
309
-
310
- if PY26PLUS:
311
- thread_is_alive = threading.Thread.is_alive
312
- else:
313
- thread_is_alive = threading.Thread.isAlive
314
-
315
-
316
- def event_is_set(event):
317
- try:
318
- return event.is_set()
319
- except AttributeError:
320
- return event.isSet()
321
-
322
-
323
- class SpeedtestException(Exception):
324
- """Base exception for this module"""
325
-
326
-
327
- class SpeedtestCLIError(SpeedtestException):
328
- """Generic exception for raising errors during CLI operation"""
329
-
330
-
331
- class SpeedtestHTTPError(SpeedtestException):
332
- """Base HTTP exception for this module"""
333
-
334
-
335
- class SpeedtestConfigError(SpeedtestException):
336
- """Configuration XML is invalid"""
337
-
338
-
339
- class SpeedtestServersError(SpeedtestException):
340
- """Servers XML is invalid"""
341
-
342
-
343
- class ConfigRetrievalError(SpeedtestHTTPError):
344
- """Could not retrieve config.php"""
345
-
346
-
347
- class ServersRetrievalError(SpeedtestHTTPError):
348
- """Could not retrieve speedtest-servers.php"""
349
-
350
-
351
- class InvalidServerIDType(SpeedtestException):
352
- """Server ID used for filtering was not an integer"""
353
-
354
-
355
- class NoMatchedServers(SpeedtestException):
356
- """No servers matched when filtering"""
357
-
358
-
359
- class SpeedtestMiniConnectFailure(SpeedtestException):
360
- """Could not connect to the provided speedtest mini server"""
361
-
362
-
363
- class InvalidSpeedtestMiniServer(SpeedtestException):
364
- """Server provided as a speedtest mini server does not actually appear
365
- to be a speedtest mini server
366
- """
367
-
368
-
369
- class ShareResultsConnectFailure(SpeedtestException):
370
- """Could not connect to speedtest.net API to POST results"""
371
-
372
-
373
- class ShareResultsSubmitFailure(SpeedtestException):
374
- """Unable to successfully POST results to speedtest.net API after
375
- connection
376
- """
377
-
378
-
379
- class SpeedtestUploadTimeout(SpeedtestException):
380
- """testlength configuration reached during upload
381
- Used to ensure the upload halts when no additional data should be sent
382
- """
383
-
384
-
385
- class SpeedtestBestServerFailure(SpeedtestException):
386
- """Unable to determine best server"""
387
-
388
-
389
- class SpeedtestMissingBestServer(SpeedtestException):
390
- """get_best_server not called or not able to determine best server"""
391
-
392
-
393
- def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, source_address=None):
394
- """Connect to *address* and return the socket object.
395
-
396
- Convenience function. Connect to *address* (a 2-tuple ``(host,
397
- port)``) and return the socket object. Passing the optional
398
- *timeout* parameter will set the timeout on the socket instance
399
- before attempting to connect. If no *timeout* is supplied, the
400
- global default timeout setting returned by :func:`getdefaulttimeout`
401
- is used. If *source_address* is set it must be a tuple of (host, port)
402
- for the socket to bind as a source address before making the connection.
403
- An host of '' or port 0 tells the OS to use the default.
404
-
405
- Largely vendored from Python 2.7, modified to work with Python 2.4
406
- """
407
-
408
- host, port = address
409
- err = None
410
- for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
411
- af, socktype, proto, canonname, sa = res
412
- sock = None
413
- try:
414
- sock = socket.socket(af, socktype, proto)
415
- if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
416
- sock.settimeout(float(timeout))
417
- if source_address:
418
- sock.bind(source_address)
419
- sock.connect(sa)
420
- return sock
421
-
422
- except socket.error:
423
- err = get_exception()
424
- if sock is not None:
425
- sock.close()
426
-
427
- if err is not None:
428
- raise err
429
- else:
430
- raise socket.error("getaddrinfo returns an empty list")
431
-
432
-
433
- class SpeedtestHTTPConnection(HTTPConnection):
434
- """Custom HTTPConnection to support source_address across
435
- Python 2.4 - Python 3
436
- """
437
-
438
- def __init__(self, *args, **kwargs):
439
- source_address = kwargs.pop("source_address", None)
440
- timeout = kwargs.pop("timeout", 10)
441
-
442
- self._tunnel_host = None
443
-
444
- HTTPConnection.__init__(self, *args, **kwargs)
445
-
446
- self.source_address = source_address
447
- self.timeout = timeout
448
-
449
- def connect(self):
450
- """Connect to the host and port specified in __init__."""
451
- try:
452
- self.sock = socket.create_connection((self.host, self.port), self.timeout, self.source_address)
453
- except (AttributeError, TypeError):
454
- self.sock = create_connection((self.host, self.port), self.timeout, self.source_address)
455
-
456
- if self._tunnel_host:
457
- self._tunnel()
458
-
459
-
460
- if HTTPSConnection:
461
-
462
- class SpeedtestHTTPSConnection(HTTPSConnection):
463
- """Custom HTTPSConnection to support source_address across
464
- Python 2.4 - Python 3
465
- """
466
-
467
- default_port = 443
468
-
469
- def __init__(self, *args, **kwargs):
470
- source_address = kwargs.pop("source_address", None)
471
- timeout = kwargs.pop("timeout", 10)
472
-
473
- self._tunnel_host = None
474
-
475
- HTTPSConnection.__init__(self, *args, **kwargs)
476
-
477
- self.timeout = timeout
478
- self.source_address = source_address
479
-
480
- def connect(self):
481
- "Connect to a host on a given (SSL) port."
482
- try:
483
- self.sock = socket.create_connection((self.host, self.port), self.timeout, self.source_address)
484
- except (AttributeError, TypeError):
485
- self.sock = create_connection((self.host, self.port), self.timeout, self.source_address)
486
-
487
- if self._tunnel_host:
488
- self._tunnel()
489
-
490
- if ssl:
491
- try:
492
- kwargs = {}
493
- if hasattr(ssl, "SSLContext"):
494
- if self._tunnel_host:
495
- kwargs["server_hostname"] = self._tunnel_host
496
- else:
497
- kwargs["server_hostname"] = self.host
498
- self.sock = self._context.wrap_socket(self.sock, **kwargs)
499
- except AttributeError:
500
- self.sock = ssl.wrap_socket(self.sock)
501
- try:
502
- self.sock.server_hostname = self.host
503
- except AttributeError:
504
- pass
505
- elif FakeSocket:
506
- # Python 2.4/2.5 support
507
- try:
508
- self.sock = FakeSocket(self.sock, socket.ssl(self.sock))
509
- except AttributeError:
510
- raise SpeedtestException("This version of Python does not support HTTPS/SSL " "functionality")
511
- else:
512
- raise SpeedtestException("This version of Python does not support HTTPS/SSL " "functionality")
513
-
514
-
515
- def _build_connection(connection, source_address, timeout, context=None):
516
- """Cross Python 2.4 - Python 3 callable to build an ``HTTPConnection`` or
517
- ``HTTPSConnection`` with the args we need
518
-
519
- Called from ``http(s)_open`` methods of ``SpeedtestHTTPHandler`` or
520
- ``SpeedtestHTTPSHandler``
521
- """
522
-
523
- def inner(host, **kwargs):
524
- kwargs.update({"source_address": source_address, "timeout": timeout})
525
- if context:
526
- kwargs["context"] = context
527
- return connection(host, **kwargs)
528
-
529
- return inner
530
-
531
-
532
- class SpeedtestHTTPHandler(AbstractHTTPHandler):
533
- """Custom ``HTTPHandler`` that can build a ``HTTPConnection`` with the
534
- args we need for ``source_address`` and ``timeout``
535
- """
536
-
537
- def __init__(self, debuglevel=0, source_address=None, timeout=10):
538
- AbstractHTTPHandler.__init__(self, debuglevel)
539
- self.source_address = source_address
540
- self.timeout = timeout
541
-
542
- def http_open(self, req):
543
- return self.do_open(_build_connection(SpeedtestHTTPConnection, self.source_address, self.timeout), req)
544
-
545
- http_request = AbstractHTTPHandler.do_request_
546
-
547
-
548
- class SpeedtestHTTPSHandler(AbstractHTTPHandler):
549
- """Custom ``HTTPSHandler`` that can build a ``HTTPSConnection`` with the
550
- args we need for ``source_address`` and ``timeout``
551
- """
552
-
553
- def __init__(self, debuglevel=0, context=None, source_address=None, timeout=10):
554
- AbstractHTTPHandler.__init__(self, debuglevel)
555
- self._context = context
556
- self.source_address = source_address
557
- self.timeout = timeout
558
-
559
- def https_open(self, req):
560
- return self.do_open(
561
- _build_connection(
562
- SpeedtestHTTPSConnection,
563
- self.source_address,
564
- self.timeout,
565
- context=self._context,
566
- ),
567
- req,
568
- )
569
-
570
- https_request = AbstractHTTPHandler.do_request_
571
-
572
-
573
- def build_opener(source_address=None, timeout=10):
574
- """Function similar to ``urllib2.build_opener`` that will build
575
- an ``OpenerDirector`` with the explicit handlers we want,
576
- ``source_address`` for binding, ``timeout`` and our custom
577
- `User-Agent`
578
- """
579
-
580
- printer("Timeout set to %d" % timeout, debug=True)
581
-
582
- if source_address:
583
- source_address_tuple = (source_address, 0)
584
- printer("Binding to source address: %r" % (source_address_tuple,), debug=True)
585
- else:
586
- source_address_tuple = None
587
-
588
- handlers = [
589
- ProxyHandler(),
590
- SpeedtestHTTPHandler(source_address=source_address_tuple, timeout=timeout),
591
- SpeedtestHTTPSHandler(source_address=source_address_tuple, timeout=timeout),
592
- HTTPDefaultErrorHandler(),
593
- HTTPRedirectHandler(),
594
- HTTPErrorProcessor(),
595
- ]
596
-
597
- opener = OpenerDirector()
598
- opener.addheaders = [("User-agent", build_user_agent())]
599
-
600
- for handler in handlers:
601
- opener.add_handler(handler)
602
-
603
- return opener
604
-
605
-
606
- class GzipDecodedResponse(GZIP_BASE):
607
- """A file-like object to decode a response encoded with the gzip
608
- method, as described in RFC 1952.
609
-
610
- Largely copied from ``xmlrpclib``/``xmlrpc.client`` and modified
611
- to work for py2.4-py3
612
- """
613
-
614
- def __init__(self, response):
615
- # response doesn't support tell() and read(), required by
616
- # GzipFile
617
- if not gzip:
618
- raise SpeedtestHTTPError("HTTP response body is gzip encoded, " "but gzip support is not available")
619
- IO = BytesIO or StringIO
620
- self.io = IO()
621
- while 1:
622
- chunk = response.read(1024)
623
- if len(chunk) == 0:
624
- break
625
- self.io.write(chunk)
626
- self.io.seek(0)
627
- gzip.GzipFile.__init__(self, mode="rb", fileobj=self.io)
628
-
629
- def close(self):
630
- try:
631
- gzip.GzipFile.close(self)
632
- finally:
633
- self.io.close()
634
-
635
-
636
- def get_exception():
637
- """Helper function to work with py2.4-py3 for getting the current
638
- exception in a try/except block
639
- """
640
- return sys.exc_info()[1]
641
-
642
-
643
- def distance(origin, destination):
644
- """Determine distance between 2 sets of [lat,lon] in km"""
645
-
646
- lat1, lon1 = origin
647
- lat2, lon2 = destination
648
- radius = 6371 # km
649
-
650
- dlat = math.radians(lat2 - lat1)
651
- dlon = math.radians(lon2 - lon1)
652
- a = math.sin(dlat / 2) * math.sin(dlat / 2) + math.cos(math.radians(lat1)) * math.cos(
653
- math.radians(lat2)
654
- ) * math.sin(dlon / 2) * math.sin(dlon / 2)
655
- c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
656
- d = radius * c
657
-
658
- return d
659
-
660
-
661
- def build_user_agent():
662
- """Build a Mozilla/5.0 compatible User-Agent string"""
663
-
664
- ua_tuple = (
665
- "Mozilla/5.0",
666
- "(%s; U; %s; en-us)" % (platform.platform(), platform.architecture()[0]),
667
- "Python/%s" % platform.python_version(),
668
- "(KHTML, like Gecko)",
669
- "speedtest-cli/%s" % __version__,
670
- )
671
- user_agent = " ".join(ua_tuple)
672
- printer("User-Agent: %s" % user_agent, debug=True)
673
- return user_agent
674
-
675
-
676
- def build_request(url, data=None, headers=None, bump="0", secure=False):
677
- """Build a urllib2 request object
678
-
679
- This function automatically adds a User-Agent header to all requests
680
-
681
- """
682
-
683
- if not headers:
684
- headers = {}
685
-
686
- if url[0] == ":":
687
- scheme = ("http", "https")[bool(secure)]
688
- schemed_url = "%s%s" % (scheme, url)
689
- else:
690
- schemed_url = url
691
-
692
- if "?" in url:
693
- delim = "&"
694
- else:
695
- delim = "?"
696
-
697
- # WHO YOU GONNA CALL? CACHE BUSTERS!
698
- final_url = "%s%sx=%s.%s" % (schemed_url, delim, int(timeit.time.time() * 1000), bump)
699
-
700
- headers.update(
701
- {
702
- "Cache-Control": "no-cache",
703
- }
704
- )
705
-
706
- printer("%s %s" % (("GET", "POST")[bool(data)], final_url), debug=True)
707
-
708
- return Request(final_url, data=data, headers=headers)
709
-
710
-
711
- def catch_request(request, opener=None):
712
- """Helper function to catch common exceptions encountered when
713
- establishing a connection with a HTTP/HTTPS request
714
-
715
- """
716
-
717
- if opener:
718
- _open = opener.open
719
- else:
720
- _open = urlopen
721
-
722
- try:
723
- uh = _open(request)
724
- if request.get_full_url() != uh.geturl():
725
- printer("Redirected to %s" % uh.geturl(), debug=True)
726
- return uh, False
727
- except HTTP_ERRORS:
728
- e = get_exception()
729
- return None, e
730
-
731
-
732
- def get_response_stream(response):
733
- """Helper function to return either a Gzip reader if
734
- ``Content-Encoding`` is ``gzip`` otherwise the response itself
735
-
736
- """
737
-
738
- try:
739
- getheader = response.headers.getheader
740
- except AttributeError:
741
- getheader = response.getheader
742
-
743
- if getheader("content-encoding") == "gzip":
744
- return GzipDecodedResponse(response)
745
-
746
- return response
747
-
748
-
749
- def get_attributes_by_tag_name(dom, tag_name):
750
- """Retrieve an attribute from an XML document and return it in a
751
- consistent format
752
-
753
- Only used with xml.dom.minidom, which is likely only to be used
754
- with python versions older than 2.5
755
- """
756
- elem = dom.getElementsByTagName(tag_name)[0]
757
- return dict(list(elem.attributes.items()))
758
-
759
-
760
- def print_dots(shutdown_event):
761
- """Built in callback function used by Thread classes for printing
762
- status
763
- """
764
-
765
- def inner(current, total, start=False, end=False):
766
- if event_is_set(shutdown_event):
767
- return
768
-
769
- sys.stdout.write(".")
770
- if current + 1 == total and end is True:
771
- sys.stdout.write("\n")
772
- sys.stdout.flush()
773
-
774
- return inner
775
-
776
-
777
- def do_nothing(*args, **kwargs):
778
- pass
779
-
780
-
781
- class HTTPDownloader(threading.Thread):
782
- """Thread class for retrieving a URL"""
783
-
784
- def __init__(self, i, request, start, timeout, opener=None, shutdown_event=None):
785
- threading.Thread.__init__(self)
786
- self.request = request
787
- self.result = [0]
788
- self.starttime = start
789
- self.timeout = timeout
790
- self.i = i
791
- if opener:
792
- self._opener = opener.open
793
- else:
794
- self._opener = urlopen
795
-
796
- if shutdown_event:
797
- self._shutdown_event = shutdown_event
798
- else:
799
- self._shutdown_event = FakeShutdownEvent()
800
-
801
- def run(self):
802
- try:
803
- if (timeit.default_timer() - self.starttime) <= self.timeout:
804
- f = self._opener(self.request)
805
- while (
806
- not event_is_set(self._shutdown_event) and (timeit.default_timer() - self.starttime) <= self.timeout
807
- ):
808
- self.result.append(len(f.read(10240)))
809
- if self.result[-1] == 0:
810
- break
811
- f.close()
812
- except IOError:
813
- pass
814
- except HTTP_ERRORS:
815
- pass
816
-
817
-
818
- class HTTPUploaderData(object):
819
- """File like object to improve cutting off the upload once the timeout
820
- has been reached
821
- """
822
-
823
- def __init__(self, length, start, timeout, shutdown_event=None):
824
- self.length = length
825
- self.start = start
826
- self.timeout = timeout
827
-
828
- if shutdown_event:
829
- self._shutdown_event = shutdown_event
830
- else:
831
- self._shutdown_event = FakeShutdownEvent()
832
-
833
- self._data = None
834
-
835
- self.total = [0]
836
-
837
- def pre_allocate(self):
838
- chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
839
- multiplier = int(round(int(self.length) / 36.0))
840
- IO = BytesIO or StringIO
841
- try:
842
- self._data = IO(("content1=%s" % (chars * multiplier)[0 : int(self.length) - 9]).encode())
843
- except MemoryError:
844
- raise SpeedtestCLIError("Insufficient memory to pre-allocate upload data. Please " "use --no-pre-allocate")
845
-
846
- @property
847
- def data(self):
848
- if not self._data:
849
- self.pre_allocate()
850
- return self._data
851
-
852
- def read(self, n=10240):
853
- if (timeit.default_timer() - self.start) <= self.timeout and not event_is_set(self._shutdown_event):
854
- chunk = self.data.read(n)
855
- self.total.append(len(chunk))
856
- return chunk
857
- else:
858
- raise SpeedtestUploadTimeout()
859
-
860
- def __len__(self):
861
- return self.length
862
-
863
-
864
- class HTTPUploader(threading.Thread):
865
- """Thread class for putting a URL"""
866
-
867
- def __init__(self, i, request, start, size, timeout, opener=None, shutdown_event=None):
868
- threading.Thread.__init__(self)
869
- self.request = request
870
- self.request.data.start = self.starttime = start
871
- self.size = size
872
- self.result = 0
873
- self.timeout = timeout
874
- self.i = i
875
-
876
- if opener:
877
- self._opener = opener.open
878
- else:
879
- self._opener = urlopen
880
-
881
- if shutdown_event:
882
- self._shutdown_event = shutdown_event
883
- else:
884
- self._shutdown_event = FakeShutdownEvent()
885
-
886
- def run(self):
887
- request = self.request
888
- try:
889
- if (timeit.default_timer() - self.starttime) <= self.timeout and not event_is_set(self._shutdown_event):
890
- try:
891
- f = self._opener(request)
892
- except TypeError:
893
- # PY24 expects a string or buffer
894
- # This also causes issues with Ctrl-C, but we will concede
895
- # for the moment that Ctrl-C on PY24 isn't immediate
896
- request = build_request(self.request.get_full_url(), data=request.data.read(self.size))
897
- f = self._opener(request)
898
- f.read(11)
899
- f.close()
900
- self.result = sum(self.request.data.total)
901
- else:
902
- self.result = 0
903
- except (IOError, SpeedtestUploadTimeout):
904
- self.result = sum(self.request.data.total)
905
- except HTTP_ERRORS:
906
- self.result = 0
907
-
908
-
909
- class SpeedtestResults(object):
910
- """Class for holding the results of a speedtest, including:
911
-
912
- Download speed
913
- Upload speed
914
- Ping/Latency to test server
915
- Data about server that the test was run against
916
-
917
- Additionally this class can return a result data as a dictionary or CSV,
918
- as well as submit a POST of the result data to the speedtest.net API
919
- to get a share results image link.
920
- """
921
-
922
- def __init__(self, download=0, upload=0, ping=0, server=None, client=None, opener=None, secure=False):
923
- self.download = download
924
- self.upload = upload
925
- self.ping = ping
926
- if server is None:
927
- self.server = {}
928
- else:
929
- self.server = server
930
- self.client = client or {}
931
-
932
- self._share = None
933
- self.timestamp = "%sZ" % datetime.datetime.utcnow().isoformat()
934
- self.bytes_received = 0
935
- self.bytes_sent = 0
936
-
937
- if opener:
938
- self._opener = opener
939
- else:
940
- self._opener = build_opener()
941
-
942
- self._secure = secure
943
-
944
- def __repr__(self):
945
- return repr(self.dict())
946
-
947
- def share(self):
948
- """POST data to the speedtest.net API to obtain a share results
949
- link
950
- """
951
-
952
- if self._share:
953
- return self._share
954
-
955
- download = int(round(self.download / 1000.0, 0))
956
- ping = int(round(self.ping, 0))
957
- upload = int(round(self.upload / 1000.0, 0))
958
-
959
- # Build the request to send results back to speedtest.net
960
- # We use a list instead of a dict because the API expects parameters
961
- # in a certain order
962
- api_data = [
963
- "recommendedserverid=%s" % self.server["id"],
964
- "ping=%s" % ping,
965
- "screenresolution=",
966
- "promo=",
967
- "download=%s" % download,
968
- "screendpi=",
969
- "upload=%s" % upload,
970
- "testmethod=http",
971
- "hash=%s" % md5(("%s-%s-%s-%s" % (ping, upload, download, "297aae72")).encode()).hexdigest(),
972
- "touchscreen=none",
973
- "startmode=pingselect",
974
- "accuracy=1",
975
- "bytesreceived=%s" % self.bytes_received,
976
- "bytessent=%s" % self.bytes_sent,
977
- "serverid=%s" % self.server["id"],
978
- ]
979
-
980
- headers = {"Referer": "http://c.speedtest.net/flash/speedtest.swf"}
981
- request = build_request(
982
- "://www.speedtest.net/api/api.php", data="&".join(api_data).encode(), headers=headers, secure=self._secure
983
- )
984
- f, e = catch_request(request, opener=self._opener)
985
- if e:
986
- raise ShareResultsConnectFailure(e)
987
-
988
- response = f.read()
989
- code = f.code
990
- f.close()
991
-
992
- if int(code) != 200:
993
- raise ShareResultsSubmitFailure("Could not submit results to " "speedtest.net")
994
-
995
- qsargs = parse_qs(response.decode())
996
- resultid = qsargs.get("resultid")
997
- if not resultid or len(resultid) != 1:
998
- raise ShareResultsSubmitFailure("Could not submit results to " "speedtest.net")
999
-
1000
- self._share = "http://www.speedtest.net/result/%s.png" % resultid[0]
1001
-
1002
- return self._share
1003
-
1004
- def dict(self):
1005
- """Return dictionary of result data"""
1006
-
1007
- return {
1008
- "download": self.download,
1009
- "upload": self.upload,
1010
- "ping": self.ping,
1011
- "server": self.server,
1012
- "timestamp": self.timestamp,
1013
- "bytes_sent": self.bytes_sent,
1014
- "bytes_received": self.bytes_received,
1015
- "share": self._share,
1016
- "client": self.client,
1017
- }
1018
-
1019
- @staticmethod
1020
- def csv_header(delimiter=","):
1021
- """Return CSV Headers"""
1022
-
1023
- row = [
1024
- "Server ID",
1025
- "Sponsor",
1026
- "Server Name",
1027
- "Timestamp",
1028
- "Distance",
1029
- "Ping",
1030
- "Download",
1031
- "Upload",
1032
- "Share",
1033
- "IP Address",
1034
- ]
1035
- out = StringIO()
1036
- writer = csv.writer(out, delimiter=delimiter, lineterminator="")
1037
- writer.writerow([to_utf8(v) for v in row])
1038
- return out.getvalue()
1039
-
1040
- def csv(self, delimiter=","):
1041
- """Return data in CSV format"""
1042
-
1043
- data = self.dict()
1044
- out = StringIO()
1045
- writer = csv.writer(out, delimiter=delimiter, lineterminator="")
1046
- row = [
1047
- data["server"]["id"],
1048
- data["server"]["sponsor"],
1049
- data["server"]["name"],
1050
- data["timestamp"],
1051
- data["server"]["d"],
1052
- data["ping"],
1053
- data["download"],
1054
- data["upload"],
1055
- self._share or "",
1056
- self.client["ip"],
1057
- ]
1058
- writer.writerow([to_utf8(v) for v in row])
1059
- return out.getvalue()
1060
-
1061
- def json(self, pretty=False):
1062
- """Return data in JSON format"""
1063
-
1064
- kwargs = {}
1065
- if pretty:
1066
- kwargs.update({"indent": 4, "sort_keys": True})
1067
- return json.dumps(self.dict(), **kwargs)
1068
-
1069
-
1070
- class Speedtest(object):
1071
- """Class for performing standard speedtest.net testing operations"""
1072
-
1073
- def __init__(self, config=None, source_address=None, timeout=10, secure=False, shutdown_event=None):
1074
- self.config = {}
1075
-
1076
- self._source_address = source_address
1077
- self._timeout = timeout
1078
- self._opener = build_opener(source_address, timeout)
1079
-
1080
- self._secure = secure
1081
-
1082
- if shutdown_event:
1083
- self._shutdown_event = shutdown_event
1084
- else:
1085
- self._shutdown_event = FakeShutdownEvent()
1086
-
1087
- self.get_config()
1088
- if config is not None:
1089
- self.config.update(config)
1090
-
1091
- self.servers = {}
1092
- self.closest = []
1093
- self._best = {}
1094
-
1095
- self.results = SpeedtestResults(
1096
- client=self.config["client"],
1097
- opener=self._opener,
1098
- secure=secure,
1099
- )
1100
-
1101
- @property
1102
- def best(self):
1103
- if not self._best:
1104
- self.get_best_server()
1105
- return self._best
1106
-
1107
- def get_config(self):
1108
- """Download the speedtest.net configuration and return only the data
1109
- we are interested in
1110
- """
1111
-
1112
- headers = {}
1113
- if gzip:
1114
- headers["Accept-Encoding"] = "gzip"
1115
- request = build_request("://www.speedtest.net/speedtest-config.php", headers=headers, secure=self._secure)
1116
- uh, e = catch_request(request, opener=self._opener)
1117
- if e:
1118
- raise ConfigRetrievalError(e)
1119
- configxml_list = []
1120
-
1121
- stream = get_response_stream(uh)
1122
-
1123
- while 1:
1124
- try:
1125
- configxml_list.append(stream.read(1024))
1126
- except (OSError, EOFError):
1127
- raise ConfigRetrievalError(get_exception())
1128
- if len(configxml_list[-1]) == 0:
1129
- break
1130
- stream.close()
1131
- uh.close()
1132
-
1133
- if int(uh.code) != 200:
1134
- return None
1135
-
1136
- configxml = "".encode().join(configxml_list)
1137
-
1138
- printer("Config XML:\n%s" % configxml, debug=True)
1139
-
1140
- try:
1141
- try:
1142
- root = ET.fromstring(configxml)
1143
- except ET.ParseError:
1144
- e = get_exception()
1145
- raise SpeedtestConfigError("Malformed speedtest.net configuration: %s" % e)
1146
- server_config = root.find("server-config").attrib
1147
- download = root.find("download").attrib
1148
- upload = root.find("upload").attrib
1149
- # times = root.find('times').attrib
1150
- client = root.find("client").attrib
1151
-
1152
- except AttributeError:
1153
- try:
1154
- root = DOM.parseString(configxml)
1155
- except ExpatError:
1156
- e = get_exception()
1157
- raise SpeedtestConfigError("Malformed speedtest.net configuration: %s" % e)
1158
- server_config = get_attributes_by_tag_name(root, "server-config")
1159
- download = get_attributes_by_tag_name(root, "download")
1160
- upload = get_attributes_by_tag_name(root, "upload")
1161
- # times = get_attributes_by_tag_name(root, 'times')
1162
- client = get_attributes_by_tag_name(root, "client")
1163
-
1164
- ignore_servers = [int(i) for i in server_config["ignoreids"].split(",") if i]
1165
-
1166
- ratio = int(upload["ratio"])
1167
- upload_max = int(upload["maxchunkcount"])
1168
- up_sizes = [32768, 65536, 131072, 262144, 524288, 1048576, 7340032]
1169
- sizes = {"upload": up_sizes[ratio - 1 :], "download": [350, 500, 750, 1000, 1500, 2000, 2500, 3000, 3500, 4000]}
1170
-
1171
- size_count = len(sizes["upload"])
1172
-
1173
- upload_count = int(math.ceil(upload_max / size_count))
1174
-
1175
- counts = {"upload": upload_count, "download": int(download["threadsperurl"])}
1176
-
1177
- threads = {"upload": int(upload["threads"]), "download": int(server_config["threadcount"]) * 2}
1178
-
1179
- length = {"upload": int(upload["testlength"]), "download": int(download["testlength"])}
1180
-
1181
- self.config.update(
1182
- {
1183
- "client": client,
1184
- "ignore_servers": ignore_servers,
1185
- "sizes": sizes,
1186
- "counts": counts,
1187
- "threads": threads,
1188
- "length": length,
1189
- "upload_max": upload_count * size_count,
1190
- }
1191
- )
1192
-
1193
- try:
1194
- self.lat_lon = (float(client["lat"]), float(client["lon"]))
1195
- except ValueError:
1196
- raise SpeedtestConfigError("Unknown location: lat=%r lon=%r" % (client.get("lat"), client.get("lon")))
1197
-
1198
- printer("Config:\n%r" % self.config, debug=True)
1199
-
1200
- return self.config
1201
-
1202
- def get_servers(self, servers=None, exclude=None):
1203
- """Retrieve a the list of speedtest.net servers, optionally filtered
1204
- to servers matching those specified in the ``servers`` argument
1205
- """
1206
- if servers is None:
1207
- servers = []
1208
-
1209
- if exclude is None:
1210
- exclude = []
1211
-
1212
- self.servers.clear()
1213
-
1214
- for server_list in (servers, exclude):
1215
- for i, s in enumerate(server_list):
1216
- try:
1217
- server_list[i] = int(s)
1218
- except ValueError:
1219
- raise InvalidServerIDType("%s is an invalid server type, must be int" % s)
1220
-
1221
- urls = [
1222
- "://www.speedtest.net/speedtest-servers-static.php",
1223
- "http://c.speedtest.net/speedtest-servers-static.php",
1224
- "://www.speedtest.net/speedtest-servers.php",
1225
- "http://c.speedtest.net/speedtest-servers.php",
1226
- ]
1227
-
1228
- headers = {}
1229
- if gzip:
1230
- headers["Accept-Encoding"] = "gzip"
1231
-
1232
- errors = []
1233
- for url in urls:
1234
- try:
1235
- request = build_request(
1236
- "%s?threads=%s" % (url, self.config["threads"]["download"]), headers=headers, secure=self._secure
1237
- )
1238
- uh, e = catch_request(request, opener=self._opener)
1239
- if e:
1240
- errors.append("%s" % e)
1241
- raise ServersRetrievalError()
1242
-
1243
- stream = get_response_stream(uh)
1244
-
1245
- serversxml_list = []
1246
- while 1:
1247
- try:
1248
- serversxml_list.append(stream.read(1024))
1249
- except (OSError, EOFError):
1250
- raise ServersRetrievalError(get_exception())
1251
- if len(serversxml_list[-1]) == 0:
1252
- break
1253
-
1254
- stream.close()
1255
- uh.close()
1256
-
1257
- if int(uh.code) != 200:
1258
- raise ServersRetrievalError()
1259
-
1260
- serversxml = "".encode().join(serversxml_list)
1261
-
1262
- printer("Servers XML:\n%s" % serversxml, debug=True)
1263
-
1264
- try:
1265
- try:
1266
- try:
1267
- root = ET.fromstring(serversxml)
1268
- except ET.ParseError:
1269
- e = get_exception()
1270
- raise SpeedtestServersError("Malformed speedtest.net server list: %s" % e)
1271
- elements = etree_iter(root, "server")
1272
- except AttributeError:
1273
- try:
1274
- root = DOM.parseString(serversxml)
1275
- except ExpatError:
1276
- e = get_exception()
1277
- raise SpeedtestServersError("Malformed speedtest.net server list: %s" % e)
1278
- elements = root.getElementsByTagName("server")
1279
- except (SyntaxError, xml.parsers.expat.ExpatError):
1280
- raise ServersRetrievalError()
1281
-
1282
- for server in elements:
1283
- try:
1284
- attrib = server.attrib
1285
- except AttributeError:
1286
- attrib = dict(list(server.attributes.items()))
1287
-
1288
- if servers and int(attrib.get("id")) not in servers:
1289
- continue
1290
-
1291
- if int(attrib.get("id")) in self.config["ignore_servers"] or int(attrib.get("id")) in exclude:
1292
- continue
1293
-
1294
- try:
1295
- d = distance(self.lat_lon, (float(attrib.get("lat")), float(attrib.get("lon"))))
1296
- except Exception:
1297
- continue
1298
-
1299
- attrib["d"] = d
1300
-
1301
- try:
1302
- self.servers[d].append(attrib)
1303
- except KeyError:
1304
- self.servers[d] = [attrib]
1305
-
1306
- break
1307
-
1308
- except ServersRetrievalError:
1309
- continue
1310
-
1311
- if (servers or exclude) and not self.servers:
1312
- raise NoMatchedServers()
1313
-
1314
- return self.servers
1315
-
1316
- def set_mini_server(self, server):
1317
- """Instead of querying for a list of servers, set a link to a
1318
- speedtest mini server
1319
- """
1320
-
1321
- urlparts = urlparse(server)
1322
-
1323
- name, ext = os.path.splitext(urlparts[2])
1324
- if ext:
1325
- url = os.path.dirname(server)
1326
- else:
1327
- url = server
1328
-
1329
- request = build_request(url)
1330
- uh, e = catch_request(request, opener=self._opener)
1331
- if e:
1332
- raise SpeedtestMiniConnectFailure("Failed to connect to %s" % server)
1333
- else:
1334
- text = uh.read()
1335
- uh.close()
1336
-
1337
- extension = re.findall('upload_?[Ee]xtension: "([^"]+)"', text.decode())
1338
- if not extension:
1339
- for ext in ["php", "asp", "aspx", "jsp"]:
1340
- try:
1341
- f = self._opener.open("%s/speedtest/upload.%s" % (url, ext))
1342
- except Exception:
1343
- pass
1344
- else:
1345
- data = f.read().strip().decode()
1346
- if f.code == 200 and len(data.splitlines()) == 1 and re.match("size=[0-9]", data):
1347
- extension = [ext]
1348
- break
1349
- if not urlparts or not extension:
1350
- raise InvalidSpeedtestMiniServer("Invalid Speedtest Mini Server: " "%s" % server)
1351
-
1352
- self.servers = [
1353
- {
1354
- "sponsor": "Speedtest Mini",
1355
- "name": urlparts[1],
1356
- "d": 0,
1357
- "url": "%s/speedtest/upload.%s" % (url.rstrip("/"), extension[0]),
1358
- "latency": 0,
1359
- "id": 0,
1360
- }
1361
- ]
1362
-
1363
- return self.servers
1364
-
1365
- def get_closest_servers(self, limit=5):
1366
- """Limit servers to the closest speedtest.net servers based on
1367
- geographic distance
1368
- """
1369
-
1370
- if not self.servers:
1371
- self.get_servers()
1372
-
1373
- for d in sorted(self.servers.keys()):
1374
- for s in self.servers[d]:
1375
- self.closest.append(s)
1376
- if len(self.closest) == limit:
1377
- break
1378
- else:
1379
- continue
1380
- break
1381
-
1382
- printer("Closest Servers:\n%r" % self.closest, debug=True)
1383
- return self.closest
1384
-
1385
- def get_best_server(self, servers=None):
1386
- """Perform a speedtest.net "ping" to determine which speedtest.net
1387
- server has the lowest latency
1388
- """
1389
-
1390
- if not servers:
1391
- if not self.closest:
1392
- servers = self.get_closest_servers()
1393
- servers = self.closest
1394
-
1395
- if self._source_address:
1396
- source_address_tuple = (self._source_address, 0)
1397
- else:
1398
- source_address_tuple = None
1399
-
1400
- user_agent = build_user_agent()
1401
-
1402
- results = {}
1403
- for server in servers:
1404
- cum = []
1405
- url = os.path.dirname(server["url"])
1406
- stamp = int(timeit.time.time() * 1000)
1407
- latency_url = "%s/latency.txt?x=%s" % (url, stamp)
1408
- for i in range(0, 3):
1409
- this_latency_url = "%s.%s" % (latency_url, i)
1410
- printer("%s %s" % ("GET", this_latency_url), debug=True)
1411
- urlparts = urlparse(latency_url)
1412
- try:
1413
- if urlparts[0] == "https":
1414
- h = SpeedtestHTTPSConnection(urlparts[1], source_address=source_address_tuple)
1415
- else:
1416
- h = SpeedtestHTTPConnection(urlparts[1], source_address=source_address_tuple)
1417
- headers = {"User-Agent": user_agent}
1418
- path = "%s?%s" % (urlparts[2], urlparts[4])
1419
- start = timeit.default_timer()
1420
- h.request("GET", path, headers=headers)
1421
- r = h.getresponse()
1422
- total = timeit.default_timer() - start
1423
- except HTTP_ERRORS:
1424
- e = get_exception()
1425
- printer("ERROR: %r" % e, debug=True)
1426
- cum.append(3600)
1427
- continue
1428
-
1429
- text = r.read(9)
1430
- if int(r.status) == 200 and text == "test=test".encode():
1431
- cum.append(total)
1432
- else:
1433
- cum.append(3600)
1434
- h.close()
1435
-
1436
- avg = round((sum(cum) / 6) * 1000.0, 3)
1437
- results[avg] = server
1438
-
1439
- try:
1440
- fastest = sorted(results.keys())[0]
1441
- except IndexError:
1442
- raise SpeedtestBestServerFailure("Unable to connect to servers to " "test latency.")
1443
- best = results[fastest]
1444
- best["latency"] = fastest
1445
-
1446
- self.results.ping = fastest
1447
- self.results.server = best
1448
-
1449
- self._best.update(best)
1450
- printer("Best Server:\n%r" % best, debug=True)
1451
- return best
1452
-
1453
- def download(self, callback=do_nothing, threads=None):
1454
- """Test download speed against speedtest.net
1455
-
1456
- A ``threads`` value of ``None`` will fall back to those dictated
1457
- by the speedtest.net configuration
1458
- """
1459
-
1460
- urls = []
1461
- for size in self.config["sizes"]["download"]:
1462
- for _ in range(0, self.config["counts"]["download"]):
1463
- urls.append("%s/random%sx%s.jpg" % (os.path.dirname(self.best["url"]), size, size))
1464
-
1465
- request_count = len(urls)
1466
- requests = []
1467
- for i, url in enumerate(urls):
1468
- requests.append(build_request(url, bump=i, secure=self._secure))
1469
-
1470
- max_threads = threads or self.config["threads"]["download"]
1471
- in_flight = {"threads": 0}
1472
-
1473
- def producer(q, requests, request_count):
1474
- for i, request in enumerate(requests):
1475
- thread = HTTPDownloader(
1476
- i,
1477
- request,
1478
- start,
1479
- self.config["length"]["download"],
1480
- opener=self._opener,
1481
- shutdown_event=self._shutdown_event,
1482
- )
1483
- while in_flight["threads"] >= max_threads:
1484
- timeit.time.sleep(0.001)
1485
- thread.start()
1486
- q.put(thread, True)
1487
- in_flight["threads"] += 1
1488
- callback(i, request_count, start=True)
1489
-
1490
- finished = []
1491
-
1492
- def consumer(q, request_count):
1493
- _is_alive = thread_is_alive
1494
- while len(finished) < request_count:
1495
- thread = q.get(True)
1496
- while _is_alive(thread):
1497
- thread.join(timeout=0.001)
1498
- in_flight["threads"] -= 1
1499
- finished.append(sum(thread.result))
1500
- callback(thread.i, request_count, end=True)
1501
-
1502
- q = Queue(max_threads)
1503
- prod_thread = threading.Thread(target=producer, args=(q, requests, request_count))
1504
- cons_thread = threading.Thread(target=consumer, args=(q, request_count))
1505
- start = timeit.default_timer()
1506
- prod_thread.start()
1507
- cons_thread.start()
1508
- _is_alive = thread_is_alive
1509
- while _is_alive(prod_thread):
1510
- prod_thread.join(timeout=0.001)
1511
- while _is_alive(cons_thread):
1512
- cons_thread.join(timeout=0.001)
1513
-
1514
- stop = timeit.default_timer()
1515
- self.results.bytes_received = sum(finished)
1516
- self.results.download = (self.results.bytes_received / (stop - start)) * 8.0
1517
- if self.results.download > 100000:
1518
- self.config["threads"]["upload"] = 8
1519
- return self.results.download
1520
-
1521
- def upload(self, callback=do_nothing, pre_allocate=True, threads=None):
1522
- """Test upload speed against speedtest.net
1523
-
1524
- A ``threads`` value of ``None`` will fall back to those dictated
1525
- by the speedtest.net configuration
1526
- """
1527
-
1528
- sizes = []
1529
-
1530
- for size in self.config["sizes"]["upload"]:
1531
- for _ in range(0, self.config["counts"]["upload"]):
1532
- sizes.append(size)
1533
-
1534
- # request_count = len(sizes)
1535
- request_count = self.config["upload_max"]
1536
-
1537
- requests = []
1538
- for i, size in enumerate(sizes):
1539
- # We set ``0`` for ``start`` and handle setting the actual
1540
- # ``start`` in ``HTTPUploader`` to get better measurements
1541
- data = HTTPUploaderData(size, 0, self.config["length"]["upload"], shutdown_event=self._shutdown_event)
1542
- if pre_allocate:
1543
- data.pre_allocate()
1544
-
1545
- headers = {"Content-length": size}
1546
- requests.append((build_request(self.best["url"], data, secure=self._secure, headers=headers), size))
1547
-
1548
- max_threads = threads or self.config["threads"]["upload"]
1549
- in_flight = {"threads": 0}
1550
-
1551
- def producer(q, requests, request_count):
1552
- for i, request in enumerate(requests[:request_count]):
1553
- thread = HTTPUploader(
1554
- i,
1555
- request[0],
1556
- start,
1557
- request[1],
1558
- self.config["length"]["upload"],
1559
- opener=self._opener,
1560
- shutdown_event=self._shutdown_event,
1561
- )
1562
- while in_flight["threads"] >= max_threads:
1563
- timeit.time.sleep(0.001)
1564
- thread.start()
1565
- q.put(thread, True)
1566
- in_flight["threads"] += 1
1567
- callback(i, request_count, start=True)
1568
-
1569
- finished = []
1570
-
1571
- def consumer(q, request_count):
1572
- _is_alive = thread_is_alive
1573
- while len(finished) < request_count:
1574
- thread = q.get(True)
1575
- while _is_alive(thread):
1576
- thread.join(timeout=0.001)
1577
- in_flight["threads"] -= 1
1578
- finished.append(thread.result)
1579
- callback(thread.i, request_count, end=True)
1580
-
1581
- q = Queue(threads or self.config["threads"]["upload"])
1582
- prod_thread = threading.Thread(target=producer, args=(q, requests, request_count))
1583
- cons_thread = threading.Thread(target=consumer, args=(q, request_count))
1584
- start = timeit.default_timer()
1585
- prod_thread.start()
1586
- cons_thread.start()
1587
- _is_alive = thread_is_alive
1588
- while _is_alive(prod_thread):
1589
- prod_thread.join(timeout=0.1)
1590
- while _is_alive(cons_thread):
1591
- cons_thread.join(timeout=0.1)
1592
-
1593
- stop = timeit.default_timer()
1594
- self.results.bytes_sent = sum(finished)
1595
- self.results.upload = (self.results.bytes_sent / (stop - start)) * 8.0
1596
- return self.results.upload
1597
-
1598
-
1599
- def ctrl_c(shutdown_event):
1600
- """Catch Ctrl-C key sequence and set a SHUTDOWN_EVENT for our threaded
1601
- operations
1602
- """
1603
-
1604
- def inner(signum, frame):
1605
- shutdown_event.set()
1606
- printer("\nCancelling...", error=True)
1607
- sys.exit(0)
1608
-
1609
- return inner
1610
-
1611
-
1612
- def version():
1613
- """Print the version"""
1614
-
1615
- printer("speedtest-cli %s" % __version__)
1616
- printer("Python %s" % sys.version.replace("\n", ""))
1617
- sys.exit(0)
1618
-
1619
-
1620
- def csv_header(delimiter=","):
1621
- """Print the CSV Headers"""
1622
-
1623
- printer(SpeedtestResults.csv_header(delimiter=delimiter))
1624
- sys.exit(0)
1625
-
1626
-
1627
- def parse_args():
1628
- """Function to handle building and parsing of command line arguments"""
1629
- description = (
1630
- "Command line interface for testing internet bandwidth using "
1631
- "speedtest.net.\n"
1632
- "------------------------------------------------------------"
1633
- "--------------\n"
1634
- "https://github.com/sivel/speedtest-cli"
1635
- )
1636
-
1637
- parser = ArgParser(description=description)
1638
- # Give optparse.OptionParser an `add_argument` method for
1639
- # compatibility with argparse.ArgumentParser
1640
- try:
1641
- parser.add_argument = parser.add_option
1642
- except AttributeError:
1643
- pass
1644
- parser.add_argument(
1645
- "--no-download",
1646
- dest="download",
1647
- default=True,
1648
- action="store_const",
1649
- const=False,
1650
- help="Do not perform download test",
1651
- )
1652
- parser.add_argument(
1653
- "--no-upload", dest="upload", default=True, action="store_const", const=False, help="Do not perform upload test"
1654
- )
1655
- parser.add_argument(
1656
- "--single",
1657
- default=False,
1658
- action="store_true",
1659
- help="Only use a single connection instead of " "multiple. This simulates a typical file " "transfer.",
1660
- )
1661
- parser.add_argument(
1662
- "--bytes",
1663
- dest="units",
1664
- action="store_const",
1665
- const=("byte", 8),
1666
- default=("bit", 1),
1667
- help="Display values in bytes instead of bits. Does "
1668
- "not affect the image generated by --share, nor "
1669
- "output from --json or --csv",
1670
- )
1671
- parser.add_argument(
1672
- "--share",
1673
- action="store_true",
1674
- help="Generate and provide a URL to the speedtest.net " "share results image, not displayed with --csv",
1675
- )
1676
- parser.add_argument(
1677
- "--simple", action="store_true", default=False, help="Suppress verbose output, only show basic " "information"
1678
- )
1679
- parser.add_argument(
1680
- "--csv",
1681
- action="store_true",
1682
- default=False,
1683
- help="Suppress verbose output, only show basic "
1684
- "information in CSV format. Speeds listed in "
1685
- "bit/s and not affected by --bytes",
1686
- )
1687
- parser.add_argument(
1688
- "--csv-delimiter",
1689
- default=",",
1690
- type=PARSER_TYPE_STR,
1691
- help="Single character delimiter to use in CSV " 'output. Default ","',
1692
- )
1693
- parser.add_argument("--csv-header", action="store_true", default=False, help="Print CSV headers")
1694
- parser.add_argument(
1695
- "--json",
1696
- action="store_true",
1697
- default=False,
1698
- help="Suppress verbose output, only show basic "
1699
- "information in JSON format. Speeds listed in "
1700
- "bit/s and not affected by --bytes",
1701
- )
1702
- parser.add_argument(
1703
- "--list", action="store_true", help="Display a list of speedtest.net servers " "sorted by distance"
1704
- )
1705
- parser.add_argument(
1706
- "--server",
1707
- type=PARSER_TYPE_INT,
1708
- action="append",
1709
- help="Specify a server ID to test against. Can be " "supplied multiple times",
1710
- )
1711
- parser.add_argument(
1712
- "--exclude",
1713
- type=PARSER_TYPE_INT,
1714
- action="append",
1715
- help="Exclude a server from selection. Can be " "supplied multiple times",
1716
- )
1717
- parser.add_argument("--mini", help="URL of the Speedtest Mini server")
1718
- parser.add_argument("--source", help="Source IP address to bind to")
1719
- parser.add_argument("--timeout", default=10, type=PARSER_TYPE_FLOAT, help="HTTP timeout in seconds. Default 10")
1720
- parser.add_argument(
1721
- "--secure",
1722
- action="store_true",
1723
- help="Use HTTPS instead of HTTP when communicating " "with speedtest.net operated servers",
1724
- )
1725
- parser.add_argument(
1726
- "--no-pre-allocate",
1727
- dest="pre_allocate",
1728
- action="store_const",
1729
- default=True,
1730
- const=False,
1731
- help="Do not pre allocate upload data. Pre allocation "
1732
- "is enabled by default to improve upload "
1733
- "performance. To support systems with "
1734
- "insufficient memory, use this option to avoid a "
1735
- "MemoryError",
1736
- )
1737
- parser.add_argument("--version", action="store_true", help="Show the version number and exit")
1738
- parser.add_argument("--debug", action="store_true", help=ARG_SUPPRESS, default=ARG_SUPPRESS)
1739
-
1740
- options = parser.parse_args()
1741
- if isinstance(options, tuple):
1742
- args = options[0]
1743
- else:
1744
- args = options
1745
- return args
1746
-
1747
-
1748
- def validate_optional_args(args):
1749
- """Check if an argument was provided that depends on a module that may
1750
- not be part of the Python standard library.
1751
-
1752
- If such an argument is supplied, and the module does not exist, exit
1753
- with an error stating which module is missing.
1754
- """
1755
- optional_args = {
1756
- "json": ("json/simplejson python module", json),
1757
- "secure": ("SSL support", HTTPSConnection),
1758
- }
1759
-
1760
- for arg, info in optional_args.items():
1761
- if getattr(args, arg, False) and info[1] is None:
1762
- raise SystemExit("%s is not installed. --%s is " "unavailable" % (info[0], arg))
1763
-
1764
-
1765
- def printer(string, quiet=False, debug=False, error=False, **kwargs):
1766
- """Helper function print a string with various features"""
1767
-
1768
- if debug and not DEBUG:
1769
- return
1770
-
1771
- if debug:
1772
- if sys.stdout.isatty():
1773
- out = "\033[1;30mDEBUG: %s\033[0m" % string
1774
- else:
1775
- out = "DEBUG: %s" % string
1776
- else:
1777
- out = string
1778
-
1779
- if error:
1780
- kwargs["file"] = sys.stderr
1781
-
1782
- if not quiet:
1783
- print_(out, **kwargs)
1784
-
1785
-
1786
- def shell():
1787
- """Run the full speedtest.net test"""
1788
-
1789
- global DEBUG
1790
- shutdown_event = threading.Event()
1791
-
1792
- signal.signal(signal.SIGINT, ctrl_c(shutdown_event))
1793
-
1794
- args = parse_args()
1795
-
1796
- # Print the version and exit
1797
- if args.version:
1798
- version()
1799
-
1800
- if not args.download and not args.upload:
1801
- raise SpeedtestCLIError("Cannot supply both --no-download and " "--no-upload")
1802
-
1803
- if len(args.csv_delimiter) != 1:
1804
- raise SpeedtestCLIError("--csv-delimiter must be a single character")
1805
-
1806
- if args.csv_header:
1807
- csv_header(args.csv_delimiter)
1808
-
1809
- validate_optional_args(args)
1810
-
1811
- debug = getattr(args, "debug", False)
1812
- if debug == "SUPPRESSHELP":
1813
- debug = False
1814
- if debug:
1815
- DEBUG = True
1816
-
1817
- if args.simple or args.csv or args.json:
1818
- quiet = True
1819
- else:
1820
- quiet = False
1821
-
1822
- if args.csv or args.json:
1823
- machine_format = True
1824
- else:
1825
- machine_format = False
1826
-
1827
- # Don't set a callback if we are running quietly
1828
- if quiet or debug:
1829
- callback = do_nothing
1830
- else:
1831
- callback = print_dots(shutdown_event)
1832
-
1833
- printer("Retrieving speedtest.net configuration...", quiet)
1834
- try:
1835
- speedtest = Speedtest(source_address=args.source, timeout=args.timeout, secure=args.secure)
1836
- except (ConfigRetrievalError,) + HTTP_ERRORS:
1837
- printer("Cannot retrieve speedtest configuration", error=True)
1838
- raise SpeedtestCLIError(get_exception())
1839
-
1840
- if args.list:
1841
- try:
1842
- speedtest.get_servers()
1843
- except (ServersRetrievalError,) + HTTP_ERRORS:
1844
- printer("Cannot retrieve speedtest server list", error=True)
1845
- raise SpeedtestCLIError(get_exception())
1846
-
1847
- for _, servers in sorted(speedtest.servers.items()):
1848
- for server in servers:
1849
- line = "%(id)5s) %(sponsor)s (%(name)s, %(country)s) " "[%(d)0.2f km]" % server
1850
- try:
1851
- printer(line)
1852
- except IOError:
1853
- e = get_exception()
1854
- if e.errno != errno.EPIPE:
1855
- raise
1856
- sys.exit(0)
1857
-
1858
- printer("Testing from %(isp)s (%(ip)s)..." % speedtest.config["client"], quiet)
1859
-
1860
- if not args.mini:
1861
- printer("Retrieving speedtest.net server list...", quiet)
1862
- try:
1863
- speedtest.get_servers(servers=args.server, exclude=args.exclude)
1864
- except NoMatchedServers:
1865
- raise SpeedtestCLIError("No matched servers: %s" % ", ".join("%s" % s for s in args.server))
1866
- except (ServersRetrievalError,) + HTTP_ERRORS:
1867
- printer("Cannot retrieve speedtest server list", error=True)
1868
- raise SpeedtestCLIError(get_exception())
1869
- except InvalidServerIDType:
1870
- raise SpeedtestCLIError(
1871
- "%s is an invalid server type, must " "be an int" % ", ".join("%s" % s for s in args.server)
1872
- )
1873
-
1874
- if args.server and len(args.server) == 1:
1875
- printer("Retrieving information for the selected server...", quiet)
1876
- else:
1877
- printer("Selecting best server based on ping...", quiet)
1878
- speedtest.get_best_server()
1879
- elif args.mini:
1880
- speedtest.get_best_server(speedtest.set_mini_server(args.mini))
1881
-
1882
- results = speedtest.results
1883
-
1884
- printer("Hosted by %(sponsor)s (%(name)s) [%(d)0.2f km]: " "%(latency)s ms" % results.server, quiet)
1885
-
1886
- if args.download:
1887
- printer("Testing download speed", quiet, end=("", "\n")[bool(debug)])
1888
- speedtest.download(callback=callback, threads=(None, 1)[args.single])
1889
- printer("Download: %0.2f M%s/s" % ((results.download / 1000.0 / 1000.0) / args.units[1], args.units[0]), quiet)
1890
- else:
1891
- printer("Skipping download test", quiet)
1892
-
1893
- if args.upload:
1894
- printer("Testing upload speed", quiet, end=("", "\n")[bool(debug)])
1895
- speedtest.upload(callback=callback, pre_allocate=args.pre_allocate, threads=(None, 1)[args.single])
1896
- printer("Upload: %0.2f M%s/s" % ((results.upload / 1000.0 / 1000.0) / args.units[1], args.units[0]), quiet)
1897
- else:
1898
- printer("Skipping upload test", quiet)
1899
-
1900
- printer("Results:\n%r" % results.dict(), debug=True)
1901
-
1902
- if not args.simple and args.share:
1903
- results.share()
1904
-
1905
- if args.simple:
1906
- printer(
1907
- "Ping: %s ms\nDownload: %0.2f M%s/s\nUpload: %0.2f M%s/s"
1908
- % (
1909
- results.ping,
1910
- (results.download / 1000.0 / 1000.0) / args.units[1],
1911
- args.units[0],
1912
- (results.upload / 1000.0 / 1000.0) / args.units[1],
1913
- args.units[0],
1914
- )
1915
- )
1916
- elif args.csv:
1917
- printer(results.csv(delimiter=args.csv_delimiter))
1918
- elif args.json:
1919
- printer(results.json())
1920
-
1921
- if args.share and not machine_format:
1922
- printer("Share results: %s" % results.share())
1923
-
1924
-
1925
- def main():
1926
- try:
1927
- shell()
1928
- except KeyboardInterrupt:
1929
- printer("\nCancelling...", error=True)
1930
- except (SpeedtestException, SystemExit):
1931
- e = get_exception()
1932
- # Ignore a successful exit, or argparse exit
1933
- if getattr(e, "code", 1) not in (0, 2):
1934
- msg = "%s" % e
1935
- if not msg:
1936
- msg = "%r" % e
1937
- raise SystemExit("ERROR: %s" % msg)
1938
-
1939
-
1940
- if __name__ == "__main__":
1941
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/examples/prompt-tuning-personachat.ipynb DELETED
@@ -1,339 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "a07e0f5e",
6
- "metadata": {},
7
- "source": [
8
- "<div>\n",
9
- "<img src=\"https://camo.githubusercontent.com/473dd9f992924d27457650251786464f72e54121ac6e9210add0f483ca849277/68747470733a2f2f692e696d6775722e636f6d2f3765523750616e2e706e67\" width=\"40%\"> \n",
10
- "</div>\n",
11
- "\n",
12
- "# Distributed Bloom for Text Generation using Prompt Tuning\n",
13
- "\n",
14
- "In this example, we show how to use [prompt tuning](https://aclanthology.org/2021.emnlp-main.243.pdf) to adapt a test 6B version of the [BLOOM](https://huggingface.co/bigscience/bloom) model for a specific downstream task. We will run this model in a decentralized fashion using [Petals](https://github.com/bigscience-workshop/petals). Petals servers will maintain the BLOOM blocks (they are kept unchanged during adaptation), and the gradient descent will learn a few prefix tokens stored on a Petals client.\n",
15
- "\n",
16
- "We will adapt the BLOOM model for the chatbot task using the [Personachat](https://huggingface.co/datasets/bavard/personachat_truecased) dataset. For a given dialogue context, the model has to provide a relevant answer.\n",
17
- "\n",
18
- "To open this notebook in colab: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bigscience-workshop/petals/blob/main/examples/prompt-tuning-personachat.ipynb)"
19
- ]
20
- },
21
- {
22
- "cell_type": "markdown",
23
- "id": "a3f8526f",
24
- "metadata": {},
25
- "source": [
26
- "First, we have to prepare all dependencies."
27
- ]
28
- },
29
- {
30
- "cell_type": "code",
31
- "execution_count": null,
32
- "id": "73bbc648",
33
- "metadata": {},
34
- "outputs": [],
35
- "source": [
36
- "# This block is only need for colab users. It will change nothing if you are running this notebook locally.\n",
37
- "import subprocess\n",
38
- "import sys\n",
39
- "\n",
40
- "\n",
41
- "IN_COLAB = 'google.colab' in sys.modules\n",
42
- "\n",
43
- "if IN_COLAB:\n",
44
- " subprocess.run(['git', 'clone', 'https://github.com/bigscience-workshop/petals'])\n",
45
- " subprocess.run(['pip', 'install', '-r', 'petals/requirements.txt'])\n",
46
- " subprocess.run(['pip', 'install', 'datasets', 'lib64'])\n",
47
- "\n",
48
- " try:\n",
49
- " subprocess.check_output([\"nvidia-smi\", \"-L\"])\n",
50
- " except subprocess.CalledProcessError as e:\n",
51
- " subprocess.run(['rm', '-r', '/usr/local/cuda/lib64'])"
52
- ]
53
- },
54
- {
55
- "cell_type": "code",
56
- "execution_count": null,
57
- "id": "b4ab6ca7",
58
- "metadata": {},
59
- "outputs": [],
60
- "source": [
61
- "import os\n",
62
- "import sys\n",
63
- "sys.path.insert(0, \"..\") # for colab change to sys.path.insert(0, './petals/')\n",
64
- " \n",
65
- "import torch\n",
66
- "import transformers\n",
67
- "import wandb\n",
68
- "from datasets import load_dataset\n",
69
- "from tqdm import tqdm\n",
70
- "from torch.optim import AdamW\n",
71
- "from torch.utils.data import DataLoader\n",
72
- "from transformers import get_scheduler\n",
73
- "\n",
74
- "# Import a Petals model\n",
75
- "from src.client.remote_model import DistributedBloomForCausalLM"
76
- ]
77
- },
78
- {
79
- "cell_type": "markdown",
80
- "id": "1bf07b5d",
81
- "metadata": {},
82
- "source": [
83
- "Let's set some hyperparameters for training:"
84
- ]
85
- },
86
- {
87
- "cell_type": "code",
88
- "execution_count": null,
89
- "id": "f04ba4d2",
90
- "metadata": {},
91
- "outputs": [],
92
- "source": [
93
- "MODEL_NAME = ... # select model you like\n",
94
- "INITIAL_PEERS = [...] # add your peers adresses here, like \"/ip4/192.168.1.2/tcp/31000/p2p/Qma....\"\n",
95
- "NUM_PREFIX_TOKENS = 16\n",
96
- "DEVICE = 'cpu'\n",
97
- "BATCH_SIZE = 4\n",
98
- "LR = 1e-2\n",
99
- "WEIGHT_DECAY = 0.0\n",
100
- "NUM_SAMPLES = 1000\n",
101
- "SEED = 42\n",
102
- "MODEL_MAX_LENGTH = 256\n",
103
- "TUNING_MODE = 'ptune' # choose between ['ptune', 'deep_ptune'] "
104
- ]
105
- },
106
- {
107
- "cell_type": "markdown",
108
- "id": "d38316bd",
109
- "metadata": {},
110
- "source": [
111
- "Prepare tokenizer and distributed model, connect it to servers."
112
- ]
113
- },
114
- {
115
- "cell_type": "code",
116
- "execution_count": null,
117
- "id": "03c6e53e",
118
- "metadata": {},
119
- "outputs": [],
120
- "source": [
121
- "tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)\n",
122
- "tokenizer.padding_side = 'right'\n",
123
- "tokenizer.model_max_length = MODEL_MAX_LENGTH\n",
124
- "model = DistributedBloomForCausalLM.from_pretrained(\n",
125
- " MODEL_NAME, \n",
126
- " initial_peers=INITIAL_PEERS, \n",
127
- " pre_seq_len=NUM_PREFIX_TOKENS, \n",
128
- " tuning_mode=TUNING_MODE\n",
129
- ").to(DEVICE)"
130
- ]
131
- },
132
- {
133
- "cell_type": "markdown",
134
- "id": "042e3786",
135
- "metadata": {},
136
- "source": [
137
- "Let's prepare the Personachat dataset. We need two mapping functions, one to concatenate history and candidate answers, and another for tokenization."
138
- ]
139
- },
140
- {
141
- "cell_type": "code",
142
- "execution_count": null,
143
- "id": "9c44d516",
144
- "metadata": {},
145
- "outputs": [],
146
- "source": [
147
- "dataset = load_dataset(\"bavard/personachat_truecased\")\n",
148
- "\n",
149
- "\n",
150
- "def chunking(examples):\n",
151
- " inputs = [\n",
152
- " \"\\n-----\\n\".join(history) + \"\\n-----\\n\" + candidate\n",
153
- " for history, candidates in zip(examples[\"history\"], examples[\"candidates\"])\n",
154
- " for candidate in candidates\n",
155
- " ]\n",
156
- " return {\"chunks\": inputs}\n",
157
- "\n",
158
- "\n",
159
- "def tokenize(examples):\n",
160
- " outputs = {\n",
161
- " \"input_ids\": tokenizer(examples[\"chunks\"], padding='max_length', truncation=True)[\"input_ids\"]\n",
162
- " }\n",
163
- " outputs[\"labels\"] = outputs[\"input_ids\"]\n",
164
- " return outputs\n",
165
- "\n",
166
- "\n",
167
- "tokenized_datasets = (\n",
168
- " dataset\n",
169
- " .map(chunking, batched=True, remove_columns=dataset[\"train\"].column_names)\n",
170
- " .map(tokenize, batched=True, remove_columns=[\"chunks\"])\n",
171
- ")\n",
172
- "\n",
173
- "\n",
174
- "tokenized_datasets.set_format(\"torch\")\n",
175
- "train_dataset = tokenized_datasets[\"train\"].shuffle(seed=SEED)\n",
176
- "train_dataloader = DataLoader(\n",
177
- " train_dataset.select(list(range(NUM_SAMPLES))),\n",
178
- " shuffle=True,\n",
179
- " batch_size=BATCH_SIZE,\n",
180
- " drop_last=True,\n",
181
- ")"
182
- ]
183
- },
184
- {
185
- "cell_type": "markdown",
186
- "id": "ef4323fd",
187
- "metadata": {},
188
- "source": [
189
- "Before setting up optimizers, check the model parameters that will be trained."
190
- ]
191
- },
192
- {
193
- "cell_type": "code",
194
- "execution_count": null,
195
- "id": "9cc0ba34",
196
- "metadata": {},
197
- "outputs": [],
198
- "source": [
199
- "for n, p in model.named_parameters():\n",
200
- " if p.requires_grad:\n",
201
- " print(n, p.requires_grad, p.device)"
202
- ]
203
- },
204
- {
205
- "cell_type": "markdown",
206
- "id": "59cffce7",
207
- "metadata": {},
208
- "source": [
209
- "The optimizer will only work on **prompts**, they are only trainable parameters. Let's initialize optimizer and learning rate scheduler."
210
- ]
211
- },
212
- {
213
- "cell_type": "code",
214
- "execution_count": null,
215
- "id": "ef9bf344",
216
- "metadata": {},
217
- "outputs": [],
218
- "source": [
219
- "optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)\n",
220
- "\n",
221
- "lr_scheduler = get_scheduler(\n",
222
- " name=\"linear\", optimizer=optimizer, num_warmup_steps=0, num_training_steps=len(train_dataloader)\n",
223
- ")"
224
- ]
225
- },
226
- {
227
- "cell_type": "markdown",
228
- "id": "423c56d5",
229
- "metadata": {},
230
- "source": [
231
- "Let's initialize wandb for logging and start the training loop!"
232
- ]
233
- },
234
- {
235
- "cell_type": "code",
236
- "execution_count": null,
237
- "id": "d9e46807",
238
- "metadata": {},
239
- "outputs": [],
240
- "source": [
241
- "wandb.init(\n",
242
- " project=\"bloom-personachat\",\n",
243
- " config={\n",
244
- " \"num_samples\": NUM_SAMPLES,\n",
245
- " \"batch_size\": BATCH_SIZE,\n",
246
- " \"learning_rate\": LR,\n",
247
- " \"weight_decay\": WEIGHT_DECAY,\n",
248
- " \"num_prefix_tokens\": NUM_PREFIX_TOKENS,\n",
249
- " \"model_name\": MODEL_NAME,\n",
250
- " \"seed\": SEED,\n",
251
- " }\n",
252
- ")\n",
253
- "\n",
254
- "for batch in tqdm(train_dataloader):\n",
255
- " batch = {k: v.to(DEVICE) for k, v in batch.items()}\n",
256
- "\n",
257
- " model.train()\n",
258
- " outputs = model(**batch)\n",
259
- " loss = outputs.loss\n",
260
- " loss.backward()\n",
261
- "\n",
262
- " optimizer.step()\n",
263
- " lr_scheduler.step()\n",
264
- " optimizer.zero_grad()\n",
265
- "\n",
266
- " wandb.log({\"Train Loss\": loss})"
267
- ]
268
- },
269
- {
270
- "cell_type": "markdown",
271
- "id": "0f36cb80",
272
- "metadata": {},
273
- "source": [
274
- "Try to talk with the trained model! Submit an empty input to stop the execution.\n",
275
- "\n",
276
- "\n",
277
- "__Note__: In this example, we the whole dialogue as a prefix when generating each new replica. In the future, we will support a faster \"interactive\" dialogue mode, so generating a new replica will be able to reuse inference caches from the previous replica."
278
- ]
279
- },
280
- {
281
- "cell_type": "code",
282
- "execution_count": null,
283
- "id": "720181b7",
284
- "metadata": {},
285
- "outputs": [],
286
- "source": [
287
- "MAX_TOKENS = 16\n",
288
- "TOP_K = 100\n",
289
- "TEMPERATURE = 0.6\n",
290
- "dialog = \"\"\n",
291
- "\n",
292
- "while True:\n",
293
- " user_phrase = input()\n",
294
- " if len(user_phrase) == 0:\n",
295
- " break\n",
296
- " dialog += f\"{user_phrase}\\n-----\\n\"\n",
297
- " inputs = tokenizer([dialog], return_tensors='pt')['input_ids']\n",
298
- " outputs = model.generate(\n",
299
- " inputs,\n",
300
- " temperature=TEMPERATURE,\n",
301
- " do_sample=True,\n",
302
- " top_k=TOP_K,\n",
303
- " eos_token_id=tokenizer.eos_token_id,\n",
304
- " max_new_tokens=MAX_TOKENS,\n",
305
- " )\n",
306
- " bloom_answer = tokenizer.batch_decode(outputs)[0]\n",
307
- " bloom_answer = bloom_answer[len(dialog):].split(\"\\n\")[0]\n",
308
- " print(bloom_answer)\n",
309
- " dialog += f\"{bloom_answer}\\n-----\\n\""
310
- ]
311
- }
312
- ],
313
- "metadata": {
314
- "kernelspec": {
315
- "display_name": "Python 3.8.10 64-bit",
316
- "language": "python",
317
- "name": "python3"
318
- },
319
- "language_info": {
320
- "codemirror_mode": {
321
- "name": "ipython",
322
- "version": 3
323
- },
324
- "file_extension": ".py",
325
- "mimetype": "text/x-python",
326
- "name": "python",
327
- "nbconvert_exporter": "python",
328
- "pygments_lexer": "ipython3",
329
- "version": "3.8.9"
330
- },
331
- "vscode": {
332
- "interpreter": {
333
- "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
334
- }
335
- }
336
- },
337
- "nbformat": 4,
338
- "nbformat_minor": 5
339
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/pyproject.toml DELETED
@@ -1,10 +0,0 @@
1
- [tool.black]
2
- line-length = 120
3
- required-version = "22.3.0"
4
-
5
- [tool.isort]
6
- profile = "black"
7
- line_length = 120
8
- combine_as_imports = true
9
- combine_star = true
10
- known_local_folder = ["tests", "cli"]
 
 
 
 
 
 
 
 
 
 
 
petals/requirements-dev.txt DELETED
@@ -1,6 +0,0 @@
1
- pytest==6.2.5 # see https://github.com/pytest-dev/pytest/issues/9621
2
- pytest-forked
3
- pytest-asyncio==0.16.0
4
- black==22.3.0
5
- isort==5.10.1
6
- psutil
 
 
 
 
 
 
 
petals/requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- torch>=1.12
2
- bitsandbytes==0.34.0
3
- accelerate==0.10.0
4
- huggingface-hub==0.7.0
5
- transformers==4.21.3
6
- protobuf>=3.12.2,<4.0.0
7
- git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7
8
- humanfriendly
 
 
 
 
 
 
 
 
 
petals/src/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from src.bloom import *
2
- from src.client import *
3
- from src.dht_utils import declare_active_modules, get_remote_module
4
-
5
- project_name = "bloomd"
6
- __version__ = "0.2"
 
 
 
 
 
 
 
petals/src/bloom/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from src.bloom.block import BloomBlock
2
- from src.bloom.model import BloomConfig, BloomForCausalLM, BloomModel, BloomPreTrainedModel
 
 
 
petals/src/bloom/block.py DELETED
@@ -1,255 +0,0 @@
1
- """
2
- Bloom intermediate layer
3
- Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4
- See commit history for authorship.
5
- """
6
- import math
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.quantized.dynamic.modules.linear
11
-
12
- from src.bloom.ops import (
13
- BloomGelu,
14
- BloomScaledSoftmax,
15
- attention_mask_func,
16
- build_alibi_tensor,
17
- dropout_add,
18
- pre_process_alibi_for_pad,
19
- split_tensor_along_last_dim,
20
- )
21
-
22
-
23
- class BloomAttention(nn.Module):
24
- def __init__(self, config, layer_number=None):
25
- super().__init__()
26
-
27
- self.hidden_size = config.hidden_size
28
- self.num_heads = config.n_head
29
- self.head_dim = self.hidden_size // self.num_heads
30
- self.split_size = self.hidden_size
31
- self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
32
- self.masked_softmax_fusion = config.masked_softmax_fusion
33
- self.hidden_dropout = config.hidden_dropout
34
-
35
- if self.head_dim * self.num_heads != self.hidden_size:
36
- raise ValueError(
37
- f"`hidden_size` must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
38
- f" {self.num_heads})."
39
- )
40
-
41
- # Layer-wise attention scaling
42
- self.layer_number = max(1, layer_number)
43
- self.norm_factor = math.sqrt(self.head_dim) * self.layer_number
44
-
45
- # Scaled Softmax
46
- self.scale_mask_softmax = BloomScaledSoftmax(
47
- self.masked_softmax_fusion,
48
- attention_mask_func,
49
- self.attention_softmax_in_fp32,
50
- self.layer_number,
51
- )
52
-
53
- self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=True)
54
- self.dense = nn.Linear(self.hidden_size, self.hidden_size)
55
-
56
- self.attention_dropout = nn.Dropout(config.attention_dropout)
57
-
58
- def forward(
59
- self,
60
- hidden_states,
61
- residual,
62
- layer_past=None,
63
- attention_mask=None,
64
- alibi=None,
65
- head_mask=None,
66
- use_cache=False,
67
- output_attentions=False,
68
- ):
69
- if alibi is None:
70
- current_sequence_length = hidden_states.shape[1] + (0 if layer_past is None else layer_past[0].shape[1])
71
- alibi = build_alibi_tensor(
72
- current_sequence_length, n_head=self.num_heads, dtype=hidden_states.dtype, device=hidden_states.device
73
- )
74
-
75
- # hidden_states: [batch_size, seq_length, hidden_size]
76
- # apply preprocessing if the input is padded
77
- if attention_mask is not None:
78
- alibi = pre_process_alibi_for_pad(alibi, attention_mask)
79
- # otherwise repeat alibi tensor with the batch size
80
- else:
81
- alibi = alibi.repeat(hidden_states.shape[0], 1, 1)
82
-
83
- mixed_x_layer = self.query_key_value(hidden_states)
84
-
85
- # [batch_size, seq_length, 3 x hidden_size] --> [batch_size, seq_length, num_heads, 3 x head_dim]
86
- new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_heads, 3 * self.head_dim)
87
- mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
88
-
89
- # [batch_size, seq_length, num_heads, 3 x head_dim] --> 3 [batch_size, seq_length, num_heads, head_dim]
90
- (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
91
-
92
- if layer_past is not None:
93
- past_key, past_value = layer_past
94
- key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=1)
95
- value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=1)
96
-
97
- if use_cache is True:
98
- present = (key_layer, value_layer)
99
- else:
100
- present = None
101
-
102
- # [batch_size, head_dim, q_length, k_length]
103
- output_size = (query_layer.size(0), query_layer.size(2), query_layer.size(1), key_layer.size(1))
104
-
105
- # [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
106
- query_layer = query_layer.transpose(1, 0).reshape(output_size[2], output_size[0] * output_size[1], -1)
107
-
108
- # [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
109
- key_layer = key_layer.transpose(1, 0).reshape(output_size[3], output_size[0] * output_size[1], -1)
110
-
111
- # Raw attention scores. [batch_size * num_heads, q_length, k_length]
112
- beta = 1.0 / self.layer_number
113
-
114
- matmul_result = torch.baddbmm(
115
- alibi,
116
- query_layer.transpose(1, 0),
117
- key_layer.transpose(1, 0).transpose(1, 2),
118
- beta=beta,
119
- alpha=(1.0 / self.norm_factor),
120
- )
121
-
122
- # change view to [batch_size, num_heads, q_length, k_length]
123
- attention_scores = matmul_result.view(*output_size)
124
-
125
- # attention scores and attention mask [b, np, sq, sk]
126
- max_positions = max(attention_scores.shape[-1], attention_scores.shape[-2])
127
- attention_probs = self.scale_mask_softmax(attention_scores, attention_mask, max_positions).to(value_layer.dtype)
128
- attention_probs = self.attention_dropout(attention_probs)
129
-
130
- if head_mask is not None:
131
- attention_probs = attention_probs * head_mask
132
-
133
- # context layer shape: [batch_size, num_heads, q_length, head_dim]
134
- output_size = (value_layer.size(0), value_layer.size(2), query_layer.size(0), value_layer.size(3))
135
-
136
- # change view [k_length, batch_size x num_heads, head_dim]
137
- value_layer = value_layer.transpose(1, 0).reshape(value_layer.size(1), output_size[0] * output_size[1], -1)
138
-
139
- # change view [batch_size x num_heads, q_length, k_length]
140
- attention_probs_reshaped = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
141
-
142
- # matmul: [batch_size * num_heads, q_length, head_dim]
143
- context_layer = torch.bmm(attention_probs_reshaped, value_layer.transpose(0, 1))
144
-
145
- # change view [batch_size, num_heads, q_length, head_dim]
146
- context_layer = context_layer.view(*output_size)
147
-
148
- # [batchs_size, num_heads, q_length, head_dim] --> [q_length, batch_size, num_heads, head_dim]
149
- context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
150
-
151
- # [q_length, batch_size, num_heads, head_dim] --> [q_length, batch_size, hidden_size]
152
- new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
153
-
154
- context_layer = context_layer.view(*new_context_layer_shape)
155
-
156
- # Output. [q_length, batch_size, hidden_size]
157
-
158
- # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
159
- output_tensor = self.dense(context_layer)
160
- output = output_tensor.transpose(1, 0)
161
-
162
- output = dropout_add(output, residual, self.hidden_dropout, self.training)
163
-
164
- outputs = (output, present)
165
- if output_attentions:
166
- outputs += (attention_probs,)
167
-
168
- return outputs
169
-
170
-
171
- class BloomMLP(nn.Module):
172
- def __init__(self, config):
173
- super().__init__()
174
- self.hidden_size = config.hidden_size
175
- self.dense_h_to_4h = nn.Linear(self.hidden_size, 4 * self.hidden_size)
176
- self.dense_4h_to_h = nn.Linear(4 * self.hidden_size, self.hidden_size)
177
- self.hidden_dropout = config.hidden_dropout
178
- self.gelu_impl = BloomGelu()
179
-
180
- def forward(self, hidden_states, residual):
181
- hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
182
- intermediate_output = self.dense_4h_to_h(hidden_states)
183
- output = dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
184
- return output
185
-
186
-
187
- class BloomBlock(nn.Module):
188
- def __init__(self, config, layer_number=None):
189
- super().__init__()
190
- self.hidden_size = config.hidden_size
191
-
192
- self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
193
- self.n_head = config.n_head
194
- self.self_attention = BloomAttention(config, layer_number=layer_number)
195
- self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_epsilon)
196
-
197
- self.mlp = BloomMLP(config)
198
-
199
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
200
- self.hidden_dropout = config.hidden_dropout
201
-
202
- def forward(
203
- self,
204
- hidden_states,
205
- layer_past=None,
206
- attention_mask=None,
207
- head_mask=None,
208
- use_cache=False,
209
- output_attentions=False,
210
- alibi=None,
211
- ):
212
- # hidden_states: [batch_size, seq_length, hidden_size]
213
-
214
- # Layer norm at the beginning of the transformer layer.
215
- layernorm_output = self.input_layernorm(hidden_states)
216
-
217
- # Layer norm post the self attention.
218
- if self.apply_residual_connection_post_layernorm:
219
- residual = layernorm_output
220
- else:
221
- residual = hidden_states
222
-
223
- # Self attention.
224
- attn_outputs = self.self_attention(
225
- layernorm_output,
226
- residual,
227
- layer_past=layer_past,
228
- attention_mask=attention_mask,
229
- alibi=alibi,
230
- head_mask=head_mask,
231
- use_cache=use_cache,
232
- output_attentions=output_attentions,
233
- )
234
-
235
- attention_output = attn_outputs[0]
236
-
237
- outputs = attn_outputs[1:]
238
-
239
- layernorm_output = self.post_attention_layernorm(attention_output)
240
-
241
- # Get residual
242
- if self.apply_residual_connection_post_layernorm:
243
- residual = layernorm_output
244
- else:
245
- residual = attention_output
246
-
247
- # MLP.
248
- output = self.mlp(layernorm_output, residual)
249
-
250
- if use_cache:
251
- outputs = (output,) + outputs
252
- else:
253
- outputs = (output,) + outputs[1:]
254
-
255
- return outputs # hidden_states, present, attentions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/bloom/from_pretrained.py DELETED
@@ -1,86 +0,0 @@
1
- """
2
- Utils for fetching pretrained model parts. Currently, this relies on huggingface transformers' from_pretrained code.
3
- If necessary, one can rewrite this to implement a different behavior, such as:
4
- - loading files from a local data source (e.g. S3)
5
- - load files via BitTorrent ( https://pypi.org/project/libtorrent/ ) or IPFS( https://docs.ipfs.io/how-to )
6
- - fetch the weights over IPoAC, using a fleet of trained pigeons ( http://www.faqs.org/rfcs/rfc1149.html )
7
-
8
- """
9
- from __future__ import annotations
10
-
11
- from typing import Optional, OrderedDict, Union
12
-
13
- import torch
14
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
15
- from transformers.modeling_utils import WEIGHTS_NAME
16
- from transformers.utils.hub import cached_path, hf_bucket_url
17
-
18
- from src.bloom import BloomBlock, BloomConfig
19
-
20
- use_hivemind_log_handler("in_root_logger")
21
- logger = get_logger(__file__)
22
-
23
- CLIENT_BRANCH = "main"
24
- BLOCK_BRANCH_PREFIX = "block_"
25
- USER_AGENT = {"file_type": "model", "framework": "pytorch", "from_auto_class": False}
26
- FORCE_DOWNLOAD = False
27
- RESUME_DOWNLOAD = False
28
- LOCAL_FILES_ONLY = False
29
-
30
-
31
- def load_pretrained_block(
32
- converted_model_name_or_path: str,
33
- block_index: int,
34
- config: Optional[BloomConfig] = None,
35
- torch_dtype: Union[torch.dtype, str] = "auto",
36
- use_auth_token: Optional[str] = None,
37
- cache_dir: Optional[str] = None,
38
- ) -> BloomBlock:
39
- """Load one BloomBlock from a converted model. See convert_model.py (or README.md) on how to convert it."""
40
- if config is None:
41
- config = BloomConfig.from_pretrained(converted_model_name_or_path, use_auth_token=use_auth_token)
42
- block = BloomBlock(config, layer_number=block_index)
43
- state_dict = _load_state_dict(
44
- converted_model_name_or_path, block_index, use_auth_token=use_auth_token, cache_dir=cache_dir
45
- )
46
- block.load_state_dict(state_dict)
47
-
48
- if torch_dtype == "auto":
49
- with torch.no_grad():
50
- for name, param in block.named_parameters():
51
- assert name in state_dict, f"{name} not in state dict"
52
- param.data = param.data.to(state_dict[name].dtype)
53
- else:
54
- assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
55
- block = block.to(dtype=torch_dtype)
56
-
57
- report = block.load_state_dict(state_dict, strict=True)
58
- logger.info(f"Loaded {converted_model_name_or_path} block {block_index}, {report}")
59
- return block
60
-
61
-
62
- def _load_state_dict(
63
- pretrained_model_name_or_path: str,
64
- block_index: Optional[int] = None,
65
- use_auth_token: Optional[str] = None,
66
- cache_dir: Optional[str] = None,
67
- ) -> OrderedDict[str, torch.Tensor]:
68
- revision = BLOCK_BRANCH_PREFIX + str(block_index) if block_index is not None else CLIENT_BRANCH
69
- archive_file = hf_bucket_url(pretrained_model_name_or_path, filename=WEIGHTS_NAME, revision=revision, mirror=None)
70
-
71
- # Load from URL or cache if already cached
72
- resolved_archive_file = cached_path(
73
- archive_file,
74
- cache_dir=cache_dir,
75
- force_download=FORCE_DOWNLOAD,
76
- proxies=None,
77
- resume_download=RESUME_DOWNLOAD,
78
- local_files_only=LOCAL_FILES_ONLY,
79
- use_auth_token=use_auth_token,
80
- user_agent=USER_AGENT,
81
- )
82
- state_dict = torch.load(resolved_archive_file, map_location="cpu")
83
- return state_dict
84
-
85
-
86
- DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/bloom/model.py DELETED
@@ -1,583 +0,0 @@
1
- """
2
- PyTorch BLOOM model that implements several memory-efficient modes.
3
- Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4
- See commit history for authorship.
5
- """
6
- from typing import Tuple, Union
7
-
8
- import torch
9
- import torch.nn.functional as F
10
- import torch.utils.checkpoint
11
- from hivemind import use_hivemind_log_handler
12
- from torch import nn
13
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
14
- from transformers.file_utils import (
15
- add_code_sample_docstrings,
16
- add_start_docstrings,
17
- add_start_docstrings_to_model_forward,
18
- )
19
- from transformers.modeling_outputs import (
20
- BaseModelOutputWithPastAndCrossAttentions,
21
- CausalLMOutputWithCrossAttentions,
22
- SequenceClassifierOutputWithPast,
23
- )
24
- from transformers.modeling_utils import PreTrainedModel
25
- from transformers.models.bloom.configuration_bloom import BloomConfig
26
- from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
27
- from transformers.utils import logging
28
-
29
- from src.bloom.block import BloomBlock
30
-
31
- use_hivemind_log_handler("in_root_logger")
32
- logger = logging.get_logger(__file__)
33
-
34
- _CHECKPOINT_FOR_DOC = "bigscience/Bloom"
35
- _CONFIG_FOR_DOC = "BloomConfig"
36
- _TOKENIZER_FOR_DOC = "BloomTokenizer"
37
-
38
-
39
- BLOOM_START_DOCSTRING = r"""
40
-
41
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
42
- library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
43
-
44
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
45
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
46
- and behavior.
47
-
48
- Parameters:
49
- config ([`MemoryEfficientBloomConfig`]): Model configuration class with all the parameters of the model.
50
- Initializing with a config file does not load the weights associated with the model, only the
51
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
52
- """
53
-
54
- BLOOM_INPUTS_DOCSTRING = r"""
55
- Args:
56
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
57
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
58
- `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
59
- sequence tokens in the vocabulary.
60
-
61
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
62
- `input_ids`.
63
-
64
- Indices can be obtained using [`BloomTokenizer`]. See [`PreTrainedTokenizer.encode`] and
65
- [`PreTrainedTokenizer.__call__`] for details.
66
-
67
- [What are input IDs?](../glossary#input-ids)
68
- past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
69
- Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
70
- `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
71
- their past given to this model should not be passed as `input_ids` as they have already been computed.
72
- attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
73
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
74
-
75
- - 1 for tokens that are **not masked**,
76
- - 0 for tokens that are **masked**.
77
-
78
- [What are attention masks?](../glossary#attention-mask)
79
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
80
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
81
- config.max_position_embeddings - 1]`.
82
-
83
- [What are position IDs?](../glossary#position-ids)
84
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
85
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
86
-
87
- - 1 indicates the head is **not masked**,
88
- - 0 indicates the head is **masked**.
89
-
90
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
91
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
92
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
93
- model's internal embedding lookup matrix.
94
-
95
- If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
96
- `past_key_values`).
97
- use_cache (`bool`, *optional*):
98
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
99
- `past_key_values`).
100
- output_attentions (`bool`, *optional*):
101
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
102
- tensors for more detail.
103
- output_hidden_states (`bool`, *optional*):
104
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
105
- more detail.
106
- return_dict (`bool`, *optional*):
107
- Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
108
- """
109
-
110
-
111
- @add_start_docstrings(
112
- "The bare Bloom Model transformer outputting raw hidden-states without any specific head on top.",
113
- BLOOM_START_DOCSTRING,
114
- )
115
- class BloomModel(BloomPreTrainedModel):
116
- def __init__(self, config):
117
- super().__init__(config)
118
- assert not config.slow_but_exact, "slow_but_exact mode was removed for code simplicity"
119
-
120
- self.embed_dim = config.hidden_size
121
- self.n_head = config.n_head
122
-
123
- # Embedding + LN Embedding
124
- self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
125
- self.word_embeddings_layernorm = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
126
-
127
- # Transformer blocks
128
- self.h = nn.ModuleList([BloomBlock(config, layer_number=i) for i in range(config.num_hidden_layers)])
129
-
130
- # Final Layer Norm
131
- self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
132
-
133
- self.gradient_checkpointing = False
134
-
135
- # Initialize weights and apply final processing
136
- self.post_init()
137
-
138
- def get_input_embeddings(self):
139
- return self.word_embeddings
140
-
141
- def set_input_embeddings(self, new_embeddings):
142
- self.word_embeddings = new_embeddings
143
-
144
- @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
145
- @add_code_sample_docstrings(
146
- processor_class=_TOKENIZER_FOR_DOC,
147
- checkpoint=_CHECKPOINT_FOR_DOC,
148
- output_type=BaseModelOutputWithPastAndCrossAttentions,
149
- config_class=_CONFIG_FOR_DOC,
150
- )
151
- def forward(
152
- self,
153
- input_ids=None,
154
- past_key_values=None,
155
- attention_mask=None,
156
- position_ids=None,
157
- head_mask=None,
158
- inputs_embeds=None,
159
- use_cache=None,
160
- output_attentions=None,
161
- output_hidden_states=None,
162
- return_dict=None,
163
- ):
164
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
165
- output_hidden_states = (
166
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
167
- )
168
- use_cache = use_cache if use_cache is not None else self.config.use_cache
169
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
170
-
171
- if input_ids is not None and inputs_embeds is not None:
172
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
173
- if position_ids is not None:
174
- logger.warning("position_ids are ignored in this bloom implementation")
175
- elif input_ids is not None:
176
- input_shape = input_ids.size()
177
- input_ids = input_ids.view(-1, input_shape[-1])
178
- elif inputs_embeds is not None:
179
- input_shape = inputs_embeds.size()[:-1]
180
- else:
181
- raise ValueError("You have to specify either input_ids or inputs_embeds")
182
-
183
- if past_key_values is None:
184
- past_key_values = tuple([None] * len(self.h))
185
-
186
- # Prepare head mask if needed
187
- # 1.0 in head_mask indicate we keep the head
188
- # attention_probs has shape bsz x n_head x N x N
189
- # head_mask has shape n_layer x batch x n_head x N x N
190
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
191
-
192
- if inputs_embeds is None:
193
- inputs_embeds = self.word_embeddings(input_ids)
194
-
195
- # Note: it supports only float32 or bfloat16 inputs
196
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
197
-
198
- output_shape = input_shape + (hidden_states.size(-1),)
199
-
200
- presents = () if use_cache else None
201
- all_self_attentions = () if output_attentions else None
202
- all_hidden_states = () if output_hidden_states else None
203
-
204
- # Compute alibi tensor: check build_alibi_tensor documentation
205
- current_sequence_length = hidden_states.shape[1]
206
- if past_key_values and past_key_values[0]:
207
- current_sequence_length += past_key_values[0][0].shape[1]
208
-
209
- for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
210
-
211
- if output_hidden_states:
212
- all_hidden_states = all_hidden_states + (hidden_states,)
213
-
214
- if self.gradient_checkpointing and self.training:
215
-
216
- if use_cache:
217
- logger.warning(
218
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
219
- )
220
- use_cache = False
221
-
222
- def create_custom_forward(module):
223
- def custom_forward(*inputs):
224
- # None for past_key_value
225
- return module(*inputs, use_cache, output_attentions, alibi=None)
226
-
227
- return custom_forward
228
-
229
- outputs = torch.utils.checkpoint.checkpoint(
230
- create_custom_forward(block),
231
- hidden_states,
232
- None,
233
- attention_mask,
234
- head_mask[i],
235
- )
236
- else:
237
- outputs = block(
238
- hidden_states,
239
- layer_past=layer_past,
240
- attention_mask=attention_mask,
241
- head_mask=head_mask[i],
242
- use_cache=use_cache,
243
- output_attentions=output_attentions,
244
- alibi=None,
245
- )
246
-
247
- hidden_states = outputs[0]
248
- if use_cache is True:
249
- presents = presents + (outputs[1],)
250
-
251
- if output_attentions:
252
- all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
253
-
254
- # Add last hidden state
255
- hidden_states = self.ln_f(hidden_states)
256
-
257
- if output_hidden_states:
258
- all_hidden_states = all_hidden_states + (hidden_states,)
259
-
260
- hidden_states = hidden_states.view(output_shape)
261
-
262
- if not return_dict:
263
- return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
264
-
265
- return BaseModelOutputWithPastAndCrossAttentions(
266
- last_hidden_state=hidden_states,
267
- past_key_values=presents,
268
- hidden_states=all_hidden_states,
269
- attentions=all_self_attentions,
270
- )
271
-
272
-
273
- @add_start_docstrings(
274
- """
275
- The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input
276
- embeddings).
277
- """,
278
- BLOOM_START_DOCSTRING,
279
- )
280
- class BloomForCausalLM(BloomPreTrainedModel):
281
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
282
-
283
- def __init__(self, config):
284
- super().__init__(config)
285
- self.transformer = BloomModel(config)
286
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
287
-
288
- # Initialize weights and apply final processing
289
- self.post_init()
290
-
291
- def get_output_embeddings(self):
292
- return self.lm_head
293
-
294
- def set_output_embeddings(self, new_embeddings):
295
- self.lm_head = new_embeddings
296
-
297
- def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
298
- # only last token for inputs_ids if past is defined in kwargs
299
- if past:
300
- input_ids = input_ids[:, -1].unsqueeze(-1)
301
-
302
- attention_mask = kwargs.get("attention_mask", None)
303
- position_ids = kwargs.get("position_ids", None)
304
-
305
- if attention_mask is not None and position_ids is None:
306
- # create position_ids on the fly for batch generation
307
- position_ids = attention_mask.long().cumsum(-1) - 1
308
- position_ids.masked_fill_(attention_mask == 0, 1)
309
- if past:
310
- position_ids = position_ids[:, -1].unsqueeze(-1)
311
- else:
312
- position_ids = None
313
- return {
314
- "input_ids": input_ids,
315
- "past_key_values": past,
316
- "use_cache": kwargs.get("use_cache"),
317
- "position_ids": position_ids,
318
- "attention_mask": attention_mask,
319
- }
320
-
321
- @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
322
- @add_code_sample_docstrings(
323
- processor_class=_TOKENIZER_FOR_DOC,
324
- checkpoint=_CHECKPOINT_FOR_DOC,
325
- output_type=CausalLMOutputWithCrossAttentions,
326
- config_class=_CONFIG_FOR_DOC,
327
- )
328
- def forward(
329
- self,
330
- input_ids=None,
331
- past_key_values=None,
332
- attention_mask=None,
333
- position_ids=None,
334
- head_mask=None,
335
- inputs_embeds=None,
336
- labels=None,
337
- use_cache=None,
338
- output_attentions=None,
339
- output_hidden_states=None,
340
- return_dict=None,
341
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
342
- r"""
343
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
344
- Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
345
- `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
346
- are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
347
- """
348
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
349
-
350
- transformer_outputs = self.transformer(
351
- input_ids,
352
- past_key_values=past_key_values,
353
- attention_mask=attention_mask,
354
- position_ids=position_ids,
355
- head_mask=head_mask,
356
- inputs_embeds=inputs_embeds,
357
- use_cache=use_cache,
358
- output_attentions=output_attentions,
359
- output_hidden_states=output_hidden_states,
360
- return_dict=return_dict,
361
- )
362
- hidden_states = transformer_outputs[0]
363
-
364
- lm_logits = self.lm_head(hidden_states)
365
-
366
- loss = None
367
- if labels is not None:
368
- # Shift so that tokens < n predict n
369
- shift_logits = lm_logits[..., :-1, :].contiguous()
370
- shift_labels = labels[..., 1:].contiguous()
371
- # Flatten the tokens
372
- loss_fct = CrossEntropyLoss()
373
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
374
-
375
- if not return_dict:
376
- output = (lm_logits,) + transformer_outputs[1:]
377
- return ((loss,) + output) if loss is not None else output
378
-
379
- return CausalLMOutputWithCrossAttentions(
380
- loss=loss,
381
- logits=lm_logits,
382
- past_key_values=transformer_outputs.past_key_values,
383
- hidden_states=transformer_outputs.hidden_states,
384
- attentions=transformer_outputs.attentions,
385
- )
386
-
387
- @staticmethod
388
- def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
389
- """
390
- This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
391
- [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
392
- beam_idx at every generation step.
393
- """
394
- return tuple(
395
- tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
396
- for layer_past in past
397
- )
398
-
399
-
400
- @add_start_docstrings(
401
- """
402
- The modified language modeling head which does not create extra tensor for the linear layer with weights tied to the input
403
- embeddings. Thus, it reduces initial memory consumption which might be crucial for large dictionaries.
404
- In addition, it provides an effcient way to deal with half-precision word embeddings on CPU.
405
- """,
406
- BLOOM_START_DOCSTRING,
407
- )
408
- class LMHead(nn.Module):
409
- def __init__(self, config, word_embeddings: nn.Embedding):
410
- super().__init__()
411
- self.word_embeddings = word_embeddings
412
- self.chunk_size = config.chunk_size_for_efficient_fp16_on_cpu
413
-
414
- @property
415
- def in_features(self) -> int:
416
- return self.word_embeddings.num_embeddings
417
-
418
- @property
419
- def out_features(self) -> int:
420
- return self.word_embeddings.embedding_dim
421
-
422
- @property
423
- def weight(self):
424
- return self.word_embeddings.weight
425
-
426
- @property
427
- def bias(self):
428
- return None
429
-
430
- def forward(self, hidden_states):
431
- word_embeddings = self.word_embeddings.weight
432
-
433
- # We use 'chunked_forward' only when embeddings are in half-precision on CPU.
434
- if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
435
- lm_logits = self.chunked_forward(hidden_states)
436
- else:
437
- # Switch dtype in case word_embeddings are fp16/bf16
438
- hidden_states = hidden_states.to(word_embeddings.dtype)
439
- lm_logits = F.linear(hidden_states, word_embeddings).float()
440
- return lm_logits
441
-
442
- def chunked_forward(self, hidden_states):
443
- """Splits word embeddings on chunks and iteratively casts them into fp32 to perform matmul more efficiently on CPU.
444
- chunk_size: provides trade-off between efficiency and extra memory consumption.
445
- """
446
- assert self.chunk_size > 0, "Chunk size for chunked forward must be positive"
447
-
448
- word_embeddings = self.word_embeddings.weight
449
- num_embeddings = self.word_embeddings.num_embeddings
450
-
451
- hidden_states = hidden_states.float()
452
- output = torch.zeros(*hidden_states.shape[:-1], num_embeddings)
453
-
454
- for i in range(0, num_embeddings, self.chunk_size):
455
- chunk = word_embeddings[i : i + self.chunk_size].float()
456
- output[..., i : i + self.chunk_size] = F.linear(hidden_states, chunk)
457
- return output
458
-
459
-
460
- @add_start_docstrings(
461
- """
462
- The Bloom Model transformer with a sequence classification head on top (linear layer).
463
- [`BloomForSequenceClassification`] uses the last token in order to do the classification, as other causal models
464
- (e.g. GPT-1) do.
465
- Since it does classification on the last token, it requires to know the position of the last token. If a
466
- `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
467
- no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
468
- padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
469
- each row of the batch).
470
- """,
471
- BLOOM_START_DOCSTRING,
472
- )
473
- class BloomForSequenceClassification(BloomPreTrainedModel):
474
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
475
-
476
- def __init__(self, config):
477
- super().__init__(config)
478
- self.num_labels = config.num_labels
479
- self.transformer = BloomModel(config)
480
- self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
481
-
482
- # Initialize weights and apply final processing
483
- self.post_init()
484
-
485
- @add_start_docstrings_to_model_forward(BLOOM_INPUTS_DOCSTRING)
486
- @add_code_sample_docstrings(
487
- processor_class=_TOKENIZER_FOR_DOC,
488
- checkpoint=_CHECKPOINT_FOR_DOC,
489
- output_type=SequenceClassifierOutputWithPast,
490
- config_class=_CONFIG_FOR_DOC,
491
- )
492
- def forward(
493
- self,
494
- input_ids=None,
495
- past_key_values=None,
496
- attention_mask=None,
497
- position_ids=None,
498
- head_mask=None,
499
- inputs_embeds=None,
500
- labels=None,
501
- use_cache=None,
502
- output_attentions=None,
503
- output_hidden_states=None,
504
- return_dict=None,
505
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
506
- r"""
507
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
508
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
509
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
510
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
511
- """
512
-
513
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
514
-
515
- transformer_outputs = self.transformer(
516
- input_ids,
517
- past_key_values=past_key_values,
518
- attention_mask=attention_mask,
519
- position_ids=position_ids,
520
- head_mask=head_mask,
521
- inputs_embeds=inputs_embeds,
522
- use_cache=use_cache,
523
- output_attentions=output_attentions,
524
- output_hidden_states=output_hidden_states,
525
- return_dict=return_dict,
526
- )
527
-
528
- hidden_states = transformer_outputs[0]
529
- logits = self.score(hidden_states)
530
-
531
- if input_ids is not None:
532
- batch_size = input_ids.shape[0]
533
- else:
534
- batch_size = inputs_embeds.shape[0]
535
-
536
- if self.config.pad_token_id is None and batch_size != 1:
537
- raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
538
- if self.config.pad_token_id is None:
539
- sequence_lengths = -1
540
- else:
541
- if input_ids is not None:
542
- sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
543
- else:
544
- sequence_lengths = -1
545
- logger.warning(
546
- f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
547
- "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
548
- )
549
-
550
- pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
551
- loss = None
552
- if labels is not None:
553
- if self.config.problem_type is None:
554
- if self.num_labels == 1:
555
- self.config.problem_type = "regression"
556
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
557
- self.config.problem_type = "single_label_classification"
558
- else:
559
- self.config.problem_type = "multi_label_classification"
560
-
561
- if self.config.problem_type == "regression":
562
- loss_fct = MSELoss()
563
- if self.num_labels == 1:
564
- loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
565
- else:
566
- loss = loss_fct(pooled_logits, labels)
567
- elif self.config.problem_type == "single_label_classification":
568
- loss_fct = CrossEntropyLoss()
569
- loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
570
- elif self.config.problem_type == "multi_label_classification":
571
- loss_fct = BCEWithLogitsLoss()
572
- loss = loss_fct(pooled_logits, labels)
573
- if not return_dict:
574
- output = (pooled_logits,) + transformer_outputs[1:]
575
- return ((loss,) + output) if loss is not None else output
576
-
577
- return SequenceClassifierOutputWithPast(
578
- loss=loss,
579
- logits=pooled_logits,
580
- past_key_values=transformer_outputs.past_key_values,
581
- hidden_states=transformer_outputs.hidden_states,
582
- attentions=transformer_outputs.attentions,
583
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/bloom/ops.py DELETED
@@ -1,246 +0,0 @@
1
- """
2
- Utility operations used in the the BLOOM model
3
- Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
4
- See commit history for authorship.
5
- """
6
- import math
7
-
8
- import torch
9
- import torch.autograd
10
- import torch.nn.functional as F
11
- from torch import nn
12
-
13
-
14
- def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
15
- """Split a tensor along its last dimension.
16
-
17
- Args:
18
- tensor: ([`torch.tensor`], *required*):
19
- input tensor to split
20
- num_partitions ([`int`], *required*):
21
- number of partitions to split the tensor
22
- contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
23
- If True, make each chunk contiguous in memory.
24
- """
25
- # Get the size and dimension.
26
- last_dim = tensor.dim() - 1
27
- numerator, denominator = tensor.size()[last_dim], num_partitions
28
- if not (numerator % denominator == 0):
29
- raise ValueError(f"{numerator} is not divisible by {denominator}")
30
- last_dim_size = numerator // denominator
31
- # Split.
32
- tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
33
- # Note: torch.split does not create contiguous tensors by default.
34
- if contiguous_split_chunks:
35
- return tuple(chunk.contiguous() for chunk in tensor_list)
36
-
37
- return tensor_list
38
-
39
-
40
- def attention_mask_func(attention_scores, attention_mask, causal_mask):
41
- if attention_mask.dtype == torch.bool:
42
- attention_mask_bool = ~attention_mask
43
- else:
44
- attention_mask_bool = (1 - attention_mask).bool()
45
-
46
- query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
47
- padded_causal_mask = (
48
- attention_mask_bool[:, None, key_length - query_length : key_length, None]
49
- + ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
50
- ).bool()
51
- padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
52
- # Make use of floats
53
- return (
54
- attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
55
- padded_causal_mask,
56
- )
57
-
58
-
59
- def build_alibi_tensor(
60
- max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
61
- ) -> torch.Tensor:
62
- """
63
- Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
64
- relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
65
- `softmax(l+a) = softmax(l)`. Based on
66
- https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
67
- Args:
68
- Returns tensor shaped (n_head, 1, max_seq_len)
69
- max_seq_len: (`int`, *required*):
70
- max sequence length
71
- n_head: (`int`, *required*):
72
- number of heads
73
- dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
74
- dtype of the output tensor
75
- device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
76
- device of the output alibi tensor
77
- """
78
- closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
79
- base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
80
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
81
- slopes = torch.pow(base, powers)
82
-
83
- if closest_power_of_2 != n_head:
84
- extra_base = torch.tensor(
85
- 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
86
- )
87
- num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
88
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
89
- slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
90
-
91
- lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
92
- return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
93
-
94
-
95
- def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
96
- """
97
- Args:
98
- Pre-process the alibi tensor for padding.
99
- alibi: ([`torch.tensor`], *required*):
100
- alibi tensor to pre-process
101
- attention_mask: ([`torch.tensor`], *required*):
102
- attention mask to pre-process
103
- """
104
- assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
105
- unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
106
- # ^-- [batch, max_len], values correspond to element indices after removing padding
107
- # We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
108
- alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
109
- return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
110
-
111
-
112
- def dropout_add(x, residual, prob, training):
113
- """
114
- Dropout add function
115
-
116
- Args:
117
- x (`torch.tensor`, *required*):
118
- input tensor
119
- residual (`torch.tensor`, *rquired*):
120
- esidual tensor
121
- prob (`float`, *required*):
122
- dropout probability
123
- training (`bool`, *required*):
124
- training mode
125
- """
126
- out = nn.functional.dropout(x, p=prob, training=training)
127
- out = residual + out
128
- return out
129
-
130
-
131
- def bloom_gelu_forward(x):
132
- """
133
- Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
134
- make the model jitable.
135
-
136
- Args:
137
- x (`torch.tensor`, *required*):
138
- input hidden states
139
- """
140
- return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
141
-
142
-
143
- def bloom_gelu_back(g, x):
144
- """
145
- gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
146
- 0.3989423 * x * torch.exp(-0.5 * x * x)
147
-
148
- Args:
149
- g (`torch.tensor`, *required*):
150
- gradient output tensor
151
- x (`torch.tensor`, *required*):
152
- input tensor
153
- """
154
- x = x[0] # x is a tuple of 1 element, needs to unpack it first
155
- tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
156
- # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
157
- ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
158
- return ff * g
159
-
160
-
161
- class GeLUFunction(torch.autograd.Function):
162
- @staticmethod
163
- def forward(ctx, input):
164
- ctx.save_for_backward(input)
165
- return bloom_gelu_forward(input)
166
-
167
- @staticmethod
168
- def backward(ctx, grad_output):
169
- input = ctx.saved_tensors
170
- tmp = bloom_gelu_back(grad_output, input)
171
- return tmp
172
-
173
-
174
- class BloomGelu(nn.Module):
175
- """
176
- BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
177
- torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
178
- copied from Megatron-DeepSpeed code and adapted for our needs
179
-
180
- See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
181
-
182
- """
183
-
184
- def __init__(self):
185
- super().__init__()
186
-
187
- def forward(self, x):
188
- if self.training:
189
- return GeLUFunction.apply(x)
190
- else:
191
- return bloom_gelu_forward(x)
192
-
193
-
194
- class BloomScaledSoftmax(nn.Module):
195
- """
196
- fused operation: scaling + mask + softmax
197
-
198
- Args:
199
- input_in_fp16 (`bool`, *required*):
200
- flag to indicate if input in fp16 data format.
201
- input_in_bf16 (`bool`, *required*):
202
- flag to indicate if input in bf16 data format.
203
- scaled_masked_softmax_fusion (`bool`, *required*):
204
- flag to indicate user want to use softmax fusion
205
- mask_func (`function`, *required*):
206
- mask function to be applied.
207
- softmax_in_fp32 (`bool`, *required*):
208
- if true, softmax in performed at fp32 precision.
209
- scale (`float`, *required*):
210
- scaling factor used in input tensor scaling.
211
- """
212
-
213
- def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
214
- super().__init__()
215
- self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
216
- self.mask_func = mask_func
217
- self.softmax_in_fp32 = softmax_in_fp32
218
- self.scale = scale
219
-
220
- if not (self.scale is None or softmax_in_fp32):
221
- raise ValueError("softmax should be in fp32 when scaled")
222
-
223
- def forward(self, input, mask, max_positions):
224
- input_dtype = input.dtype
225
- input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
226
- softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
227
-
228
- if self.scale is not None:
229
- input = input * self.scale
230
-
231
- if mask is None:
232
- mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
233
-
234
- mask = mask.to(input.device)
235
- causal_mask = (
236
- torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
237
- .view(1, 1, max_positions, max_positions)
238
- .to(input.device)
239
- )
240
- mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
241
- probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
242
-
243
- if input_in_16bit and self.softmax_in_fp32:
244
- probs = probs.to(dtype=input_dtype)
245
-
246
- return probs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/__init__.py DELETED
@@ -1,5 +0,0 @@
1
- from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession
2
- from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
3
- from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
4
- from src.client.sequence_manager import RemoteSequenceManager
5
- from src.client.spending_policy import NoSpendingPolicy, SpendingPolicyBase
 
 
 
 
 
 
petals/src/client/inference_session.py DELETED
@@ -1,216 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import asyncio
4
- import contextlib
5
- from typing import AsyncIterator, List, Optional
6
-
7
- import torch
8
- from hivemind import (
9
- P2P,
10
- MSGPackSerializer,
11
- anext,
12
- deserialize_torch_tensor,
13
- get_logger,
14
- nested_flatten,
15
- serialize_torch_tensor,
16
- use_hivemind_log_handler,
17
- )
18
- from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
19
- from hivemind.p2p import StubBase
20
- from hivemind.proto import runtime_pb2
21
-
22
- from src.client.sequence_manager import RemoteSequenceManager
23
- from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
24
- from src.server.handler import TransformerConnectionHandler
25
- from src.utils.misc import DUMMY, is_dummy
26
-
27
- use_hivemind_log_handler("in_root_logger")
28
- logger = get_logger(__file__)
29
-
30
-
31
- class RemoteTransformerBlockInferenceSession:
32
- """
33
- An interface to a single multi-step *inference* session for a specific remote module on a specific server
34
-
35
- :note: this inference session is *not* fault-tolerant out of the box
36
- """
37
-
38
- def __init__(
39
- self,
40
- uid: ModuleUID,
41
- rpc_info: RPCInfo,
42
- inputs_queue: asyncio.Queue,
43
- outputs_aiter: AsyncIterator,
44
- *,
45
- max_length: int,
46
- points: int = 0,
47
- ):
48
- self.uid, self.rpc_info = uid, rpc_info
49
- self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
50
- # warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
51
- # using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
52
- self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
53
- self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
54
- self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
55
- self.stepped = False
56
- self.closed = False
57
-
58
- @classmethod
59
- async def _create(
60
- cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata
61
- ) -> RemoteTransformerBlockInferenceSession:
62
- """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
63
- inputs_queue = asyncio.Queue()
64
- outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout)
65
- return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata)
66
-
67
- @staticmethod
68
- async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator:
69
- while True:
70
- next_input_message = await asyncio.wait_for(queue.get(), timeout)
71
- yield next_input_message
72
- if not next_input_message.uid and not next_input_message.tensors:
73
- break # this message means "done sending"
74
-
75
- def step(
76
- self,
77
- new_hidden_states: torch.Tensor,
78
- prompts: Optional[torch.Tensor] = None,
79
- hypo_ids: Optional[torch.Tensor] = None,
80
- ):
81
- """
82
- Inference step: send a chunk of input tesors and receive a chunk of outputs
83
- :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
84
- if specified, deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]
85
- """
86
- if self.closed:
87
- raise Exception("Session is closed, cannot perform step")
88
- if prompts is None or is_dummy(prompts):
89
- prompts = DUMMY
90
- else:
91
- assert prompts.ndim == 4, "deep promts should have shape [num_layers, batch_size, prefix_len, hid_size]"
92
- assert prompts.shape[0] == self.num_blocks
93
- assert prompts.shape[1] in (new_hidden_states.shape[0], 1)
94
- assert prompts.shape[2] <= new_hidden_states.shape[1]
95
- assert prompts.shape[3] == new_hidden_states.shape[2]
96
-
97
- if hypo_ids is None or is_dummy(hypo_ids):
98
- hypo_ids = DUMMY
99
- else:
100
- assert len(hypo_ids) == len(new_hidden_states)
101
- assert hypo_ids.dtype == torch.int64
102
-
103
- # serialize inputs and put them into the queue
104
- inputs = (new_hidden_states, prompts, hypo_ids)
105
- outputs_serialized = RemoteExpertWorker.run_coroutine(
106
- self._step(
107
- runtime_pb2.ExpertRequest(
108
- uid=self.uid,
109
- tensors=[
110
- serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
111
- for tensor, proto in zip(inputs, nested_flatten(self.rpc_info["inference_schema"]))
112
- ],
113
- metadata=self._serialized_metadata if not self.stepped else None,
114
- )
115
- )
116
- )
117
- outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
118
- assert outputs[0].shape == inputs[0].shape, f"expected outputs[0] to be hidden states but got {outputs[0]}"
119
- return outputs[0]
120
-
121
- async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
122
- """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
123
- await self._inputs_queue.put(inputs_serialized)
124
- self.stepped = True
125
- return await anext(self._outputs_stream)
126
-
127
- def close(self):
128
- """Finish a given inference session, close the underlying connection"""
129
- if self._outputs_stream is None:
130
- return # already closed
131
- RemoteExpertWorker.run_coroutine(self._aclose_stream())
132
- self._outputs_stream = self._inputs_queue = None
133
- self.closed = True
134
-
135
- async def _aclose_stream(self):
136
- """Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
137
- if self._outputs_stream is None:
138
- return # already closed
139
- if self.stepped:
140
- await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
141
- try:
142
- await anext(self._outputs_stream)
143
- except StopAsyncIteration:
144
- pass
145
-
146
- def __del__(self):
147
- self.close()
148
-
149
- def __enter__(self):
150
- assert not self.closed
151
- return self
152
-
153
- def __exit__(self, *exc_details):
154
- self.close()
155
-
156
-
157
- class RemoteSequentialInferenceSession:
158
- """
159
- An interface to a multi-step *inference* session for a sequence of remote transformer blocks
160
- """
161
-
162
- def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata):
163
- self.sequence_manager = sequence_manager
164
- self.p2p = p2p
165
- self.closed = False
166
- self.chosen_spans: List[RemoteSpanInfo] = []
167
- self.stack = contextlib.ExitStack()
168
- self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = []
169
- self.metadata = metadata
170
- self.timeout = timeout
171
-
172
- def __enter__(self):
173
- assert not self.closed and not self.chosen_spans
174
- self.stack.__enter__()
175
- # TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail
176
- self.chosen_spans.extend(self.sequence_manager.make_sequence())
177
-
178
- for chosen_span in self.chosen_spans:
179
- stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id)
180
- span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end])
181
- inference_session = RemoteExpertWorker.run_coroutine(
182
- RemoteTransformerBlockInferenceSession._create(
183
- stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata
184
- )
185
- )
186
- self.inference_sessions.append(inference_session)
187
- self.stack.enter_context(inference_session)
188
-
189
- return self
190
-
191
- def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs):
192
- assert not self.closed
193
- if torch.is_grad_enabled():
194
- logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
195
- if prompts is None or is_dummy(prompts):
196
- prompts = DUMMY
197
- else:
198
- assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager)
199
- for session in self.inference_sessions:
200
- outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs)
201
- assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}"
202
- inputs = outputs
203
- return inputs
204
-
205
- def close(self, *exc_details):
206
- """Finish a given inference session, close the underlying connection"""
207
- if not self.closed:
208
- self.stack.__exit__(*exc_details or (None, None, None))
209
- self.inference_sessions.clear()
210
- self.closed = True
211
-
212
- def __exit__(self, *exc_details):
213
- self.close(*exc_details)
214
-
215
- def __del__(self):
216
- self.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/remote_forward_backward.py DELETED
@@ -1,156 +0,0 @@
1
- """
2
- Utility functions that call RPC forward or backward on a single remote server
3
- """
4
- import asyncio
5
- from typing import Iterable, List, Sequence, Tuple
6
-
7
- import torch
8
- from hivemind import nested_compare, nested_flatten, nested_pack, serialize_torch_tensor
9
- from hivemind.compression.serialization import deserialize_tensor_stream, deserialize_torch_tensor
10
- from hivemind.p2p import StubBase
11
- from hivemind.p2p.p2p_daemon_bindings.control import DEFAULT_MAX_MSG_SIZE, MAX_UNARY_PAYLOAD_SIZE
12
- from hivemind.proto import runtime_pb2
13
- from hivemind.utils.asyncio import amap_in_executor, iter_as_aiter
14
- from hivemind.utils.streaming import split_for_streaming
15
-
16
- from src.data_structures import ModuleUID, RPCInfo
17
-
18
-
19
- async def run_remote_forward(
20
- uid: ModuleUID, stub: StubBase, rpc_info: RPCInfo, *inputs: torch.Tensor, metadata: bytes = b"", **kwargs
21
- ) -> Tuple[torch.Tensor, ...]:
22
- """
23
- Serializes input tensors and calls "rpc_forward" on a remote server.
24
- Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L198
25
- but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
26
- """
27
-
28
- # Note: *inputs are flattened input tensors that follow the expert's info['input_schema']
29
- # detach to avoid pickling the computation graph
30
- assert len(kwargs) == len(rpc_info["keyword_names"]), f"Keyword args should be {rpc_info['keyword_names']}"
31
- kwargs = {key: kwargs[key] for key in rpc_info["keyword_names"]}
32
-
33
- # Note: we put keyword arguments in the same order as on a server to prevent f(a=1, b=2) != f(b=2, a=1) errors
34
- forward_inputs = (inputs, kwargs)
35
-
36
- # Modify forward_schema to support prompts
37
- args_schema, kwargs_schema = rpc_info["forward_schema"]
38
- # TODO: rm this assert when support arbitrary number of input tensors
39
- assert len(args_schema) == 1 and len(inputs) == 2
40
- forward_schema_with_prompts = (tuple(args_schema * len(inputs)), kwargs_schema)
41
-
42
- if not nested_compare(forward_inputs, forward_schema_with_prompts):
43
- raise TypeError(f"Inputs do not match expert input schema. Did you pass the right number of parameters?")
44
-
45
- forward_inputs = nested_flatten(forward_inputs)
46
- inputs = tuple(tensor.cpu().detach() for tensor in forward_inputs)
47
-
48
- # Asynchronous serialization
49
- loop = asyncio.get_running_loop()
50
- serialized_tensors = await asyncio.gather(
51
- *(
52
- loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
53
- for tensor, proto in zip(inputs, nested_flatten(forward_schema_with_prompts))
54
- )
55
- )
56
-
57
- # call RPC on remote server
58
- size = sum(t.element_size() * t.nelement() for t in inputs)
59
- if size > MAX_UNARY_PAYLOAD_SIZE:
60
- deserialized_outputs = await _forward_stream(uid, serialized_tensors, stub, **kwargs)
61
- else:
62
- deserialized_outputs = await _forward_unary(uid, serialized_tensors, stub, **kwargs)
63
-
64
- return nested_pack(deserialized_outputs, structure=rpc_info["outputs_schema"])
65
-
66
-
67
- async def _forward_stream(
68
- uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
69
- ) -> List[torch.Tensor]:
70
- split = (p for t in serialized_tensors for p in split_for_streaming(t, DEFAULT_MAX_MSG_SIZE))
71
-
72
- outputs = await stub.rpc_forward_stream(
73
- amap_in_executor(
74
- lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
75
- iter_as_aiter(split),
76
- ),
77
- )
78
-
79
- tensors_stream = amap_in_executor(lambda msg: msg.tensors, outputs)
80
- return await deserialize_tensor_stream(tensors_stream)
81
-
82
-
83
- async def _forward_unary(
84
- uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
85
- ) -> List[torch.Tensor]:
86
- outputs: runtime_pb2.ExpertResponse = await stub.rpc_forward(
87
- runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
88
- )
89
- return [deserialize_torch_tensor(t) for t in outputs.tensors]
90
-
91
-
92
- async def _backward_stream(
93
- uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
94
- ) -> List[torch.Tensor]:
95
- split = (part for tensor in serialized_tensors for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE))
96
-
97
- grad_inputs = await stub.rpc_backward_stream(
98
- amap_in_executor(
99
- lambda tensor: runtime_pb2.ExpertRequest(uid=uid, tensors=[tensor], **kwargs),
100
- iter_as_aiter(split),
101
- ),
102
- )
103
- tensors_stream = amap_in_executor(lambda msg: msg.tensors, grad_inputs)
104
- return await deserialize_tensor_stream(tensors_stream)
105
-
106
-
107
- async def run_remote_backward(
108
- uid: ModuleUID,
109
- stub: StubBase,
110
- rpc_info: RPCInfo,
111
- inputs: torch.Tensor,
112
- grad_outputs: List[torch.Tensor],
113
- *extra_tensors: torch.Tensor,
114
- **kwargs,
115
- ) -> Sequence[torch.Tensor]:
116
- """
117
- Serializes grad outputs and calls "rpc_backward" on a remote server.
118
- Mostly adapted from https://github.com/learning-at-home/hivemind/blob/7a7c93aefffc9494c39e7b170c07cb06d8c09c4c/hivemind/moe/client/expert.py#L221
119
- but without RemoteExpertWorker.run_coroutine() call that leads to deadlock here.
120
- """
121
-
122
- grad_outputs_cpu = tuple(tensor.cpu() for tensor in grad_outputs)
123
- inputs_and_grad_outputs = tuple(nested_flatten((inputs, grad_outputs_cpu, *extra_tensors)))
124
-
125
- # Modify forward_schema to support prompts
126
- args_schema, kwargs_schema = rpc_info["forward_schema"]
127
- assert len(args_schema) == 1 and isinstance(inputs, torch.Tensor)
128
- # TODO generalize this
129
- prompts_schema = next(iter(args_schema))
130
- backward_schema = tuple(nested_flatten((rpc_info["forward_schema"], rpc_info["outputs_schema"], prompts_schema)))
131
-
132
- # Asynchronous serialization
133
- loop = asyncio.get_running_loop()
134
- serialized_tensors = await asyncio.gather(
135
- *(
136
- loop.run_in_executor(None, serialize_torch_tensor, tensor.to(proto.dtype), proto.compression)
137
- for tensor, proto in zip(inputs_and_grad_outputs, backward_schema)
138
- )
139
- )
140
-
141
- size = sum(t.element_size() * t.nelement() for t in inputs_and_grad_outputs)
142
- if size > MAX_UNARY_PAYLOAD_SIZE:
143
- deserialized_grad_inputs = await _backward_stream(uid, serialized_tensors, stub, **kwargs)
144
- else:
145
- deserialized_grad_inputs = await _backward_unary(uid, serialized_tensors, stub, **kwargs)
146
-
147
- return deserialized_grad_inputs
148
-
149
-
150
- async def _backward_unary(
151
- uid: str, serialized_tensors: Iterable[runtime_pb2.Tensor], stub, **kwargs
152
- ) -> List[torch.Tensor]:
153
- grad_inputs: runtime_pb2.ExpertResponse = await stub.rpc_backward(
154
- runtime_pb2.ExpertRequest(uid=uid, tensors=list(serialized_tensors), **kwargs)
155
- )
156
- return [deserialize_torch_tensor(t) for t in grad_inputs.tensors]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/remote_generation.py DELETED
@@ -1,257 +0,0 @@
1
- from typing import List, Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- from src.utils.generation_algorithms import DecodingAlgorithm, GreedyAlgorithm, NucleusAlgorithm, TopKAlgorithm
7
- from src.utils.generation_constraints import ABCBloomConstraint, EosConstraint, MaxNewTokensConstraint
8
-
9
-
10
- class RemoteGenerationMixin:
11
- """
12
- A class containing all functions for auto-regressive text generation, to be used as a mixin in [`BloomForCausalLM`].
13
- The class exposes can be used for:
14
- - *greedy decoding*.
15
- - *multinomial sampling*.
16
-
17
- This class is similar to transformer's [`generation_utils.GenerationMixin`], it can be used instead of it. However, it has some differences.
18
- """
19
-
20
- @torch.no_grad()
21
- def generate(
22
- self,
23
- inputs: Optional[torch.Tensor] = None,
24
- do_sample: Optional[bool] = None,
25
- temperature: float = 1.0,
26
- top_k: Optional[int] = None,
27
- top_p: Optional[float] = None,
28
- bos_token_id: Optional[int] = None,
29
- eos_token_id: Optional[int] = None,
30
- pad_token_id: Optional[int] = None,
31
- max_length: Optional[int] = None,
32
- max_new_tokens: Optional[int] = None,
33
- decoding_algorithm: Optional[DecodingAlgorithm] = None,
34
- provided_constraints: List[ABCBloomConstraint] = [],
35
- **model_kwargs,
36
- ) -> torch.LongTensor:
37
- """
38
- Generates sequences of token ids for models with a language modeling head.
39
-
40
- :param inputs: The input tokens to the model.
41
- :param do_sample: Whether to sample from the model predictions or take the argmax.
42
- :param temperature: The temperature to use for sampling.
43
- :param top_k: The number of results to return.
44
- :param top_p: The cumulative probability of results to return.
45
- :param bos_token_id: The id of the beginning of sentence token.
46
- :param eos_token_id: The id of the end of sentence token.
47
- :param pad_token_id: The id of the padding token.
48
- :param max_new_tokens: The maximum number of tokens to generate.
49
- :param decoding_algorithm: The decoding algorithm to use.
50
- :param provided_constraints: A list of constraints to use.
51
- :param model_kwargs: Additional arguments to pass to the model.
52
- """
53
-
54
- assert (
55
- model_kwargs.get("logits_processor", None) is None
56
- ), "For RemoteGenerationMixin models use BloomConstraints instead of logits_processor"
57
- assert (
58
- model_kwargs.get("logits_wrapper", None) is None
59
- ), "For RemoveGenerationMixin models use DecodingAlgorithm instead of logits_wrapper"
60
- assert (
61
- model_kwargs.get("stopping_criteria", None) is None
62
- ), "For RemoteGenerationMixin models use BloomConstraints instead of stopping_criteria"
63
- if inputs is not None:
64
- assert isinstance(inputs, torch.Tensor) and inputs.ndim == 2, "inputs must be a 2d tensor [batch, length]"
65
- prefix_length = 0 if inputs is None else inputs.size(1)
66
- prefix_length += self.config.pre_seq_len
67
-
68
- bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
69
- pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
70
- eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
71
-
72
- assert (max_length is None) != (max_new_tokens is None), "please set max_length or max_new_tokens (not both)"
73
- if max_length is not None and max_new_tokens is None:
74
- max_new_tokens = max_length - prefix_length
75
- assert max_new_tokens > 0, f"Provided max_length is less than prefix size: {max_length} < {inputs.size(1)}"
76
- elif max_length is None and max_new_tokens is not None:
77
- max_length = prefix_length + max_new_tokens
78
-
79
- if inputs is None:
80
- assert bos_token_id is not None, "You have to provide a bos_token_id if you do not provide inputs"
81
- inputs = torch.tensor([[bos_token_id]])
82
-
83
- if decoding_algorithm is None:
84
- if do_sample:
85
- decoding_algorithm = self._choose_sample_algorithm(temperature, top_k, top_p)
86
- else:
87
- decoding_algorithm = GreedyAlgorithm()
88
-
89
- constraints = self._get_constraints(
90
- inputs=inputs,
91
- eos_token_id=eos_token_id,
92
- pad_token_id=pad_token_id,
93
- max_new_tokens=max_new_tokens,
94
- provided_constraints=provided_constraints,
95
- )
96
-
97
- with self.transformer.h.inference_session(max_length=max_length) as sess:
98
- outputs = []
99
- if torch.any(inputs == pad_token_id): # TODO: move to prepare_inputs
100
- outputs += [inputs[:, : inputs.size(1) - (inputs == pad_token_id).sum(-1).max()]]
101
- else:
102
- outputs += [inputs]
103
- last_token_id = None
104
- seq_idx = outputs[0].size(1)
105
- hypo_ids = torch.arange(outputs[0].size(0))
106
- while True:
107
- embs = self.transformer.word_embeddings(outputs[-1])
108
- intermediate_prompts = None
109
- if self.config.pre_seq_len > 0 and len(outputs) == 1:
110
- prompts, intermediate_prompts = self.transformer.get_prompt(embs.size(0))
111
- embs = torch.cat([prompts, embs], dim=1)
112
- embs = self.transformer.word_embeddings_layernorm(embs)
113
- hidden_state = sess.step(embs, prompts=intermediate_prompts, hypo_ids=hypo_ids)[:, -1]
114
- hidden_state = self.transformer.ln_f(hidden_state)
115
- lm_logits = self.lm_head(hidden_state)
116
-
117
- for constraint in constraints:
118
- lm_logits = constraint(last_token_id, lm_logits, hypo_ids)
119
- last_token_id, hypo_ids = decoding_algorithm(lm_logits)
120
- if seq_idx < inputs.size(1): # TODO: why is it not a constraint?
121
- pad_token_mask = inputs[:, seq_idx : seq_idx + 1] == pad_token_id
122
- last_token_id = (~pad_token_mask) * inputs[
123
- :, seq_idx : seq_idx + 1
124
- ] + pad_token_mask * last_token_id
125
-
126
- if torch.all(last_token_id == eos_token_id):
127
- break
128
-
129
- outputs.append(last_token_id)
130
- seq_idx += 1
131
-
132
- return torch.cat(outputs, dim=-1)
133
-
134
- def greedy_search(
135
- self,
136
- input_ids: torch.LongTensor,
137
- max_length: Optional[int] = None,
138
- pad_token_id: Optional[int] = None,
139
- eos_token_id: Optional[int] = None,
140
- provided_constraints: List[ABCBloomConstraint] = [],
141
- **model_kwargs,
142
- ) -> torch.LongTensor:
143
- """
144
- Generates sequences of token ids for models with a language modeling head. Uses greedy search.
145
-
146
- :param input_ids: The input tokens to the model.
147
- :param max_length: The maximum length of the sequence to generate.
148
- :param pad_token_id: The id of the padding token.
149
- :param eos_token_id: The id of the end of sentence token.
150
- :param provided_constraints: A list of constraints to use.
151
- """
152
- return self.generate(
153
- inputs=input_ids,
154
- max_new_tokens=max_length,
155
- pad_token_id=pad_token_id,
156
- eos_token_id=eos_token_id,
157
- decoding_algorithm=GreedyAlgorithm(),
158
- provided_constraints=provided_constraints,
159
- **model_kwargs,
160
- )
161
-
162
- def sample(
163
- self,
164
- input_ids: torch.LongTensor,
165
- temperature: float = 1.0,
166
- top_k: Optional[int] = None,
167
- top_p: Optional[float] = None,
168
- max_length: Optional[int] = None,
169
- pad_token_id: Optional[int] = None,
170
- eos_token_id: Optional[int] = None,
171
- provided_constraints: List[ABCBloomConstraint] = [],
172
- **model_kwargs,
173
- ) -> torch.LongTensor:
174
- """
175
- Generates sequences of token ids for models with a language modeling head. Uses sampling. Uses multinomial sampling algorithm. If top_k is provided, uses top_k sampling. If top_p is provided, uses nucleus sampling.
176
-
177
- :param: input_ids: The input tokens to the model.
178
- :param: temperature: The temperature to use for sampling.
179
- :param: top_k: The number of samples to use for top_k sampling.
180
- :param: top_p: The probability of using top_p sampling.
181
- :param: max_length: The maximum length of the sequence to generate.
182
- :param: pad_token_id: The id of the padding token.
183
- :param: eos_token_id: The id of the end of sentence token.
184
- :param: provided_constraints: A list of constraints to use.
185
- :param: model_kwargs: Additional kwargs to pass to the model.
186
- """
187
-
188
- return self.generate(
189
- inputs=input_ids,
190
- max_new_tokens=max_length,
191
- pad_token_id=pad_token_id,
192
- eos_token_id=eos_token_id,
193
- decoding_algorithm=self._choose_sample_algorithm(temperature, top_k, top_p),
194
- provided_constraints=provided_constraints,
195
- **model_kwargs,
196
- )
197
-
198
- def beam_search(
199
- self,
200
- input_ids: torch.LongTensor,
201
- max_length: Optional[int] = None,
202
- pad_token_id: Optional[int] = None,
203
- eos_token_id: Optional[int] = None,
204
- provided_constraints: List[ABCBloomConstraint] = [],
205
- **model_kwargs,
206
- ) -> torch.LongTensor:
207
- raise NotImplementedError
208
-
209
- def beam_sample(
210
- self,
211
- input_ids: torch.LongTensor,
212
- max_length: Optional[int] = None,
213
- pad_token_id: Optional[int] = None,
214
- eos_token_id: Optional[int] = None,
215
- provided_constraints: List[ABCBloomConstraint] = [],
216
- **model_kwargs,
217
- ) -> torch.LongTensor:
218
- raise NotImplementedError
219
-
220
- def group_beam_search(
221
- self,
222
- input_ids: torch.LongTensor,
223
- max_length: Optional[int] = None,
224
- pad_token_id: Optional[int] = None,
225
- eos_token_id: Optional[int] = None,
226
- provided_constraints: List[ABCBloomConstraint] = [],
227
- **model_kwargs,
228
- ) -> torch.LongTensor:
229
- raise NotImplementedError
230
-
231
- def _choose_sample_algorithm(
232
- self,
233
- temperature: float = 1.0,
234
- top_k: Optional[int] = None,
235
- top_p: Optional[float] = None,
236
- ) -> DecodingAlgorithm:
237
- if (top_k is not None) and (top_p is not None):
238
- raise ValueError("You have to provide only top_k or top_p for sampling")
239
- if top_k:
240
- return TopKAlgorithm(top_k, temperature)
241
- elif top_p:
242
- return NucleusAlgorithm(top_p, temperature)
243
-
244
- def _get_constraints(
245
- self,
246
- inputs: Optional[torch.Tensor] = None,
247
- eos_token_id: Optional[int] = None,
248
- pad_token_id: Optional[int] = None,
249
- max_new_tokens: Optional[int] = None,
250
- provided_constraints: List[ABCBloomConstraint] = [],
251
- ) -> List[ABCBloomConstraint]:
252
- constraints = []
253
- constraints.extend(provided_constraints)
254
- if max_new_tokens is not None:
255
- constraints.append(MaxNewTokensConstraint(inputs, max_new_tokens, eos_token_id, pad_token_id))
256
- constraints.append(EosConstraint(inputs, eos_token_id, pad_token_id))
257
- return constraints
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/remote_model.py DELETED
@@ -1,197 +0,0 @@
1
- # this code is in active development, interfaces may change
2
- from typing import Optional, Tuple
3
-
4
- import hivemind
5
- import torch
6
- import torch.nn as nn
7
- from hivemind import get_logger, use_hivemind_log_handler
8
- from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
9
-
10
- from src.bloom.model import (
11
- BloomConfig,
12
- BloomForCausalLM,
13
- BloomForSequenceClassification,
14
- BloomModel,
15
- BloomPreTrainedModel,
16
- LMHead,
17
- )
18
- from src.client.remote_generation import RemoteGenerationMixin
19
- from src.client.remote_sequential import RemoteSequential
20
- from src.utils.misc import DUMMY
21
-
22
- use_hivemind_log_handler("in_root_logger")
23
- logger = get_logger(__file__)
24
-
25
-
26
- class DistributedBloomConfig(BloomConfig):
27
- """
28
- A bloom config that contains information about DHT peers.
29
- To create a distributed model, one must provide dht_prefix and either initial_peers or dht.
30
- """
31
-
32
- initial_peers: Tuple[str, ...] = () # a list of initial peers for hivemind DHT
33
- dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
34
- dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
35
- chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
36
- pre_seq_len: int = 0 # a number of tokens for prompt tuning.
37
- tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
38
-
39
-
40
- class DistributedBloomModel(BloomModel):
41
- """BloomModel, but all transformer layers are hosted by the swarm"""
42
-
43
- config_class = DistributedBloomConfig
44
-
45
- def __init__(self, config: DistributedBloomConfig):
46
- assert config.dht_prefix, "Could not find dht_prefix in config, please create model with dht_prefix=..."
47
- assert config.initial_peers or config.dht, "Please specify initial_peers=list(...) or dht=hivemind.DHT(...)"
48
-
49
- n_layer, config.n_layer = config.n_layer, 0 # temporarily set n_layer to 0 to prevent layer initialization
50
- super().__init__(config)
51
- assert len(self.h) == 0
52
- config.n_layer = n_layer
53
-
54
- dht = (
55
- config.dht
56
- if config.dht is not None
57
- else hivemind.DHT(initial_peers=config.initial_peers, client_mode=True, start=True)
58
- )
59
- assert isinstance(dht, hivemind.DHT) and dht.is_alive(), "dht must be a running hivemind.DHT instance"
60
- self.h = RemoteSequential(config, dht, config.dht_prefix)
61
-
62
- # Forbid accumulate grads for embeddings and layernorm
63
- self.set_requires_grad(False)
64
-
65
- if config.tuning_mode and "ptune" in config.tuning_mode:
66
- assert config.pre_seq_len > 0, "The number of prefix tokens must be > 0"
67
- self.pre_seq_len = config.pre_seq_len
68
- self.prompt_embeddings = nn.Embedding(self.pre_seq_len, config.hidden_size)
69
- self.prefix_tokens = torch.arange(self.pre_seq_len).long()
70
-
71
- if config.tuning_mode == "deep_ptune":
72
- self.intermediate_prompt_embeddings = nn.Embedding(
73
- self.pre_seq_len,
74
- config.num_hidden_layers * config.hidden_size
75
- # ^-- TODO: should be num_hidden_layers - 1
76
- )
77
- self.intermediate_prompt_embeddings.weight.data.zero_()
78
- elif config.tuning_mode:
79
- raise NotImplementedError(f"{self.tuning_mode} mode is not supported for now")
80
-
81
- def set_requires_grad(self, value):
82
- for p in self.parameters():
83
- p.requires_grad = value
84
-
85
- def get_prompt(self, batch_size):
86
- prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1)
87
- prefix_tokens = prefix_tokens.to(self.word_embeddings.weight.device)
88
- prompts = self.prompt_embeddings(prefix_tokens)
89
-
90
- if self.config.tuning_mode == "deep_ptune":
91
- intermediate_prompts = self.intermediate_prompt_embeddings(prefix_tokens)
92
- intermediate_prompts = intermediate_prompts.view(
93
- batch_size, self.pre_seq_len, len(self.h), self.config.hidden_size # TODO: should be len(self.h) - 1
94
- )
95
- intermediate_prompts = intermediate_prompts.permute([2, 0, 1, 3])
96
- else:
97
- intermediate_prompts = DUMMY
98
- return prompts, intermediate_prompts
99
-
100
- def forward(
101
- self,
102
- input_ids: Optional[torch.LongTensor] = None,
103
- inputs_embeds: Optional[torch.Tensor] = None,
104
- attention_mask: Optional[torch.Tensor] = None,
105
- **kwargs,
106
- ):
107
- assert attention_mask is None, "DistributedBloomModel does not support attention masks right now"
108
-
109
- for k, v in kwargs.items():
110
- if not (v is None or v is False):
111
- logger.debug(f"Extra keyword arguments are not yet supported (got {k} = {v})")
112
-
113
- if input_ids is not None and inputs_embeds is not None:
114
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
115
- elif input_ids is not None:
116
- input_shape = input_ids.size()
117
- input_ids = input_ids.view(-1, input_shape[-1])
118
- elif inputs_embeds is not None:
119
- input_shape = inputs_embeds.size()[:-1]
120
- else:
121
- raise ValueError("You have to specify either input_ids or inputs_embeds")
122
-
123
- if inputs_embeds is None:
124
- inputs_embeds = self.word_embeddings(input_ids)
125
-
126
- if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
127
- batch_size = inputs_embeds.shape[0]
128
- prompts, intermediate_prompts = self.get_prompt(batch_size)
129
- inputs_embeds = torch.cat([prompts, inputs_embeds], dim=1)
130
-
131
- hidden_states = self.word_embeddings_layernorm(inputs_embeds.float())
132
- output_shape = input_shape + (hidden_states.size(-1),)
133
-
134
- if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
135
- hidden_states = self.h(hidden_states, prompts=intermediate_prompts)
136
- else:
137
- hidden_states = self.h(hidden_states)
138
-
139
- # Remove prefix
140
- if self.config.tuning_mode and "ptune" in self.config.tuning_mode:
141
- hidden_states = hidden_states[:, self.pre_seq_len :]
142
-
143
- # Add last hidden state
144
- hidden_states = self.ln_f(hidden_states)
145
- hidden_states = hidden_states.view(output_shape)
146
- return BaseModelOutputWithPastAndCrossAttentions(
147
- last_hidden_state=hidden_states,
148
- past_key_values=None,
149
- hidden_states=None,
150
- attentions=None,
151
- )
152
-
153
-
154
- class DistributedBloomForCausalLM(RemoteGenerationMixin, BloomForCausalLM):
155
- """DistributedBloomForCausalLM, but all transformer layers are hosted by the swarm"""
156
-
157
- config_class = DistributedBloomConfig
158
-
159
- def __init__(self, config: DistributedBloomConfig):
160
- BloomPreTrainedModel.__init__(self, config)
161
- self.transformer = DistributedBloomModel(config)
162
- self.lm_head = LMHead(config, self.transformer.word_embeddings)
163
-
164
- # Initialize weights and apply final processing
165
- self.post_init()
166
-
167
- def get_input_embeddings(self):
168
- return self.transformer.word_embeddings
169
-
170
- def get_output_embeddings(self):
171
- if self.config.tie_word_embeddings:
172
- return None
173
- return self.lm_head
174
-
175
- def set_input_embeddings(self, new_embeddings: nn.Embedding):
176
- assert isinstance(new_embeddings, nn.Embedding)
177
- self.transformer.word_embeddings = self.lm_head.word_embeddings = new_embeddings
178
- assert self.lm_head.bias is None or len(self.lm_head.bias) == new_embeddings.num_embeddings
179
-
180
- def set_output_embeddings(self, new_lm_head: nn.Linear):
181
- with torch.no_grad():
182
- self.lm_head.word_embeddings.weight[...] = new_lm_head.weight
183
- self.lm_head.bias[...] = new_lm_head.bias
184
-
185
-
186
- class DistributedBloomForSequenceClassification(BloomForSequenceClassification):
187
- config_class = DistributedBloomConfig
188
-
189
- def __init__(self, config: DistributedBloomConfig):
190
- BloomPreTrainedModel.__init__(self, config)
191
- self.num_labels = config.num_labels
192
-
193
- self.transformer = DistributedBloomModel(config)
194
- self.score = nn.Linear(config.hidden_size, config.num_labels, bias=False)
195
-
196
- # Initialize weights and apply final processing
197
- self.post_init()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/remote_sequential.py DELETED
@@ -1,103 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from typing import Optional, Union
4
-
5
- import torch
6
- from hivemind import DHT, P2P, get_logger, use_hivemind_log_handler
7
- from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
8
- from torch import nn
9
-
10
- import src
11
- from src.client.inference_session import RemoteSequentialInferenceSession
12
- from src.client.sequence_manager import RemoteSequenceManager
13
- from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
14
- from src.data_structures import UID_DELIMITER
15
- from src.utils.misc import DUMMY
16
-
17
- use_hivemind_log_handler("in_root_logger")
18
- logger = get_logger(__file__)
19
-
20
-
21
- class RemoteSequential(nn.Module):
22
- """
23
- A sequence of transformer blocks hosted by the swarm.
24
- """
25
-
26
- def __init__(
27
- self,
28
- config: src.DistributedBloomConfig,
29
- dht: DHT,
30
- dht_prefix: Optional[str] = None,
31
- p2p: Optional[P2P] = None,
32
- sequence_manager: Optional[RemoteSequenceManager] = None,
33
- ):
34
- logger.warning(f"{self.__class__.__name__} is in active development; expect adventures")
35
- super().__init__()
36
- self.config = config
37
- self.dht = dht
38
- self.dht_prefix = dht_prefix or config.dht_prefix
39
- self.p2p = RemoteExpertWorker.run_coroutine(dht.replicate_p2p()) if p2p is None else p2p
40
-
41
- num_blocks = self.config.n_layer if sequence_manager is None else len(sequence_manager)
42
- block_uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(num_blocks)]
43
- if sequence_manager is None:
44
- logger.debug(f"Creating new sequence manager for block uids: {block_uids}")
45
- self.sequence_manager = RemoteSequenceManager(dht, block_uids, self.p2p)
46
- self.is_subsequence = False
47
- else:
48
- logger.debug(f"Reusing sequence manager with {len(sequence_manager)} modules")
49
- self.sequence_manager = sequence_manager
50
- assert isinstance(sequence_manager.block_uids, list)
51
- self.is_subsequence = self.sequence_manager.block_uids != block_uids
52
-
53
- def forward(self, inputs: torch.Tensor, prompts: torch.Tensor = DUMMY):
54
- outputs = _RemoteSequentialAutogradFunction.apply(inputs, prompts, self.sequence_manager)
55
- return outputs
56
-
57
- def __getitem__(self, ix: Union[int, slice]) -> RemoteSequential:
58
- assert isinstance(ix, (int, slice))
59
- if isinstance(ix, int):
60
- return RemoteTransformerBlock(
61
- self.config,
62
- self.dht,
63
- dht_prefix=self.dht_prefix,
64
- p2p=self.p2p,
65
- sequence_manager=self.sequence_manager[ix],
66
- )
67
- else:
68
- return RemoteSequential(
69
- self.config,
70
- self.dht,
71
- dht_prefix=self.dht_prefix,
72
- p2p=self.p2p,
73
- sequence_manager=self.sequence_manager[ix],
74
- )
75
-
76
- def __iter__(self):
77
- for block_index in range(len(self)):
78
- yield self[block_index]
79
-
80
- def __len__(self):
81
- return len(self.sequence_manager)
82
-
83
- def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession:
84
- self.sequence_manager.update_()
85
- return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs)
86
-
87
- def extra_repr(self) -> str:
88
- return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"
89
-
90
-
91
- class RemoteTransformerBlock(RemoteSequential):
92
- """Single transformer block hosted by swarm
93
-
94
- This class is deprecated and kept for backward compatibility.
95
- It will be removed soon in favor of using ``RemoteSequential`` directly.
96
- """
97
-
98
- def __init__(self, *args, **kwargs):
99
- super().__init__(*args, **kwargs)
100
- assert len(self) == 1, "Remote Block is a sequence size 1"
101
-
102
- def extra_repr(self):
103
- return f"{self.sequence_manager.block_uids[0]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/sequence_manager.py DELETED
@@ -1,153 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import random
4
- import threading
5
- from typing import List, Optional, Sequence, Tuple, Union
6
-
7
- from hivemind import DHT, P2P, DHTExpiration, MSGPackSerializer
8
- from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
9
- from hivemind.proto import runtime_pb2
10
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
11
-
12
- from src.client.spending_policy import NoSpendingPolicy
13
- from src.data_structures import ModuleUID, RemoteModuleInfo, RemoteSpanInfo, ServerState
14
- from src.dht_utils import get_remote_module_infos
15
- from src.server.handler import TransformerConnectionHandler
16
-
17
- use_hivemind_log_handler("in_root_logger")
18
- logger = get_logger(__file__)
19
-
20
-
21
- class RemoteSequenceManager:
22
- """
23
- Keeps and updates the meta-information about which peers host which blocks.
24
- In future, this class is intended to maintain latency statistics, ban non-responsive peers, etc.
25
- """
26
-
27
- def __init__(self, dht: DHT, block_uids: Sequence[ModuleUID], p2p: P2P, max_retries: int = 3):
28
- assert len(block_uids) > 0, "Sequences must contain at least one block"
29
- self.dht, self.p2p = dht, p2p
30
- self.block_uids: List[ModuleUID] = list(block_uids)
31
- self.block_infos: List[Optional[RemoteModuleInfo]] = [None] * len(self.block_uids)
32
- self.spans_by_priority: List[RemoteSpanInfo] = [] # sorted from best to worst
33
- self.spans_containing_block: Tuple[List[RemoteSpanInfo], ...] = tuple([] for _ in range(len(self.block_uids)))
34
- self.last_update_time: DHTExpiration = -float("inf")
35
- self.max_retries = max_retries
36
- self._rpc_info = None
37
- self.lock_changes = threading.Lock()
38
- self.update_()
39
-
40
- for uid, info in zip(self.block_uids, self.block_infos):
41
- assert info is not None, f"Found no remote peers for block {uid}"
42
- assert self.spans_by_priority and self.spans_containing_block
43
-
44
- def make_sequence(self, start_index: int = 0, end_index: Optional[int] = None) -> List[RemoteSpanInfo]:
45
- """
46
- Form a sequence of remote servers that collectively serve all consecutive layers
47
-
48
- :param start_index: optional index of the first module in a sequence, default = the first of block_uids
49
- :param end_index: optional index of the last module (non-inclusive), default = after last of block uids
50
- """
51
- end_index = end_index if end_index is not None else len(self.block_uids)
52
- span_sequence = []
53
- current_index = start_index
54
- while current_index < end_index:
55
- candidate_spans = self.spans_containing_block[current_index]
56
- chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
57
-
58
- assert chosen_span.start <= current_index < chosen_span.end
59
- span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
60
- current_index = chosen_span.end
61
-
62
- return span_sequence
63
-
64
- def __getitem__(self, ix: Union[int, slice]) -> RemoteSequenceManager:
65
- """Get a RemoteSequenceManager for a sub-sequence of blocks"""
66
- assert isinstance(ix, (int, slice))
67
- if not isinstance(ix, slice):
68
- ix = slice(int(ix), int(ix) + 1, 1)
69
- with self.lock_changes:
70
- subseq = RemoteSequenceManager(self.dht, self.block_uids[ix], self.p2p)
71
- subseq.block_infos = self.block_infos[ix]
72
- subseq.spans_by_priority, subseq.spans_containing_block = subseq.compute_spans(subseq.block_infos)
73
- subseq.last_update_time = self.last_update_time
74
- return subseq
75
-
76
- def update_(self):
77
- with self.lock_changes:
78
- self.update_block_infos_()
79
- self.spans_by_priority, self.spans_containing_block = self.compute_spans(self.block_infos)
80
-
81
- def update_block_infos_(self):
82
- new_block_infos = get_remote_module_infos(self.dht, self.block_uids, expiration_time=float("inf"))
83
- assert len(new_block_infos) == len(self.block_uids)
84
- for block_index, (uid, info) in enumerate(zip(self.block_uids, new_block_infos)):
85
- if info is None:
86
- logger.warning(f"Found no block info for block {uid}")
87
- continue
88
- if not isinstance(info, RemoteModuleInfo):
89
- logger.warning(f"Unexpected dht entry type for {uid}: {info}")
90
- if not info.servers:
91
- logger.warning(f"Found no active peers for block {uid}")
92
- if info.uid != uid:
93
- logger.warning(f"The DHT entry for {uid} actually points to {info.uid}")
94
- self.block_infos[block_index] = info
95
-
96
- @staticmethod
97
- def compute_spans(block_infos: Sequence[RemoteModuleInfo]):
98
- closed_spans = []
99
- active_spans = {}
100
- for block_index, info in enumerate(block_infos):
101
- if info is not None:
102
- for peer_id, server in info.servers.items():
103
- if server.state != ServerState.ONLINE:
104
- continue
105
- if peer_id not in active_spans:
106
- active_spans[peer_id] = RemoteSpanInfo(start=block_index, end=block_index + 1, peer_id=peer_id)
107
- else: # peer_id in active_spans
108
- active_spans[peer_id].end = block_index + 1
109
-
110
- for peer_id in list(active_spans.keys()):
111
- if (
112
- info is None
113
- or peer_id not in info.servers
114
- or info.servers[peer_id].state != ServerState.ONLINE
115
- or block_index == len(block_infos) - 1
116
- ):
117
- closed_spans.append(active_spans.pop(peer_id))
118
- assert not active_spans, f"spans: {active_spans}"
119
-
120
- closed_spans.sort(key=lambda span: span.end - span.start, reverse=True)
121
-
122
- spans_containing_block = tuple(list() for _ in range(len(block_infos)))
123
- for span in closed_spans:
124
- for block_index in range(span.start, span.end):
125
- spans_containing_block[block_index].append(span)
126
-
127
- return closed_spans, spans_containing_block
128
-
129
- def __len__(self):
130
- return len(self.block_uids)
131
-
132
- @property
133
- def rpc_info(self):
134
- """Return the rpc_info queried from one of the servers that hold the first block"""
135
- if self._rpc_info is None:
136
- retries = 0
137
- for i in range(self.max_retries):
138
- try:
139
- self.update_()
140
- peer_id = random.choice(list(self.block_infos[0].servers.keys()))
141
- stub = TransformerConnectionHandler.get_stub(self.p2p, peer_id)
142
- outputs = RemoteExpertWorker.run_coroutine(
143
- stub.rpc_info(runtime_pb2.ExpertUID(uid=self.block_uids[0]))
144
- )
145
- self._rpc_info = MSGPackSerializer.loads(outputs.serialized_info)
146
- break
147
- except Exception as e:
148
- retries += 1
149
- if retries >= self.max_retries:
150
- raise e
151
- else:
152
- logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
153
- return self._rpc_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/sequential_autograd.py DELETED
@@ -1,204 +0,0 @@
1
- """
2
- A PyTorch autograd function that runs forward/backward on a sequence of remote servers in a fault-tolerant manner
3
- """
4
- import asyncio
5
- import itertools
6
- import logging
7
- from typing import List, Optional, Sequence, Tuple
8
-
9
- import torch
10
- from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
11
-
12
- from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
13
- from src.client.sequence_manager import RemoteSequenceManager
14
- from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
15
- from src.server.handler import TransformerConnectionHandler
16
- from src.utils.misc import DUMMY, is_dummy
17
-
18
- MAX_TOKENS_IN_BATCH = 1024
19
-
20
-
21
- async def sequential_forward(
22
- inputs: torch.Tensor,
23
- prompts: torch.Tensor,
24
- sequence_manager: RemoteSequenceManager,
25
- start_index: int = 0,
26
- end_index: Optional[int] = None,
27
- min_backoff: float = 1.0,
28
- ) -> Tuple[torch.Tensor, Sequence[torch.Tensor], Sequence[RemoteSpanInfo]]:
29
- """
30
- Constructs a routing path from <start_index> to <end_index>.
31
- Performs chained forward for each subsequence of blocks on the path.
32
- If some subsequence fails, reconstructs the remaining path and tries to finish the forward.
33
- """
34
-
35
- assert isinstance(inputs, torch.Tensor) and inputs.ndim == 3, f"{type(inputs)}: {inputs.ndim}"
36
-
37
- end_index = end_index if end_index is not None else len(sequence_manager.block_uids)
38
- assert start_index >= 0 and end_index <= len(sequence_manager.block_uids)
39
- assert is_dummy(prompts) or len(prompts) == len(
40
- sequence_manager.block_uids
41
- ) # should be n_layers - 1 but add extra prompts for convenience
42
-
43
- sequences = sequence_manager.make_sequence(start_index, end_index)
44
- intermediate_inputs = []
45
- done_sequences = []
46
- outputs = inputs
47
-
48
- while len(sequences) > 0:
49
- for attempt_no in itertools.count():
50
- span = sequences.pop(0)
51
- span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
52
- try:
53
- stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
54
- inputs_and_prompts = [inputs, prompts[span.start : span.end]]
55
-
56
- (outputs,) = await run_remote_forward(span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts)
57
-
58
- assert isinstance(outputs, torch.Tensor)
59
- assert outputs.shape == inputs.shape, f"Expected output {inputs.shape}, got {outputs.shape}"
60
-
61
- # Save intermediate inputs and subsequences if the forward is already done for them
62
- intermediate_inputs.append(inputs)
63
- done_sequences.append(span)
64
-
65
- inputs = outputs
66
- break
67
- except Exception as e:
68
- logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True)
69
- await asyncio.sleep(min_backoff * 2**attempt_no)
70
-
71
- backup_sequences = sequence_manager.make_sequence(span.start)
72
- assert backup_sequences[0].start == span.start
73
- sequences = backup_sequences
74
-
75
- return outputs, intermediate_inputs, done_sequences
76
-
77
-
78
- async def sequential_backward(
79
- grad_outputs: Sequence[torch.Tensor],
80
- intermediate_inputs: List[torch.Tensor],
81
- prompts: torch.Tensor,
82
- forward_sequences: List[RemoteSpanInfo],
83
- sequence_manager: RemoteSequenceManager,
84
- min_backoff: float = 1.0,
85
- ) -> Sequence[torch.Tensor]:
86
- """
87
- Performs chained backward for each forward subsequence.
88
- If some subsequence fails, reconstructs the particular sub-path and recovers the backward.
89
- """
90
- assert len(intermediate_inputs) == len(forward_sequences)
91
-
92
- grad_prompts_reversed = []
93
- while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
94
- for attempt_no in itertools.count():
95
- inputs = intermediate_inputs.pop(-1)
96
- span = forward_sequences.pop(-1)
97
- span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
98
- try:
99
- stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
100
- grad_outputs, *span_grad_prompts = await run_remote_backward(
101
- span_uids, stub, sequence_manager.rpc_info, inputs, grad_outputs, prompts[span.start : span.end]
102
- )
103
- grad_outputs = [grad_outputs]
104
- grad_prompts_reversed.extend(span_grad_prompts)
105
- break
106
- except Exception as e:
107
- logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True)
108
- await asyncio.sleep(min_backoff * 2**attempt_no)
109
-
110
- _, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward(
111
- inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
112
- )
113
- assert len(intermediate_inputs) == len(forward_sequences)
114
- assert backup_forward_sequences[0].start == span.start
115
- assert backup_forward_sequences[-1].end == span.end
116
-
117
- forward_sequences.extend(backup_forward_sequences)
118
- intermediate_inputs.extend(backup_intermediate_inputs)
119
-
120
- # For now, we do not support mixed dummy and grad prompts
121
- # Concat in num_layer dimension
122
- grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else None
123
- return grad_outputs, grad_prompts
124
-
125
-
126
- async def _gather_forward(input_batches, prompt_batches, sequence_manager):
127
- """Wrapper for asyncio.gather to perform parallel sequential forwards"""
128
- return await asyncio.gather(
129
- *[
130
- sequential_forward(input_batch, prompt_batch, sequence_manager)
131
- for input_batch, prompt_batch in zip(input_batches, prompt_batches)
132
- ]
133
- )
134
-
135
-
136
- async def _gather_backward(
137
- grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences, sequence_manager
138
- ):
139
- """Wrapper for asyncio.gather to perform parallel sequential backwards"""
140
- return await asyncio.gather(
141
- *[
142
- sequential_backward((grad_output,), input_batch, prompt_batch, spans, sequence_manager)
143
- for grad_output, input_batch, prompt_batch, spans in zip(
144
- grad_output_batches, intermediate_input_batches, prompt_batches, forward_sequences
145
- )
146
- ]
147
- )
148
-
149
-
150
- class _RemoteSequentialAutogradFunction(torch.autograd.Function):
151
- """
152
- PyTorch autograd function that provides forward and backward calls for the entire sequence of remote transformer blocks.
153
- This function splits input data into batches with <MAX_TOKENS_IN_BATCH> and performs efficient parallel processing.
154
- """
155
-
156
- @staticmethod
157
- def forward(ctx, inputs: torch.Tensor, prompts: torch.Tensor, sequence_manager: RemoteSequenceManager):
158
- batch_size = max(MAX_TOKENS_IN_BATCH // inputs.shape[1], 1)
159
- input_batches: Sequence[torch.Tensor] = inputs.detach().split(batch_size)
160
- if is_dummy(prompts):
161
- prompt_batches = [DUMMY] * len(input_batches)
162
- else:
163
- prompt_batches: Sequence[torch.Tensor] = prompts.detach().split(batch_size, dim=1)
164
-
165
- sequence_manager.rpc_info # lazy init
166
- outputs = RemoteExpertWorker.run_coroutine(_gather_forward(input_batches, prompt_batches, sequence_manager))
167
- assert len(outputs) == len(input_batches)
168
-
169
- output_batches = [output[0] for output in outputs]
170
- intemediate_input_batches = [output[1] for output in outputs]
171
- sequences_for_batches = [output[2] for output in outputs]
172
-
173
- ctx.prompt_batches = prompt_batches
174
- ctx.sequence_manager = sequence_manager
175
- ctx.intemediate_input_batches = intemediate_input_batches
176
- ctx.sequences_for_batches = sequences_for_batches
177
- return torch.cat(output_batches, dim=0)
178
-
179
- @staticmethod
180
- def backward(ctx, grad_outputs: torch.Tensor):
181
- intermediate_input_batches: List[Sequence[torch.Tensor]] = ctx.intemediate_input_batches
182
- forward_sequences: List[Sequence[RemoteSpanInfo]] = ctx.sequences_for_batches
183
- ctx.sequence_manager.rpc_info # lazy init
184
-
185
- batch_size = max(MAX_TOKENS_IN_BATCH // grad_outputs.shape[1], 1)
186
- grad_output_batches: Sequence[torch.Tensor] = grad_outputs.split(batch_size)
187
- assert len(intermediate_input_batches) == len(grad_output_batches) == len(forward_sequences)
188
-
189
- outputs = RemoteExpertWorker.run_coroutine(
190
- _gather_backward(
191
- grad_output_batches,
192
- intermediate_input_batches,
193
- ctx.prompt_batches,
194
- forward_sequences,
195
- ctx.sequence_manager,
196
- )
197
- )
198
- grad_input_batches = [output[0][0] for output in outputs]
199
- grad_prompt_batches = [output[1] for output in outputs]
200
-
201
- grad_inputs = torch.cat(grad_input_batches, dim=0)
202
- dummy_grad_prompts = [grad_prompt is None for grad_prompt in grad_prompt_batches]
203
- grad_prompts = torch.cat(grad_prompt_batches, dim=1) if not any(dummy_grad_prompts) else None
204
- return (grad_inputs, grad_prompts, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/client/spending_policy.py DELETED
@@ -1,14 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- from hivemind.proto.runtime_pb2 import ExpertRequest
4
-
5
-
6
- class SpendingPolicyBase(ABC):
7
- @abstractmethod
8
- def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
9
- pass
10
-
11
-
12
- class NoSpendingPolicy(SpendingPolicyBase):
13
- def get_points(self, request: ExpertRequest, method_name: str, *args, **kwargs) -> float:
14
- return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/data_structures.py DELETED
@@ -1,41 +0,0 @@
1
- from dataclasses import dataclass
2
- from enum import Enum
3
- from typing import Any, Dict
4
-
5
- from hivemind import PeerID
6
-
7
- ModuleUID = str
8
- UID_DELIMITER = "." # delimits parts of one module uid, e.g. "bloom.transformer.h.4.self_attention"
9
- CHAIN_DELIMITER = " " # delimits multiple uids in a sequence, e.g. "bloom.layer3 bloom.layer4"
10
-
11
-
12
- class ServerState(Enum):
13
- OFFLINE = 0
14
- JOINING = 1
15
- ONLINE = 2
16
-
17
-
18
- @dataclass
19
- class ServerInfo:
20
- state: ServerState
21
- throughput: float
22
-
23
-
24
- @dataclass
25
- class RemoteModuleInfo:
26
- """A remote module that is served by one or more servers"""
27
-
28
- uid: ModuleUID
29
- servers: Dict[PeerID, ServerInfo]
30
-
31
-
32
- @dataclass
33
- class RemoteSpanInfo:
34
- """A chain of remote blocks served by one specific remote peer"""
35
-
36
- start: int
37
- end: int
38
- peer_id: PeerID
39
-
40
-
41
- RPCInfo = Dict[str, Any]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/dht_utils.py DELETED
@@ -1,180 +0,0 @@
1
- """
2
- Utilities for declaring and retrieving active model layers using a shared DHT.
3
- """
4
- from __future__ import annotations
5
-
6
- import math
7
- from functools import partial
8
- from typing import Dict, List, Optional, Sequence, Union
9
-
10
- from hivemind.dht import DHT, DHTNode, DHTValue
11
- from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
12
- from hivemind.p2p import PeerID
13
- from hivemind.utils import DHTExpiration, MPFuture, get_dht_time, get_logger, use_hivemind_log_handler
14
-
15
- import src
16
- from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ModuleUID, RemoteModuleInfo, ServerInfo, ServerState
17
-
18
- use_hivemind_log_handler("in_root_logger")
19
- logger = get_logger(__file__)
20
-
21
-
22
- def declare_active_modules(
23
- dht: DHT,
24
- uids: Sequence[ModuleUID],
25
- expiration_time: DHTExpiration,
26
- state: ServerState,
27
- throughput: float,
28
- wait: bool = True,
29
- ) -> Union[Dict[ModuleUID, bool], MPFuture[Dict[ModuleUID, bool]]]:
30
- """
31
- Declare that your node serves the specified modules; update timestamps if declared previously
32
-
33
- :param uids: a list of module ids to declare
34
- :param wait: if True, awaits for declaration to finish, otherwise runs in background
35
- :param throughput: specify your performance in terms of compute throughput
36
- :param expiration_time: declated modules will be visible for this many seconds
37
- :returns: if wait, returns store status for every key (True = store succeeded, False = store rejected)
38
- """
39
- if isinstance(uids, str):
40
- uids = [uids]
41
- if not isinstance(uids, list):
42
- uids = list(uids)
43
- for uid in uids:
44
- assert isinstance(uid, ModuleUID) and UID_DELIMITER in uid and CHAIN_DELIMITER not in uid
45
- return dht.run_coroutine(
46
- partial(
47
- _declare_active_modules,
48
- uids=uids,
49
- expiration_time=expiration_time,
50
- state=state,
51
- throughput=throughput,
52
- ),
53
- return_future=not wait,
54
- )
55
-
56
-
57
- async def _declare_active_modules(
58
- dht: DHT,
59
- node: DHTNode,
60
- uids: List[ModuleUID],
61
- expiration_time: DHTExpiration,
62
- state: ServerState,
63
- throughput: float,
64
- ) -> Dict[ModuleUID, bool]:
65
- num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
66
- return await node.store_many(
67
- keys=uids,
68
- subkeys=[dht.peer_id.to_base58()] * len(uids),
69
- values=[(state.value, throughput)] * len(uids),
70
- expiration_time=expiration_time,
71
- num_workers=num_workers,
72
- )
73
-
74
-
75
- def get_remote_sequence(
76
- dht: DHT,
77
- start: int,
78
- stop: int,
79
- config: src.DistributedBloomConfig,
80
- dht_prefix: Optional[str] = None,
81
- return_future: bool = False,
82
- ) -> Union[src.RemoteSequential, MPFuture]:
83
- return RemoteExpertWorker.run_coroutine(
84
- _get_remote_sequence(dht, start, stop, config, dht_prefix), return_future=return_future
85
- )
86
-
87
-
88
- async def _get_remote_sequence(
89
- dht: DHT,
90
- start: int,
91
- stop: int,
92
- config: src.DistributedBloomConfig,
93
- dht_prefix: Optional[str] = None,
94
- ) -> src.RemoteSequential:
95
- uids = [f"{config.dht_prefix}{UID_DELIMITER}{i}" for i in range(start, stop)]
96
- p2p = await dht.replicate_p2p()
97
- manager = src.RemoteSequenceManager(dht, uids, p2p)
98
- return src.RemoteSequential(config, dht, dht_prefix, p2p, manager)
99
-
100
-
101
- def get_remote_module(
102
- dht: DHT,
103
- uid_or_uids: Union[ModuleUID, List[ModuleUID]],
104
- config: src.DistributedBloomConfig,
105
- dht_prefix: Optional[str] = None,
106
- return_future: bool = False,
107
- ) -> Union[Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]], MPFuture]:
108
- """
109
- :param uid_or_uids: find one or more modules with these ids from across the DHT
110
- :param config: model config, usualy taken by .from_pretrained(MODEL_NAME)
111
- :param return_future: if False (default), return when finished. Otherwise return MPFuture and run in background.
112
- :returns: a list of [RemoteTransformerBlock]
113
- """
114
- return RemoteExpertWorker.run_coroutine(
115
- _get_remote_module(dht, uid_or_uids, config, dht_prefix), return_future=return_future
116
- )
117
-
118
-
119
- async def _get_remote_module(
120
- dht: DHT,
121
- uid_or_uids: Union[ModuleUID, List[ModuleUID]],
122
- config: src.DistributedBloomConfig,
123
- dht_prefix: Optional[str] = None,
124
- ) -> Union[src.RemoteTransformerBlock, List[src.RemoteTransformerBlock]]:
125
- single_uid = isinstance(uid_or_uids, ModuleUID)
126
- uids = [uid_or_uids] if single_uid else uid_or_uids
127
- p2p = await dht.replicate_p2p()
128
- managers = (src.RemoteSequenceManager(dht, [uid], p2p) for uid in uids)
129
- modules = [
130
- src.RemoteTransformerBlock(config, dht, dht_prefix=dht_prefix, p2p=p2p, sequence_manager=m) for m in managers
131
- ]
132
- return modules[0] if single_uid else modules
133
-
134
-
135
- def get_remote_module_infos(
136
- dht: DHT,
137
- uid_or_uids: Union[ModuleUID, List[ModuleUID]],
138
- expiration_time: Optional[DHTExpiration] = None,
139
- ) -> List[Optional[RemoteModuleInfo]]:
140
- single_uid = isinstance(uid_or_uids, ModuleUID)
141
- uids = [uid_or_uids] if single_uid else uid_or_uids
142
- infos = dht.run_coroutine(
143
- partial(_get_remote_module_infos, uids=uids, expiration_time=expiration_time), return_future=False
144
- )
145
- return infos[0] if single_uid else infos
146
-
147
-
148
- async def _get_remote_module_infos(
149
- dht: DHT, node: DHTNode, uids: List[ModuleUID], expiration_time: Optional[DHTExpiration]
150
- ) -> List[Optional[RemoteModuleInfo]]:
151
- if expiration_time is None:
152
- expiration_time = get_dht_time()
153
- num_workers = len(uids) if dht.num_workers is None else min(len(uids), dht.num_workers)
154
- found: Dict[ModuleUID, DHTValue] = await node.get_many(uids, expiration_time, num_workers=num_workers)
155
-
156
- modules: List[Optional[RemoteModuleInfo]] = [None] * len(uids)
157
- for i, uid in enumerate(uids):
158
- metadata = found[uid]
159
- if metadata is None or not isinstance(metadata.value, dict):
160
- if metadata is not None:
161
- logger.error(f"Incorrect metadata for {uid}: {metadata}")
162
- continue
163
- servers = {}
164
- for peer_id, server_info in metadata.value.items():
165
- try:
166
- peer_id = PeerID.from_base58(peer_id)
167
- state, throughput = server_info.value
168
- if not (
169
- isinstance(state, int)
170
- and isinstance(throughput, float)
171
- and math.isfinite(throughput)
172
- and throughput >= 0.0
173
- ):
174
- raise ValueError(f"Invalid server info: {server_info}")
175
- servers[peer_id] = ServerInfo(ServerState(state), throughput)
176
- except (TypeError, ValueError) as e:
177
- logger.error(f"Incorrect peer entry for uid={uid}, peer_id={peer_id}: {e}")
178
- if servers:
179
- modules[i] = RemoteModuleInfo(uid, servers)
180
- return modules
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/__init__.py DELETED
File without changes
petals/src/server/backend.py DELETED
@@ -1,84 +0,0 @@
1
- """Code for serving bloom blocks via hivemind-server"""
2
- from typing import Any, Dict, Optional, Sequence, Tuple
3
-
4
- import torch
5
- from hivemind import BatchTensorDescriptor, use_hivemind_log_handler
6
- from hivemind.moe.server.module_backend import ModuleBackend
7
- from hivemind.utils import get_logger
8
-
9
- from src.bloom.from_pretrained import BloomBlock
10
- from src.server.cache import MemoryCache
11
- from src.server.task_pool import PrioritizedTaskPool
12
- from src.utils.misc import is_dummy
13
-
14
- use_hivemind_log_handler("in_root_logger")
15
- logger = get_logger(__file__)
16
-
17
-
18
- class TransformerBackend(ModuleBackend):
19
- """A wrapper for BloomBlock that can process requests for bloom layer forward, forward_incremental, and backward"""
20
-
21
- def __init__(self, *args, memory_cache: MemoryCache, backend_dtype: Optional[torch.dtype] = None, **kwargs):
22
- super().__init__(*args, **kwargs)
23
- assert isinstance(self.module, BloomBlock)
24
- self.memory_cache = memory_cache
25
- for name, param in self.module.named_parameters():
26
- assert not param.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
27
- for name, buf in self.module.named_buffers():
28
- assert not buf.requires_grad, f"Bloom layer parameters must not accumulate gradients, but {name} does"
29
-
30
- max_batch_size = self.forward_pool.max_batch_size
31
- self.inference_pool = PrioritizedTaskPool(
32
- self.inference_step, max_batch_size=max_batch_size, name=f"{self.name}_inference"
33
- )
34
- self.forward_pool = PrioritizedTaskPool(
35
- self.forward, max_batch_size=max_batch_size, name=f"{self.name}_forward"
36
- )
37
- self.backward_pool = PrioritizedTaskPool(
38
- self.backward, max_batch_size=max_batch_size, name=f"{self.name}_backward"
39
- )
40
- self.dtype = backend_dtype if backend_dtype else self.module.input_layernorm.weight.dtype
41
- self.inference_schema = (
42
- (
43
- *self.args_schema,
44
- BatchTensorDescriptor((), dtype=self.dtype),
45
- BatchTensorDescriptor((), dtype=torch.int64),
46
- ),
47
- self.kwargs_schema,
48
- )
49
-
50
- def inference_step(self, cache_metadata: torch.IntTensor, *inputs: torch.Tensor) -> Tuple[torch.Tensor, ...]:
51
- with torch.inference_mode():
52
- attention_cache_handle = int(cache_metadata[0, 0].item())
53
- prefix_length = int(cache_metadata[0, 1].item())
54
- (hidden_states, hypo_ids) = inputs
55
- assert (
56
- hidden_states.ndim == 3
57
- ), "expected hidden states to be 3-dimensional: [batch_size, seq_len, hid_size]"
58
-
59
- with self.memory_cache.use_cache(attention_cache_handle) as cache:
60
- assert isinstance(self.module, BloomBlock) and cache.shape[0] == 2 and cache.ndim == 5
61
- if not is_dummy(hypo_ids):
62
- cache[:, :] = cache[:, hypo_ids] # in-place reorder cache by hypo ids
63
- layer_past = past_k, past_v = cache[0, :, :prefix_length], cache[1, :, :prefix_length]
64
- logger.debug(f"Metadata: {cache_metadata}, past_k.shape={past_k.shape}, past_v.shape={past_v.shape}")
65
- hidden_states, (new_k, new_v) = self.module.forward(
66
- hidden_states, layer_past=layer_past, use_cache=True
67
- )
68
-
69
- # todo remove these asserts once we pass all tests
70
- new_length = new_v.shape[1]
71
- assert new_length > prefix_length
72
- assert new_k.shape[0] == past_k.shape[0] and new_v.shape[0] == past_v.shape[0]
73
- assert new_k.shape[1] == new_length and new_v.shape[1] == new_length
74
- assert new_k.shape[2:] == past_k.shape[2:] and new_v.shape[2:] == past_v.shape[2:]
75
- cache[0, :, prefix_length:new_length, :] = new_k[:, prefix_length:new_length]
76
- cache[1, :, prefix_length:new_length, :] = new_v[:, prefix_length:new_length]
77
- return (hidden_states,)
78
-
79
- def get_pools(self) -> Sequence[PrioritizedTaskPool]:
80
- return self.forward_pool, self.backward_pool, self.inference_pool
81
-
82
- def get_info(self) -> Dict[str, Any]:
83
- """Get module parameters and stats. Used by RemoteExpert to check shapes and for DMoE orchestration."""
84
- return dict(super().get_info(), inference_schema=self.inference_schema)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/block_selection.py DELETED
@@ -1,106 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, List, Optional, Tuple
3
-
4
- import numpy as np
5
- from hivemind import PeerID, get_logger
6
-
7
- from src.data_structures import RemoteModuleInfo, ServerState
8
-
9
- __all__ = ["choose_best_blocks", "should_choose_other_blocks"]
10
-
11
- logger = get_logger(__file__)
12
-
13
-
14
- @dataclass
15
- class Span:
16
- start: int
17
- end: int
18
- throughput: float
19
-
20
- @property
21
- def length(self):
22
- return self.end - self.start
23
-
24
- def move_to(self, new_start: int) -> None:
25
- self.start, self.end = new_start, new_start + self.length
26
-
27
-
28
- def _compute_spans(module_infos: List[Optional[RemoteModuleInfo]]) -> Tuple[Dict[PeerID, Span], np.ndarray]:
29
- spans = {}
30
- throughputs = np.zeros(len(module_infos))
31
- for block, module in enumerate(module_infos):
32
- if module is None:
33
- continue
34
-
35
- for peer_id, server in module.servers.items():
36
- if server.state == ServerState.OFFLINE:
37
- continue
38
-
39
- if peer_id in spans:
40
- spans[peer_id].start = min(spans[peer_id].start, block)
41
- spans[peer_id].end = max(spans[peer_id].start, block + 1)
42
- else:
43
- spans[peer_id] = Span(start=block, end=block + 1, throughput=server.throughput)
44
-
45
- throughputs[block] += server.throughput
46
-
47
- return spans, throughputs
48
-
49
-
50
- def _choose_best_start(throughputs: np.ndarray, num_blocks: int, cur_start: Optional[int]) -> int:
51
- options = (
52
- (sorted(throughputs[i : i + num_blocks]), i != cur_start, i)
53
- for i in range(0, len(throughputs) - num_blocks + 1)
54
- )
55
- return min(options)[-1]
56
-
57
-
58
- def choose_best_blocks(num_blocks: int, module_infos: List[Optional[RemoteModuleInfo]]) -> List[int]:
59
- _, throughputs = _compute_spans(module_infos)
60
- start = _choose_best_start(throughputs, num_blocks, None)
61
- return list(range(start, start + num_blocks))
62
-
63
-
64
- def should_choose_other_blocks(
65
- local_peer_id: PeerID, module_infos: List[Optional[RemoteModuleInfo]], balance_quality: float
66
- ) -> bool:
67
- if balance_quality > 1.0:
68
- return True # Forces rebalancing on each check (may be used for debugging purposes)
69
-
70
- spans, throughputs = _compute_spans(module_infos)
71
- initial_throughput = throughputs.min()
72
-
73
- assert local_peer_id in spans, "Span served by this server is not present in the DHT"
74
- local_span = spans[local_peer_id]
75
- throughputs[local_span.start : local_span.end] -= local_span.throughput
76
-
77
- new_start = _choose_best_start(throughputs, local_span.length, local_span.start)
78
- if local_span.start == new_start:
79
- return False # This server is on its best place already
80
- local_span.move_to(new_start)
81
-
82
- throughputs[local_span.start : local_span.end] += local_span.throughput
83
-
84
- moved = True
85
- while moved:
86
- servers = list(spans.keys())
87
- np.random.shuffle(servers)
88
-
89
- moved = False
90
- for peer_id in servers:
91
- span = spans[peer_id]
92
- throughputs[span.start : span.end] -= span.throughput
93
-
94
- new_start = _choose_best_start(throughputs, span.length, span.start)
95
- if span.start != new_start:
96
- span.move_to(new_start)
97
- moved = True
98
-
99
- throughputs[span.start : span.end] += span.throughput
100
-
101
- new_throughput = throughputs.min()
102
- actual_quality = initial_throughput / new_throughput
103
- logger.info(f"Swarm balance quality: {actual_quality * 100:.1f}%")
104
-
105
- eps = 1e-6
106
- return actual_quality < balance_quality - eps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/cache.py DELETED
@@ -1,143 +0,0 @@
1
- """
2
- A pytorch memory cache that can be allocated by ConnectionHandler (on cpu) and used over multiple calls to Runtime.
3
-
4
- For now, the only purpose of this code is to ensure that allocated memory will be deleted properly.
5
-
6
- """
7
- import asyncio
8
- import contextlib
9
- import ctypes
10
- import multiprocessing as mp
11
- import os
12
- import time
13
- from typing import AsyncContextManager, Dict, Optional, Union
14
-
15
- import hivemind
16
- import torch
17
- from hivemind import use_hivemind_log_handler
18
- from hivemind.utils import TensorDescriptor, get_logger
19
-
20
- use_hivemind_log_handler("in_root_logger")
21
- logger = get_logger(__file__)
22
-
23
- Handle = int
24
-
25
-
26
- class MemoryCache:
27
- """A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
28
-
29
- def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
30
- self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
31
- self.device = device
32
- self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
33
- self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
34
- self._handle_counter = mp.Value(ctypes.c_int64, 0, lock=False)
35
- self._active_handles: Optional[Dict[Handle, TensorDescriptor]] = None
36
- self._allocated_tensors: Optional[Dict[Handle, torch.Tensor]] = None
37
- self.runtime_pid = os.getpid()
38
-
39
- self._pipe_recv, self._pipe_send = mp.Pipe(duplex=False) # any ConnectionHandler -> runtime
40
- self._pending_messages = mp.Value(ctypes.c_int64, 0, lock=False)
41
- self._lock_acquire_memory = mp.Lock()
42
- self._memory_freed_event = mp.Event()
43
-
44
- @property
45
- def current_size_bytes(self) -> int:
46
- return self._current_size.value
47
-
48
- @current_size_bytes.setter
49
- def current_size_bytes(self, value: int):
50
- self._current_size.value = value
51
-
52
- @property
53
- def handle_counter(self) -> int:
54
- return self._handle_counter.value
55
-
56
- @handle_counter.setter
57
- def handle_counter(self, value: int):
58
- self._handle_counter.value = value
59
-
60
- @contextlib.asynccontextmanager
61
- async def allocate_cache(self, descr: TensorDescriptor) -> AsyncContextManager[Handle]:
62
- """
63
- Create a handle that is associated with buffers on unique device. If cache full, raises AllocationFailed.
64
-
65
- :param descr: allocate a tensor of this size, dtype, etc
66
-
67
- :note: This function should be called by connection handlers, it can be called concurrently from multiple processes.
68
- Furthermore, it can be called concurrently with at most one use_cache call in runtime.
69
- """
70
- assert os.getpid() != self.runtime_pid, "must be called by a ConnectionHandler, not runtime"
71
- assert descr.device is None and descr
72
- allocated_handle = None
73
- allocated_size_bytes = descr.numel() * torch.finfo(descr.dtype).bits // 8
74
- loop = asyncio.get_event_loop()
75
- try:
76
- async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
77
- if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
78
- await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
79
- async with hivemind.utils.enter_asynchronously(self._lock_metadata):
80
- allocated_handle = int(self.handle_counter)
81
- self.current_size_bytes += allocated_size_bytes
82
- self.handle_counter += 1 # note: this will eventually overflow and it is okay
83
- self._pending_messages.value += 1
84
- self._pipe_send.send((allocated_handle, descr))
85
-
86
- yield allocated_handle
87
- finally:
88
- if allocated_handle is not None:
89
- async with hivemind.utils.enter_asynchronously(self._lock_metadata):
90
- self._pending_messages.value += 1
91
- self._pipe_send.send((allocated_handle, None)) # signal runtime to free that handle
92
- self.current_size_bytes -= allocated_size_bytes
93
- self._memory_freed_event.set()
94
-
95
- def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
96
- # note: this function should only be called inside _lock_acquire_memory!
97
- if allocated_size_bytes > self.max_size_bytes:
98
- raise AllocationFailed(
99
- f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
100
- )
101
- deadline = None if timeout is None else time.perf_counter() + timeout
102
- while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
103
- remaining_time = deadline - time.perf_counter() if timeout is not None else None
104
- if not self._memory_freed_event.wait(remaining_time):
105
- raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
106
- self._memory_freed_event.clear()
107
-
108
- @contextlib.contextmanager
109
- def use_cache(self, handle: Handle) -> torch.Tensor:
110
- """
111
- Return a tensor that was previously allocated with try_allocate_cache,
112
-
113
- :note: This method is called by ExpertBackend in runtime: a single process with NO process parallelism.
114
- However, runtime may call use_cache concurrently with one or more connection handlers calling allocate_cache
115
- """
116
- assert os.getpid() == self.runtime_pid
117
- # note: this specific function is not concurrent, so you can safely allocate/offload/defragment data here
118
-
119
- with self._lock_metadata:
120
- if self._allocated_tensors is None:
121
- self._allocated_tensors = {}
122
-
123
- # read creation/deletion requests from connection handlers
124
- for i in range(int(self._pending_messages.value)):
125
- recv_handle, recv_data = self._pipe_recv.recv()
126
- self._pending_messages.value -= 1
127
- if isinstance(recv_data, TensorDescriptor):
128
- self._allocated_tensors[recv_handle] = recv_data.make_zeros(device=self.device)
129
- elif recv_data is None:
130
- if recv_handle not in self._allocated_tensors:
131
- logger.warning(
132
- f"Sanity check failed: asked to delete handle {recv_handle}, but there is no such handle"
133
- )
134
- self._allocated_tensors.pop(recv_handle, None)
135
- else:
136
- logger.error(f"MemoryCache pipe received unexpected message: {recv_data}")
137
-
138
- assert handle in self._allocated_tensors, f"Sanity check failed: no such handle ({handle})"
139
- yield self._allocated_tensors[handle]
140
-
141
-
142
- class AllocationFailed(Exception):
143
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/handler.py DELETED
@@ -1,421 +0,0 @@
1
- import contextlib
2
- from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
3
-
4
- import torch
5
- from hivemind import (
6
- DHT,
7
- MSGPackSerializer,
8
- P2PContext,
9
- TensorDescriptor,
10
- deserialize_tensor_stream,
11
- deserialize_torch_tensor,
12
- nested_flatten,
13
- serialize_torch_tensor,
14
- )
15
- from hivemind.moe.server.connection_handler import ConnectionHandler
16
- from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
17
- from hivemind.proto import runtime_pb2
18
- from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
19
- from hivemind.utils.logging import get_logger
20
- from hivemind.utils.streaming import split_for_streaming
21
-
22
- from src.data_structures import CHAIN_DELIMITER, ModuleUID
23
- from src.server.backend import TransformerBackend
24
- from src.server.task_pool import PrioritizedTaskPool
25
- from src.server.task_prioritizer import DummyTaskPrioritizer, TaskPrioritizerBase
26
- from src.utils.misc import DUMMY, is_dummy
27
-
28
- logger = get_logger(__file__)
29
-
30
-
31
- class TransformerConnectionHandler(ConnectionHandler):
32
- """Handles three request types: forward, backward and forward-incremental (inference)"""
33
-
34
- module_backends: Dict[ModuleUID, TransformerBackend]
35
-
36
- def __init__(
37
- self,
38
- dht: DHT,
39
- module_backends: Dict[str, TransformerBackend],
40
- inference_max_length: int,
41
- task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
42
- ):
43
- super().__init__(dht, module_backends)
44
- for module_backend in self.module_backends.values():
45
- assert isinstance(module_backend, TransformerBackend)
46
- self.inference_max_length = inference_max_length
47
- self._prioritizer = task_prioritizer
48
-
49
- async def _gather_inputs(
50
- self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
51
- ) -> Tuple[str, List[torch.Tensor], Dict]:
52
- block_uid, metadata = None, None
53
-
54
- def _unpack(req: runtime_pb2.ExpertRequest) -> Iterable[runtime_pb2.Tensor]:
55
- nonlocal block_uid, metadata
56
-
57
- if block_uid is None:
58
- block_uid = req.uid
59
- elif block_uid != req.uid:
60
- raise ValueError("Block uids differ in one request")
61
-
62
- if metadata is None:
63
- metadata = MSGPackSerializer.loads(req.metadata) if req.metadata else {}
64
-
65
- return req.tensors
66
-
67
- tensors_stream = amap_in_executor(_unpack, requests)
68
- inputs = await deserialize_tensor_stream(tensors_stream)
69
- assert isinstance(block_uid, str) and isinstance(metadata, dict)
70
- return block_uid, inputs, metadata
71
-
72
- async def rpc_inference(
73
- self,
74
- requests: AsyncIterator[runtime_pb2.ExpertRequest],
75
- context: P2PContext,
76
- ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
77
- """Compute a single step of inference using attention cache; update attention cache accordingly."""
78
- try:
79
- logger.debug("Opened rpc_inference()")
80
- request = await anext(requests)
81
- requested_uids = self._check_uids(request.uid)
82
- metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
83
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
84
- max_length = metadata.get("max_length")
85
- points = metadata.get("points", 0)
86
-
87
- if not requested_uids:
88
- raise ValueError("User must specify at least one block for inference, but got none")
89
- assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
90
- assert isinstance(
91
- points, (float, int)
92
- ), f"rpc_inference should have number of points as a number or None, got {points}"
93
- if not 0 <= max_length <= self.inference_max_length:
94
- raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
95
-
96
- point_per_piece = points / max_length if max_length > 0 else 0.0
97
- batch_size = request.tensors[0].size[0] if request.tensors else 1
98
-
99
- cache_metadata = torch.tensor(
100
- [[-1, -1] for _ in range(batch_size)], dtype=torch.int64
101
- ) # [cache_handle, prefix_length]
102
- prefix_length = 0
103
-
104
- async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
105
- assert len(cache_handles) == len(requested_backends)
106
- while request.tensors: # iterate while user is willing to supply tensors
107
- hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
108
-
109
- # Cast inputs to backend dtype
110
- hidden_states = hidden_states.to(requested_backends[0].dtype)
111
- assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
112
-
113
- # parse deep prompts (optional argument)
114
- if prompts is None or is_dummy(prompts) or is_dummy(prompts):
115
- prompts = [DUMMY] * len(requested_backends)
116
- else:
117
- prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
118
-
119
- if not (len(requested_backends) == len(prompts)):
120
- raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
121
-
122
- length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
123
- if prefix_length + length_increment > max_length:
124
- raise ValueError(
125
- f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
126
- f" exceeds pre-allocated maximum {max_length}"
127
- )
128
-
129
- # run request tensors through all requested modules, update caches
130
- for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
131
- if not is_dummy(prompt):
132
- hidden_states[:, : prompt.shape[1]] += prompt
133
-
134
- cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
135
- assert isinstance(
136
- hidden_states, torch.Tensor
137
- ), f"hidden states must be tensor, got {type(hidden_states)}"
138
- assert (
139
- hidden_states.ndim == 3
140
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
141
- assert isinstance(
142
- backend.inference_pool, PrioritizedTaskPool
143
- ), "petals support only prioritized pools"
144
- priority = self._prioritizer.prioritize(
145
- cache_metadata,
146
- hidden_states,
147
- hypo_ids,
148
- points=point_per_piece / len(requested_backends),
149
- backend=backend,
150
- type="inference",
151
- )
152
- (hidden_states,) = await backend.inference_pool.submit_task(
153
- cache_metadata, hidden_states, hypo_ids, priority=priority
154
- )
155
-
156
- # serialize and send last layer outputs
157
- yield runtime_pb2.ExpertResponse(
158
- tensors=[
159
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
160
- for result, proto in zip(
161
- (hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
162
- )
163
- ]
164
- )
165
-
166
- # prepare for next step
167
- prefix_length += hidden_states.shape[1]
168
- request = await (anext(requests))
169
- finally:
170
- logger.debug("Closed rpc_inference()")
171
-
172
- async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
173
- # Parse request and prepare backends
174
- flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
175
- requested_uids = self._check_uids(request.uid)
176
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
177
- metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
178
- points = metadata.get("points", 0)
179
- assert isinstance(
180
- points, (float, int)
181
- ), f"rpc_forward should have number of points as number or None, got {points}"
182
-
183
- hidden_states = await _rpc_forward(
184
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
185
- )
186
- assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
187
-
188
- # Serialize output and respond to client
189
- return runtime_pb2.ExpertResponse(
190
- tensors=[
191
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
192
- for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
193
- ]
194
- )
195
-
196
- async def rpc_forward_stream(
197
- self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
198
- ) -> AsyncIterator[runtime_pb2.ExpertRequest]:
199
- # Parse requests and prepare backends
200
- uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
201
- requested_uids = self._check_uids(uid_str)
202
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
203
- points = metadata.get("points", 0)
204
- assert isinstance(
205
- points, (float, int)
206
- ), f"rpc_forward_stream should have number of points as number or None, got {points}"
207
-
208
- hidden_states = await _rpc_forward(
209
- *flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
210
- )
211
- assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
212
-
213
- # Serialize the overall output
214
- serialized_output = [
215
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
216
- for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
217
- ]
218
-
219
- # Split the serialized_output for streaming and respond to client
220
- output_split = [
221
- part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
222
- ]
223
- async for part in as_aiter(*output_split):
224
- yield runtime_pb2.ExpertResponse(tensors=[part])
225
-
226
- async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
227
- # Parse requests and prepare backends
228
- flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
229
- requested_uids = self._check_uids(request.uid)
230
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
231
- metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
232
- points = metadata.get("points", 0)
233
- assert isinstance(
234
- points, (float, int)
235
- ), f"rpc_backward should have number of points as number or None, got {points}"
236
-
237
- grads = await _rpc_backward(
238
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
239
- )
240
-
241
- # Modify grad_inputs_schema to support grad_prompts
242
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
243
-
244
- grad_inputs_schema_with_prompts = (
245
- requested_backends[0].args_schema * len(grads),
246
- requested_backends[0].kwargs_schema,
247
- ) # TODO generalize
248
-
249
- # Serialize the overall grad_input and respond
250
- return runtime_pb2.ExpertResponse(
251
- tensors=[
252
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
253
- for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
254
- ]
255
- )
256
-
257
- async def rpc_backward_stream(
258
- self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
259
- ) -> AsyncIterator[runtime_pb2.ExpertResponse]:
260
-
261
- uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
262
- requested_uids = self._check_uids(uids_header)
263
- requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
264
- points = metadata.get("points", 0)
265
- assert isinstance(
266
- points, (float, int)
267
- ), f"rpc_backward_stream should have number of points as number or None, got {points}"
268
-
269
- grads = await _rpc_backward(
270
- *flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
271
- )
272
-
273
- # Modify grad_inputs_schema to support grad_prompts
274
- assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
275
- grad_inputs_schema_with_prompts = (
276
- requested_backends[0].args_schema * len(grads),
277
- requested_backends[0].kwargs_schema,
278
- ) # TODO generalize
279
-
280
- # Serialize the overall grad_inputs
281
- serialized_grad_inputs = [
282
- serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
283
- for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
284
- ]
285
- # Split the serialized_grad_inputs for streaming and respond
286
- output_split = [
287
- part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
288
- ]
289
-
290
- async for part in as_aiter(*output_split):
291
- yield runtime_pb2.ExpertResponse(tensors=[part])
292
-
293
- def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
294
- """Check that the first request to rpc_inference is valid"""
295
- uids = (uids or "").split(CHAIN_DELIMITER)
296
- if not uids:
297
- raise RuntimeError("User did not provide any uids")
298
- for uid in uids:
299
- if uid not in self.module_backends:
300
- raise RuntimeError(f"Remote peer does not serve {uid}")
301
- return tuple(uids)
302
-
303
- @contextlib.asynccontextmanager
304
- async def _allocate_caches(
305
- self, backends: Sequence[TransformerBackend], batch_size: int, max_length: int
306
- ) -> Sequence[int]:
307
- """Allocate memory caches for each transformer block, return cache handles"""
308
- async with contextlib.AsyncExitStack() as stack:
309
- handles = []
310
- for backend in backends:
311
- num_heads = backend.module.self_attention.num_heads
312
- head_dim = backend.module.self_attention.head_dim
313
-
314
- cache_descriptor = TensorDescriptor(
315
- size=(2, batch_size, max_length, num_heads, head_dim), dtype=backend.dtype
316
- )
317
- # [key_or_value, batch_size, max_length, num_heads, head_dim]
318
-
319
- handles.append(await stack.enter_async_context(backend.memory_cache.allocate_cache(cache_descriptor)))
320
-
321
- yield handles
322
-
323
-
324
- async def _rpc_forward(
325
- *flat_tensors: torch.Tensor,
326
- requested_backends: Sequence[TransformerBackend],
327
- prioritizer: TaskPrioritizerBase,
328
- points: int = 0,
329
- ) -> torch.Tensor:
330
- """
331
- Run forward pass on deserialized inputs and prompts, used by rpc_forward and rpc_forward_stream
332
-
333
- :param flat_tensors: a list of tensors that includes first layer inputs, optional prompts and extra tensors
334
- :note: some input tensors can be missing, in which case they will be replaced with dummy tensors (see is_dummy)
335
- :param requested_backends: a sequence of transformer blocks in the same order as they appear in forward pass
336
- :returns: hidden states after the last layer [batch_size, seq_length, hid_size]
337
- """
338
- hidden_states, prompts = flat_tensors
339
- dtype = requested_backends[0].dtype
340
- # check parse input tensors and cast dtypes
341
- hidden_states = hidden_states.to(dtype)
342
- assert hidden_states.ndim == 3
343
- if prompts is None or is_dummy(prompts):
344
- prompts = [DUMMY] * len(requested_backends)
345
- else:
346
- prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
347
-
348
- # Run a chain of requested backends
349
- for backend, prompt in zip(requested_backends, prompts):
350
- if not is_dummy(prompt):
351
- hidden_states[:, : prompt.shape[1]] += prompt
352
-
353
- assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
354
- priority = prioritizer.prioritize(
355
- hidden_states, points=points / len(requested_backends), backend=backend, type="forward"
356
- )
357
- (hidden_states,) = await backend.forward_pool.submit_task(
358
- hidden_states,
359
- priority=priority,
360
- )
361
- assert isinstance(hidden_states, torch.Tensor)
362
- assert (
363
- hidden_states.ndim == 3
364
- ), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
365
-
366
- # Serialize the overall output
367
- return hidden_states
368
-
369
-
370
- async def _rpc_backward(
371
- *flat_tensors: torch.Tensor,
372
- requested_backends: Sequence[TransformerBackend],
373
- prioritizer: TaskPrioritizerBase,
374
- points: int = 0,
375
- ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
376
- inputs, grad_outputs, prompts = flat_tensors
377
- # Cast inputs & grad outputs to backend dtype
378
- inputs = inputs.to(requested_backends[0].dtype)
379
- grad_outputs = grad_outputs.to(requested_backends[-1].dtype)
380
-
381
- if prompts is None or is_dummy(prompts):
382
- prompts = [DUMMY] * len(requested_backends)
383
- else:
384
- prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
385
-
386
- # Run a forward chain to collect intermediate inputs
387
- # Note that we do not forward for the last module since we do not need its output
388
- inter_inputs = []
389
- for backend, prompt in zip(requested_backends[:-1], prompts[:-1]):
390
- assert inputs.ndim == 3, f"inputs to {type(backend)} must be a single 3d tensor of hidden states"
391
- if not is_dummy(prompt):
392
- inputs[:, : prompt.shape[1]] += prompt
393
- inter_inputs.append(inputs)
394
- assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
395
- priority = prioritizer.prioritize(
396
- inputs, points=points / len(requested_backends), backend=backend, type="forward_in_backward"
397
- )
398
- (inputs,) = await backend.forward_pool.submit_task(inputs, priority=priority)
399
-
400
- assert isinstance(inputs, torch.Tensor)
401
-
402
- if not is_dummy(prompts[-1]):
403
- inputs[:, : prompts[-1].shape[1]] += prompts[-1]
404
- inter_inputs.append(inputs)
405
-
406
- assert len(inter_inputs) == len(prompts) == len(requested_backends), "internal shape error during backward"
407
- grad_prompts_reversed = []
408
- # Run a chain of requested backends
409
- for inp, prompt, backend in zip(*map(reversed, (inter_inputs, prompts, requested_backends))):
410
- assert isinstance(backend.inference_pool, PrioritizedTaskPool), "petals support only prioritized pools"
411
- priority = prioritizer.prioritize(
412
- inp, grad_outputs, points=points / len(requested_backends), backend=backend, type="backward"
413
- )
414
- (grad_outputs,) = await backend.backward_pool.submit_task(inp, grad_outputs, priority=priority)
415
-
416
- assert isinstance(grad_outputs, torch.Tensor)
417
- if not is_dummy(prompt):
418
- grad_prompts_reversed.append(grad_outputs[:, : prompt.shape[1]].unsqueeze(0))
419
-
420
- grad_prompts = torch.cat(grad_prompts_reversed[::-1], dim=0) if grad_prompts_reversed else DUMMY
421
- return [grad_outputs] if is_dummy(grad_prompts) else [grad_outputs, grad_prompts] # TODO un-duct-tape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/runtime.py DELETED
@@ -1,198 +0,0 @@
1
- import multiprocessing as mp
2
- import multiprocessing.pool
3
- import threading
4
- from collections import defaultdict
5
- from itertools import chain
6
- from queue import SimpleQueue
7
- from selectors import EVENT_READ, DefaultSelector
8
- from statistics import mean
9
- from time import time
10
- from typing import Dict, NamedTuple, Optional
11
-
12
- import torch
13
- from hivemind.moe.server.module_backend import ModuleBackend
14
- from hivemind.utils import get_logger
15
- from prefetch_generator import BackgroundGenerator
16
-
17
- logger = get_logger(__name__)
18
-
19
-
20
- class Runtime(threading.Thread):
21
- """
22
- A group of processes that processes incoming requests for multiple module backends on a shared device.
23
- Runtime is usually created and managed by Server, humans need not apply.
24
-
25
- For debugging, you can start runtime manually with .start() or .run()
26
-
27
- >>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
28
- >>> runtime = Runtime(module_backends)
29
- >>> runtime.start() # start runtime in background thread. To start in current thread, use runtime.run()
30
- >>> runtime.ready.wait() # await for runtime to load all blocks on device and create request pools
31
- >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
32
- >>> print("Returned:", future.result())
33
- >>> runtime.shutdown()
34
-
35
- :param module_backends: a dict [block uid -> ModuleBackend]
36
- :param prefetch_batches: form up to this many batches in advance
37
- :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
38
- :param device: if specified, moves all blocks and data to this device via .to(device=device).
39
- If you want to manually specify devices for each block (in their forward pass), leave device=None (default)
40
-
41
- :param stats_report_interval: interval to collect and log statistics about runtime performance
42
- """
43
-
44
- SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"
45
-
46
- def __init__(
47
- self,
48
- module_backends: Dict[str, ModuleBackend],
49
- prefetch_batches: int = 1,
50
- sender_threads: int = 1,
51
- device: torch.device = None,
52
- stats_report_interval: Optional[int] = None,
53
- ):
54
- super().__init__()
55
- self.module_backends = module_backends
56
- self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
57
- self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
58
- self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
59
- self.shutdown_trigger = mp.Event()
60
- self.ready = mp.Event() # event is set iff server is currently running and ready to accept batches
61
-
62
- self.stats_report_interval = stats_report_interval
63
- if self.stats_report_interval is not None:
64
- self.stats_reporter = StatsReporter(self.stats_report_interval)
65
-
66
- def run(self):
67
- for pool in self.pools:
68
- if not pool.is_alive():
69
- pool.start()
70
- if self.device is not None:
71
- for backend in self.module_backends.values():
72
- backend.module.to(self.device)
73
-
74
- with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
75
- try:
76
- self.ready.set()
77
- if self.stats_report_interval is not None:
78
- self.stats_reporter.start()
79
- logger.info("Started")
80
-
81
- batch_iterator = self.iterate_minibatches_from_pools()
82
- if self.prefetch_batches > 0:
83
- batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)
84
-
85
- for pool, batch_index, batch in batch_iterator:
86
- logger.debug(f"Processing batch {batch_index} from pool {pool.name}")
87
-
88
- start = time()
89
- try:
90
- outputs = pool.process_func(*batch)
91
- output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])
92
-
93
- batch_processing_time = time() - start
94
-
95
- batch_size = outputs[0].size(0)
96
- logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")
97
-
98
- if self.stats_report_interval is not None:
99
- self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)
100
-
101
- except KeyboardInterrupt:
102
- raise
103
- except BaseException as exception:
104
- logger.exception(f"Caught {exception}, attempting to recover")
105
- output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])
106
-
107
- finally:
108
- if not self.shutdown_trigger.is_set():
109
- self.shutdown()
110
-
111
- def shutdown(self):
112
- """Gracefully terminate a running runtime."""
113
- logger.info("Shutting down")
114
- self.ready.clear()
115
-
116
- if self.stats_report_interval is not None:
117
- self.stats_reporter.stop.set()
118
- self.stats_reporter.join()
119
-
120
- logger.debug("Terminating pools")
121
- for pool in self.pools:
122
- if pool.is_alive():
123
- pool.shutdown()
124
- logger.debug("Pools terminated")
125
-
126
- # trigger background thread to shutdown
127
- self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
128
- self.shutdown_trigger.set()
129
-
130
- def iterate_minibatches_from_pools(self, timeout=None):
131
- """
132
- Chooses pool according to priority, then copies exposed batch and frees the buffer
133
- """
134
- with DefaultSelector() as selector:
135
- for pool in self.pools:
136
- selector.register(pool.batch_receiver, EVENT_READ, pool)
137
- selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)
138
-
139
- while True:
140
- # wait until at least one batch_receiver becomes available
141
- logger.debug("Waiting for inputs from task pools")
142
- ready_fds = selector.select()
143
- ready_objects = {key.data for (key, events) in ready_fds}
144
- if self.SHUTDOWN_TRIGGER in ready_objects:
145
- break # someone asked us to shutdown, break from the loop
146
-
147
- logger.debug("Choosing the pool with first priority")
148
-
149
- pool = min(ready_objects, key=lambda pool: pool.priority)
150
-
151
- logger.debug(f"Loading batch from {pool.name}")
152
- batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
153
- logger.debug(f"Loaded batch from {pool.name}")
154
- yield pool, batch_index, batch_tensors
155
-
156
-
157
- BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))
158
-
159
-
160
- class StatsReporter(threading.Thread):
161
- def __init__(self, report_interval: int):
162
- super().__init__()
163
- self.report_interval = report_interval
164
- self.stop = threading.Event()
165
- self.stats_queue = SimpleQueue()
166
-
167
- def run(self):
168
- while not self.stop.wait(self.report_interval):
169
- pool_batch_stats = defaultdict(list)
170
- while not self.stats_queue.empty():
171
- pool_uid, batch_stats = self.stats_queue.get()
172
- pool_batch_stats[pool_uid].append(batch_stats)
173
-
174
- total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
175
- logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
176
- for pool_uid, pool_stats in pool_batch_stats.items():
177
- total_batches = len(pool_stats)
178
- total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
179
- avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
180
- total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
181
- batches_to_time = total_batches / total_time
182
- batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")
183
-
184
- examples_to_time = total_examples / total_time
185
- example_performance = f"{examples_to_time:.2f} " + (
186
- "examples/s" if examples_to_time > 1 else "s/example"
187
- )
188
-
189
- logger.info(
190
- f"{pool_uid}: "
191
- f"{total_batches} batches ({batch_performance}), "
192
- f"{total_examples} examples ({example_performance}), "
193
- f"avg batch size {avg_batch_size:.2f}"
194
- )
195
-
196
- def report_stats(self, pool_uid, batch_size, processing_time):
197
- batch_stats = BatchStats(batch_size, processing_time)
198
- self.stats_queue.put_nowait((pool_uid, batch_stats))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/server.py DELETED
@@ -1,475 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import gc
4
- import multiprocessing as mp
5
- import random
6
- import threading
7
- import time
8
- from typing import Dict, List, Optional, Sequence, Union
9
-
10
- import numpy as np
11
- import psutil
12
- import torch
13
- from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
14
- from hivemind.moe.server.layers import add_custom_models_from_file
15
- from hivemind.moe.server.runtime import Runtime
16
- from hivemind.proto.runtime_pb2 import CompressionType
17
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
18
-
19
- from src import BloomConfig, declare_active_modules
20
- from src.bloom.from_pretrained import DTYPE_MAP, load_pretrained_block
21
- from src.data_structures import CHAIN_DELIMITER, UID_DELIMITER, ServerState
22
- from src.dht_utils import get_remote_module_infos
23
- from src.server import block_selection
24
- from src.server.backend import TransformerBackend
25
- from src.server.cache import MemoryCache
26
- from src.server.handler import TransformerConnectionHandler
27
- from src.server.throughput import get_host_throughput
28
- from src.utils.convert_8bit import replace_8bit_linear
29
-
30
- use_hivemind_log_handler("in_root_logger")
31
- logger = get_logger(__file__)
32
-
33
-
34
- class Server(threading.Thread):
35
- """
36
- Runs ModuleContainer, periodically checks that the network is balanced,
37
- restarts the ModuleContainer with other layers if the imbalance is significant
38
- """
39
-
40
- def __init__(
41
- self,
42
- prefix: Optional[str],
43
- converted_model_name_or_path: str,
44
- throughput: Union[float, str],
45
- num_blocks: Optional[int] = None,
46
- block_indices: Optional[str] = None,
47
- num_handlers: int = 8,
48
- min_batch_size: int = 1,
49
- max_batch_size: int = 4096,
50
- inference_max_length: int = 4096,
51
- torch_dtype: str = "auto",
52
- revision: str = "main",
53
- cache_dir: Optional[str] = None,
54
- attn_cache_size: Optional[int] = None,
55
- device: Optional[Union[str, torch.device]] = None,
56
- initial_peers: Sequence[str] = (),
57
- compression=CompressionType.NONE,
58
- stats_report_interval: Optional[int] = None,
59
- custom_module_path=None,
60
- update_period: float = 30,
61
- expiration: Optional[float] = None,
62
- prefetch_batches: int = 1,
63
- sender_threads: int = 1,
64
- balance_quality: float = 0.75,
65
- mean_balance_check_period: float = 60,
66
- mean_block_selection_delay: float = 0.5,
67
- use_auth_token: Optional[str] = None,
68
- load_in_8bit: bool = False,
69
- *,
70
- start: bool,
71
- **kwargs,
72
- ):
73
- """Create a server with one or more bloom blocks. See run_server.py for documentation."""
74
-
75
- super().__init__()
76
-
77
- self.converted_model_name_or_path = converted_model_name_or_path
78
- self.num_handlers = num_handlers
79
- self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
80
- self.inference_max_length = inference_max_length
81
- self.cache_dir = cache_dir
82
- self.attn_cache_size = attn_cache_size
83
- self.compression = compression
84
- self.stats_report_interval, self.update_period = stats_report_interval, update_period
85
- self.prefetch_batches, self.sender_threads = prefetch_batches, sender_threads
86
- self.use_auth_token = use_auth_token
87
- self.load_in_8bit = load_in_8bit
88
-
89
- if custom_module_path is not None:
90
- add_custom_models_from_file(custom_module_path)
91
-
92
- if prefix is None:
93
- prefix = converted_model_name_or_path
94
- assert UID_DELIMITER not in prefix and CHAIN_DELIMITER not in prefix, (
95
- f"Cannot use model name as prefix (contains '{UID_DELIMITER}' or '{CHAIN_DELIMITER}'); "
96
- f"Please specify --prefix manually when starting a server"
97
- )
98
- logger.info(f"Automatic dht prefix: {prefix}")
99
- self.prefix = prefix
100
-
101
- if expiration is None:
102
- expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
103
- self.expiration = expiration
104
-
105
- self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
106
- visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
107
- logger.info(f"Running DHT node on {visible_maddrs_str}, initial peers = {initial_peers}")
108
-
109
- device = device or ("cuda" if torch.cuda.is_available() else "cpu")
110
- self.device = device
111
-
112
- self.memory_cache = MemoryCache(device, attn_cache_size)
113
-
114
- assert isinstance(throughput, float) or throughput in ["auto", "eval"]
115
- if throughput in ["auto", "eval"]:
116
- throughput = get_host_throughput(device, force_eval=(throughput == "eval"))
117
- self.throughput = throughput
118
-
119
- if isinstance(torch_dtype, str):
120
- torch_dtype = DTYPE_MAP[torch_dtype]
121
- assert torch_dtype in DTYPE_MAP.values(), f"torch_dtype must be one of {list(DTYPE_MAP.values())}"
122
- self.torch_dtype = torch_dtype
123
-
124
- self.block_config = BloomConfig.from_pretrained(
125
- converted_model_name_or_path,
126
- use_auth_token=use_auth_token,
127
- revision=revision,
128
- )
129
- self.module_uids = [f"{self.prefix}.{block_index}" for block_index in range(self.block_config.n_layer)]
130
-
131
- assert (block_indices is None) != (num_blocks is None), "please specify num_blocks or block_indices, not both"
132
- if block_indices is not None:
133
- try:
134
- first_block_index, last_block_index = block_indices.split(":")
135
- first_block_index, last_block_index = map(int, map(str.strip, (first_block_index, last_block_index)))
136
- except Exception as e:
137
- logger.error(f"Failed to parse --block_indices ({e}), must be start:end (e.g. 0:18)")
138
- raise
139
- block_indices = range(first_block_index, last_block_index)
140
- self.strict_block_indices, self.num_blocks = block_indices, num_blocks
141
- self.balance_quality = balance_quality
142
- self.mean_balance_check_period = mean_balance_check_period
143
- self.mean_block_selection_delay = mean_block_selection_delay
144
-
145
- self.stop = threading.Event()
146
- if start:
147
- self.start()
148
-
149
- def run(self):
150
- while True:
151
- block_indices = self._choose_blocks()
152
- self.module_container = ModuleContainer.create(
153
- dht=self.dht,
154
- prefix=self.prefix,
155
- converted_model_name_or_path=self.converted_model_name_or_path,
156
- block_config=self.block_config,
157
- memory_cache=self.memory_cache,
158
- throughput=self.throughput,
159
- block_indices=block_indices,
160
- num_handlers=self.num_handlers,
161
- min_batch_size=self.min_batch_size,
162
- max_batch_size=self.max_batch_size,
163
- inference_max_length=self.inference_max_length,
164
- torch_dtype=self.torch_dtype,
165
- cache_dir=self.cache_dir,
166
- device=self.device,
167
- compression=self.compression,
168
- stats_report_interval=self.stats_report_interval,
169
- update_period=self.update_period,
170
- expiration=self.expiration,
171
- prefetch_batches=self.prefetch_batches,
172
- sender_threads=self.sender_threads,
173
- use_auth_token=self.use_auth_token,
174
- load_in_8bit=self.load_in_8bit,
175
- start=True,
176
- )
177
- try:
178
- self.module_container.ready.wait()
179
-
180
- while True:
181
- timeout = random.random() * 2 * self.mean_balance_check_period
182
- # TODO: Follow ModuleContainer status (to restart/stop if it crashes)
183
- if self.stop.wait(timeout):
184
- return
185
-
186
- if self._should_choose_other_blocks():
187
- logger.info("Swarm is imbalanced, server will load other blocks")
188
- break # Stop serving this set of modules
189
- finally:
190
- self.module_container.shutdown()
191
-
192
- self._clean_memory_and_fds()
193
-
194
- def _clean_memory_and_fds(self):
195
- del self.module_container
196
- gc.collect() # In particular, this closes unused file descriptors
197
-
198
- cur_proc = psutil.Process()
199
- num_fds = [proc.num_fds() for proc in [cur_proc] + psutil.Process().children(recursive=True)]
200
- logger.info(f"Cleanup complete, {sum(num_fds)} open file descriptors left")
201
-
202
- def _choose_blocks(self) -> List[int]:
203
- if self.strict_block_indices is not None:
204
- return self.strict_block_indices
205
- assert self.num_blocks is not None
206
-
207
- # If multiple servers (e.g., launched on the same machine by a script) get to this line at the same time,
208
- # this delay decreases the probability of a race condition while choosing the best blocks to serve.
209
- time.sleep(random.random() * 2 * self.mean_block_selection_delay)
210
- module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
211
- return block_selection.choose_best_blocks(self.num_blocks, module_infos)
212
-
213
- def _should_choose_other_blocks(self) -> bool:
214
- if self.strict_block_indices is not None:
215
- return False
216
-
217
- module_infos = get_remote_module_infos(self.dht, self.module_uids, expiration_time=np.inf)
218
- return block_selection.should_choose_other_blocks(self.dht.peer_id, module_infos, self.balance_quality)
219
-
220
- def shutdown(self):
221
- self.stop.set()
222
-
223
- self.dht.shutdown()
224
- self.dht.join()
225
-
226
-
227
- class ModuleContainer(threading.Thread):
228
- """Serves a set of specific Bloom layers for inference, forward, and backward. Announces itself over the DHT."""
229
-
230
- def __init__(
231
- self,
232
- dht: DHT,
233
- module_backends: Dict[str, TransformerBackend],
234
- *,
235
- inference_max_length: int,
236
- num_connection_handlers: int,
237
- throughput: float,
238
- update_period: float,
239
- expiration: Optional[float] = None,
240
- start: bool,
241
- **kwargs,
242
- ):
243
- super().__init__()
244
-
245
- self.dht, self.module_backends = dht, module_backends
246
- self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
247
- self.conn_handlers = [
248
- TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
249
- for _ in range(num_connection_handlers)
250
- ]
251
- self.runtime = Runtime(self.module_backends, **kwargs)
252
- self.dht_handler_thread = ModuleAnnouncerThread(
253
- self.module_backends,
254
- dht,
255
- throughput=throughput,
256
- update_period=update_period,
257
- expiration=expiration,
258
- daemon=True,
259
- )
260
- self.checkpoint_saver = None # no need to save checkpoints since we do not change model state
261
-
262
- if start:
263
- self.run_in_background(await_ready=True)
264
-
265
- def run(self):
266
- """
267
- Runs ModuleContainer in the current thread. Initializes dht if necessary, starts connection handlers,
268
- runs Runtime (self.runtime) to process incoming requests.
269
- """
270
- logger.info(f"Serving {len(self.module_backends)} blocks:")
271
- for expert_name, backend in self.module_backends.items():
272
- num_parameters = sum(p.numel() for p in backend.module.parameters() if p.requires_grad)
273
- logger.info(f"{expert_name}: {backend.module.__class__.__name__}, {num_parameters} parameters")
274
-
275
- if not self.dht.is_alive():
276
- self.dht.run_in_background(await_ready=True)
277
-
278
- if self.module_backends:
279
- self.dht_handler_thread.start()
280
-
281
- if self.checkpoint_saver is not None:
282
- self.checkpoint_saver.start()
283
-
284
- for handler in self.conn_handlers:
285
- handler.run_in_background()
286
-
287
- self.runtime.run()
288
-
289
- # noinspection PyMethodOverriding
290
- @classmethod
291
- def create(
292
- cls,
293
- *,
294
- dht: DHT,
295
- prefix: str,
296
- converted_model_name_or_path: str,
297
- block_config: BloomConfig,
298
- memory_cache: MemoryCache,
299
- throughput: float,
300
- block_indices: List[int],
301
- num_handlers: Optional[int],
302
- min_batch_size: int,
303
- max_batch_size: int,
304
- inference_max_length: int,
305
- torch_dtype: torch.dtype,
306
- cache_dir: Optional[str],
307
- device: Union[str, torch.device],
308
- compression: CompressionType,
309
- stats_report_interval: Optional[int],
310
- update_period: float,
311
- expiration: Optional[float],
312
- prefetch_batches: int,
313
- sender_threads: int,
314
- use_auth_token: Optional[str],
315
- load_in_8bit: bool,
316
- start: bool,
317
- ) -> ModuleContainer:
318
- module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
319
- declare_active_modules(
320
- dht,
321
- module_uids,
322
- expiration_time=get_dht_time() + expiration,
323
- state=ServerState.JOINING,
324
- throughput=throughput,
325
- )
326
- logger.info(f"Announced that blocks {block_indices} are joining")
327
-
328
- blocks = {}
329
- for module_uid, block_index in zip(module_uids, block_indices):
330
- block = load_pretrained_block(
331
- converted_model_name_or_path,
332
- block_index,
333
- block_config,
334
- torch_dtype=torch_dtype,
335
- use_auth_token=use_auth_token,
336
- cache_dir=cache_dir,
337
- )
338
-
339
- if load_in_8bit:
340
- dtype = block.input_layernorm.weight.dtype
341
- block = replace_8bit_linear(block)
342
-
343
- block = block.to(device)
344
- for param in block.parameters():
345
- param.requires_grad = False
346
-
347
- blocks[module_uid] = TransformerBackend(
348
- module_uid,
349
- block,
350
- memory_cache=memory_cache,
351
- backend_dtype=None if torch_dtype == "auto" else torch_dtype,
352
- args_schema=(
353
- BatchTensorDescriptor(
354
- 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
355
- ),
356
- ),
357
- kwargs_schema={},
358
- outputs_schema=(
359
- BatchTensorDescriptor(
360
- 1, 2048, block_config.hidden_size, dtype=torch.float32, compression=compression
361
- ),
362
- ),
363
- min_batch_size=min_batch_size,
364
- max_batch_size=max_batch_size,
365
- )
366
-
367
- return cls(
368
- dht,
369
- blocks,
370
- throughput=throughput,
371
- num_connection_handlers=num_handlers,
372
- inference_max_length=inference_max_length,
373
- device=device,
374
- stats_report_interval=stats_report_interval,
375
- update_period=update_period,
376
- expiration=expiration,
377
- prefetch_batches=prefetch_batches,
378
- sender_threads=sender_threads,
379
- start=start,
380
- )
381
-
382
- def run_in_background(self, await_ready=True, timeout=None):
383
- """
384
- Starts ModuleContainer in a background thread. if await_ready, this method will wait until the container
385
- is ready to process incoming requests or for :timeout: seconds max.
386
- """
387
- self.start()
388
- if await_ready and not self.ready.wait(timeout=timeout):
389
- raise TimeoutError("ModuleContainer didn't notify .ready in {timeout} seconds")
390
-
391
- @property
392
- def ready(self) -> mp.synchronize.Event:
393
- """
394
- An event (multiprocessing.Event) that is set when the container is ready to process requests.
395
-
396
- Example
397
- =======
398
- >>> container.start()
399
- >>> container.ready.wait(timeout=10)
400
- >>> print("Container ready" if container.ready.is_set() else "Container didn't start in 10 seconds")
401
- """
402
- return self.runtime.ready # mp.Event that is true if self is ready to process batches
403
-
404
- def shutdown(self):
405
- """
406
- Gracefully terminate the container, process-safe.
407
- Please note that terminating container otherwise (e.g. by killing processes) may result in zombie processes.
408
- If you did already cause a zombie outbreak, your only option is to kill them with -9 (SIGKILL).
409
- """
410
- if self.module_backends:
411
- self.dht_handler_thread.stop.set()
412
- self.dht_handler_thread.join()
413
-
414
- declare_active_modules(
415
- self.dht,
416
- self.module_backends.keys(),
417
- expiration_time=get_dht_time() + self.expiration,
418
- state=ServerState.OFFLINE,
419
- throughput=self.throughput,
420
- )
421
- logger.info(f"Announced that blocks {list(self.module_backends.keys())} are offline")
422
-
423
- self.ready.clear()
424
-
425
- for handler in self.conn_handlers:
426
- handler.shutdown()
427
- logger.debug("Connection handlers terminated")
428
-
429
- if self.checkpoint_saver is not None:
430
- self.checkpoint_saver.stop.set()
431
- self.checkpoint_saver.join()
432
-
433
- logger.debug(f"Shutting down pools")
434
- for pool in self.runtime.pools:
435
- if pool.is_alive():
436
- pool.shutdown()
437
-
438
- logger.debug(f"Shutting down runtime")
439
- self.runtime.shutdown()
440
-
441
- logger.info("Module container shut down succesfully")
442
-
443
-
444
- class ModuleAnnouncerThread(threading.Thread):
445
- """Periodically announces that this container hosts the specified modules, visible to all DHT peers"""
446
-
447
- def __init__(
448
- self,
449
- module_backends: Dict[str, TransformerBackend],
450
- dht: DHT,
451
- *,
452
- throughput: float,
453
- update_period: float = 30,
454
- expiration: float,
455
- **kwargs,
456
- ):
457
- super().__init__(**kwargs)
458
- self.module_backends = module_backends
459
- self.dht = dht
460
- self.throughput = throughput
461
- self.update_period = update_period
462
- self.expiration = expiration
463
- self.stop = threading.Event()
464
-
465
- def run(self) -> None:
466
- while True:
467
- declare_active_modules(
468
- self.dht,
469
- self.module_backends.keys(),
470
- expiration_time=get_dht_time() + self.expiration,
471
- state=ServerState.ONLINE,
472
- throughput=self.throughput,
473
- )
474
- if self.stop.wait(self.update_period):
475
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/task_pool.py DELETED
@@ -1,178 +0,0 @@
1
- import ctypes
2
- import multiprocessing as mp
3
- import threading
4
- import time
5
- from dataclasses import dataclass, field
6
- from queue import PriorityQueue
7
- from typing import Any, Generator, List, Optional, Sequence, Tuple
8
-
9
- import torch
10
- from hivemind import MPFuture, get_logger, use_hivemind_log_handler
11
- from hivemind.moe.server.task_pool import TaskPoolBase
12
-
13
- use_hivemind_log_handler("in_root_logger")
14
- logger = get_logger(__file__)
15
-
16
-
17
- @dataclass(order=True, frozen=True)
18
- class Task:
19
- priority: float
20
- time_submitted: float
21
- future: MPFuture = field(compare=False)
22
- args: Sequence[torch.Tensor] = field(compare=False)
23
-
24
- @property
25
- def uid(self) -> int:
26
- return self.future._uid
27
-
28
-
29
- class PrioritizedTaskPool(TaskPoolBase):
30
- """
31
- Aggregates requests from multiple ConnectionHandler instances, orders them for processing in Runtime, then
32
- returns results (or exception) to the corresponding ConnectionHandler. Runs a background process.
33
- A single PrioritizedTaskPool services a specific function (e.g. layer1.forward, layer2.forward or layer1.backward)
34
-
35
- :note: unlike hivemind.moe TaskPool, this pool does *not* combine incoming requests into batches.
36
- This would require grouping requests of different length.
37
-
38
- :param process_func: function to be applied to every formed batch; called by Runtime
39
- Note that process_func should accept only positional args (Tensors) and return a flat tuple of Tensors
40
- :param max_batch_size: process at most this many inputs in a batch (task contains have one or several inputs)
41
- Measured in the total number of tokens (i.e. batch size * sequence length)
42
-
43
- :param name: pool name, used for logging
44
- :param min_batch_size: process at least this many inputs in a batch, otherwise wait for more
45
- :param start: if True, start automatically at the end of __init__
46
- """
47
-
48
- def __init__(
49
- self,
50
- process_func: callable,
51
- max_batch_size: int,
52
- name: str,
53
- min_batch_size=1,
54
- daemon=True,
55
- start=False,
56
- ):
57
- super().__init__(process_func, daemon=daemon, name=name)
58
- self.min_batch_size, self.max_batch_size = min_batch_size, max_batch_size
59
-
60
- self.submitted_tasks = mp.SimpleQueue() # interaction with ConnectionHandlers
61
- self._ordered_tasks = PriorityQueue() # interaction with Runtime - only valid inside Runtime
62
-
63
- self._prioritizer_thread = threading.Thread(
64
- name=self.name + "_prioritizer",
65
- target=self._prioritize_tasks,
66
- args=[self.submitted_tasks, self._ordered_tasks],
67
- daemon=True,
68
- )
69
- self._dispatched_tasks = {}
70
- self.batch_receiver, self.batch_sender = mp.Pipe(duplex=False)
71
- self._oldest_undispatched_timestamp = mp.Value(ctypes.c_double, 1.0)
72
- self.priority = float("inf"), float("inf") # (first task priority, first task timestamp)
73
-
74
- self._stop = mp.Event()
75
- if start:
76
- self.start()
77
-
78
- @staticmethod
79
- def _prioritize_tasks(submitted_tasks: mp.SimpleQueue, ordered_tasks: PriorityQueue):
80
- """Read tasks from incoming queue and put them into a local priority queue"""
81
- while True:
82
- task = submitted_tasks.get()
83
- if task is None:
84
- logger.debug("Shutting down prioritizer thread")
85
- break
86
-
87
- ordered_tasks.put(task, block=True)
88
-
89
- def start(self):
90
- assert not self.is_alive() and not self._prioritizer_thread.is_alive()
91
- self._prioritizer_thread.start()
92
- super().start()
93
-
94
- def shutdown(self, timeout: float = 3):
95
- self.submitted_tasks.put(None) # Shuts down self._prioritizer_thread
96
- self._stop.set()
97
-
98
- self.join(timeout)
99
- if self.is_alive():
100
- logger.warning(f"{self.__class__.__name__} failed to shut down gracefully, sending SIGTERM")
101
- self.terminate()
102
-
103
- def submit_task(self, *args: torch.Tensor, priority: float = 0.0) -> MPFuture:
104
- """Add task to this pool's queue, return Future for its output"""
105
- task = Task(priority, time.monotonic(), MPFuture(), args)
106
- if self.get_task_size(task) > self.max_batch_size:
107
- exc = ValueError(f"Task size greater than max_batch_size ({self.max_batch_size}), it can't be processed")
108
- task.future.set_exception(exc)
109
- else:
110
- self.submitted_tasks.put(task)
111
- self.batch_sender.send(None) # use this pipe to count the number of unfinished batches
112
- if (task.priority, task.time_submitted) < self.priority:
113
- self.priority = (task.priority, task.time_submitted)
114
- return task.future
115
-
116
- def get_task_size(self, task: Task) -> int:
117
- """compute task processing complexity; defaults to the total number of tokens"""
118
- if task.args and task.args[0].ndim >= 2:
119
- return task.args[0].shape[0] * task.args[0].shape[1]
120
- return 1
121
-
122
- def load_batch_to_runtime(
123
- self, timeout: Optional[float] = None, device: Optional[torch.device] = None
124
- ) -> Tuple[Any, List[torch.Tensor]]:
125
- """receive next batch of arrays"""
126
- task = self._ordered_tasks.get(block=True, timeout=timeout)
127
- batch_inputs = [
128
- tensor.detach().to(device, non_blocking=True).requires_grad_(tensor.requires_grad) for tensor in task.args
129
- ]
130
- self._dispatched_tasks[task.uid] = task
131
- self.batch_receiver.recv() # reduce the number of active batches
132
- if not self._ordered_tasks.empty():
133
- first_remaining_task: Task = self._ordered_tasks.queue[0]
134
- self.priority = (first_remaining_task.priority, first_remaining_task.time_submitted)
135
- return task.uid, batch_inputs
136
-
137
- def send_outputs_from_runtime(self, uid: int, batch_outputs: List[torch.Tensor]):
138
- """send results for a processed batch, previously loaded through load_batch_to_runtime"""
139
- batch_outputs = [
140
- tensor.to(device="cpu").share_memory_().detach().requires_grad_(tensor.requires_grad)
141
- for tensor in batch_outputs
142
- ]
143
-
144
- task = self._dispatched_tasks.pop(uid, None)
145
- if task is None:
146
- logger.error(
147
- f"Internal error: task task with index {uid} is missing from the dictionary; " f"Could not set result"
148
- )
149
- else:
150
- task.future.set_result(batch_outputs)
151
-
152
- def send_exception_from_runtime(self, uid: int, exception: BaseException):
153
- task = self._dispatched_tasks.pop(uid, None)
154
- if task is None:
155
- logger.error(
156
- f"Internal error: task task with index {uid} is missing from the dictionary; "
157
- f"Could not set exception {exception}"
158
- )
159
- else:
160
- task.future.set_exception(exception)
161
-
162
- def run(self, *args, **kwargs):
163
- self._stop.wait()
164
-
165
- @property
166
- def empty(self):
167
- return not self.batch_receiver.poll()
168
-
169
- @property
170
- def priority(self) -> Tuple[float, float]:
171
- """The priority of this pool equals the (priority, timestamp) of the most important task in it."""
172
- return float(self._priority.value), float(self._oldest_undispatched_timestamp.value)
173
-
174
- @priority.setter
175
- def priority(self, item: Tuple[float, float]):
176
- assert len(item) == 2
177
- self._priority.value = float(item[0])
178
- self._oldest_undispatched_timestamp.value = float(item[1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/task_prioritizer.py DELETED
@@ -1,20 +0,0 @@
1
- from abc import ABC, abstractmethod
2
-
3
- import torch
4
- from hivemind.moe.server.task_pool import Task
5
-
6
-
7
- class TaskPrioritizerBase(ABC):
8
- """Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
9
-
10
- @abstractmethod
11
- def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
12
- """Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
13
- pass
14
-
15
-
16
- class DummyTaskPrioritizer(TaskPrioritizerBase):
17
- """Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
18
-
19
- def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
20
- return 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/server/throughput.py DELETED
@@ -1,127 +0,0 @@
1
- import fcntl
2
- import json
3
- import os
4
- import subprocess
5
- import tempfile
6
- import time
7
- from dataclasses import asdict, dataclass
8
- from pathlib import Path
9
- from typing import Dict, Union
10
-
11
- import torch
12
- from hivemind.utils.logging import get_logger, use_hivemind_log_handler
13
-
14
- from src import project_name
15
- from src.bloom.block import BloomBlock
16
- from src.bloom.model import BloomConfig
17
- from src.bloom.ops import build_alibi_tensor
18
-
19
- use_hivemind_log_handler("in_root_logger")
20
- logger = get_logger(__file__)
21
-
22
-
23
- DEFAULT_CACHE_PATH = Path(Path.home(), ".cache", project_name, "throughput.json")
24
- DEFAULT_LOCK_PATH = Path(tempfile.gettempdir(), project_name, "throughput.lock")
25
-
26
- SPEED_TEST_PATH = Path(Path(__file__).absolute().parents[2], "cli", "speed_test.py")
27
-
28
-
29
- @dataclass
30
- class ThroughputInfo:
31
- network_rps: float
32
- device_rps: Dict[str, float]
33
-
34
-
35
- def get_host_throughput(
36
- device: Union[str, torch.device],
37
- force_eval: bool = False,
38
- cache_path: str = DEFAULT_CACHE_PATH,
39
- lock_path: str = DEFAULT_LOCK_PATH,
40
- ) -> float:
41
- # We only keep the device type, assuming that the throughput is similar among all host's GPUs
42
- device = torch.device(device).type
43
-
44
- # We use the system-wide lock since only one process at a time can measure the host throughput
45
- os.makedirs(lock_path.parent, exist_ok=True)
46
- with open(lock_path, "wb") as lock_fd:
47
- logger.info("Loading throughput info")
48
- fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
49
- # The OS will release the lock when lock_fd is closed or the process is killed
50
-
51
- info = None
52
- try:
53
- if not force_eval and os.path.exists(cache_path):
54
- with open(cache_path) as cache_fd:
55
- info = ThroughputInfo(**json.load(cache_fd))
56
- if device not in info.device_rps:
57
- force_eval = True
58
- except Exception:
59
- logger.exception(f"Failed to read throughput info from {cache_path}")
60
- force_eval = True
61
-
62
- if force_eval or info is None:
63
- info = measure_throughput_info()
64
- try:
65
- os.makedirs(cache_path.parent, exist_ok=True)
66
- with open(cache_path, "w") as cache_fd:
67
- json.dump(asdict(info), cache_fd)
68
- except Exception:
69
- logger.exception(f"Failed to save throughput info in {cache_path}")
70
-
71
- throughput = min(info.network_rps, info.device_rps[device])
72
- return throughput
73
-
74
-
75
- def measure_throughput_info() -> ThroughputInfo:
76
- logger.info(
77
- "Measuring network, CPU, and GPU throughput. " "This takes about a minute and will be cached for future runs"
78
- )
79
-
80
- # We measure throughput in "(inference) requests per second" (RPS) using a fixed model
81
- config = BloomConfig.from_pretrained("bigscience/test-bloomd-6b3")
82
-
83
- network_rps = measure_network_rps(config)
84
-
85
- device_rps = {"cpu": measure_device_rps("cpu", config)}
86
- if torch.cuda.is_available():
87
- device_rps["cuda"] = measure_device_rps("cuda", config)
88
-
89
- return ThroughputInfo(network_rps=network_rps, device_rps=device_rps)
90
-
91
-
92
- def measure_network_rps(config: BloomConfig) -> float:
93
- proc = subprocess.run([SPEED_TEST_PATH, "--json"], capture_output=True)
94
- if proc.returncode != 0:
95
- raise RuntimeError(f"Failed to measure network throughput (stdout: {proc.stdout}, stderr: {proc.stderr})")
96
- network_info = json.loads(proc.stdout)
97
-
98
- bits_per_request = config.hidden_size * 32
99
- network_rps = min(network_info["download"], network_info["upload"]) / bits_per_request
100
-
101
- logger.info(
102
- f"Network throughput: "
103
- f"{network_info['download'] / 1e6:.2f} Mbit/s on download, "
104
- f"{network_info['upload'] / 1e6:.2f} Mbit/s on upload, "
105
- f"{network_rps:.2f} RPS"
106
- )
107
- return network_rps
108
-
109
-
110
- def measure_device_rps(device: str, config: BloomConfig, layer_index: int = 0, n_steps: int = 500) -> float:
111
- with torch.inference_mode():
112
- block = BloomBlock(config, layer_index).to(device)
113
- cache = None
114
- elapsed = 0
115
- for i in range(n_steps):
116
- dummy_input = torch.randn(1, 1, config.hidden_size, device=device)
117
- alibi = build_alibi_tensor(i + 1, config.num_attention_heads, dtype=torch.float32, device=device)
118
-
119
- start_time = time.perf_counter()
120
- _, cache = block.forward(dummy_input, alibi=alibi, use_cache=True, layer_past=cache)
121
- elapsed += time.perf_counter() - start_time
122
- device_rps = n_steps / elapsed
123
-
124
- device_name = f"{torch.cuda.get_device_name(0)} GPU" if device == "cuda" else "CPU"
125
- logger.info(f"Compute throughput ({device_name}): {device_rps:.2f} RPS")
126
-
127
- return device_rps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/utils/__init__.py DELETED
File without changes
petals/src/utils/convert_8bit.py DELETED
@@ -1,41 +0,0 @@
1
- import os
2
-
3
- import bitsandbytes as bnb
4
- import torch
5
-
6
- PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 0)))
7
-
8
-
9
- def replace_8bit_linear(model, threshold=6.0):
10
- """
11
- A helper function to convert all `torch.nn.Linear` modules to `bnb.nn.Linear8bit` modules from the `bitsandbytes`
12
- library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
13
- 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
14
- version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
15
- bitsandbytes-cudaXXX` with `XXX` is your CUDA version (e.g., 11.6 = 116)
16
- The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` and 'score' that should
17
- be kept as a `torch.nn.Linear` module.
18
- Parameters:
19
- model (`torch.nn.Module`):
20
- Input model or `torch.nn.Module` as the function is run recursively.
21
- threshold (`float`, *optional*):
22
- `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
23
- `6.0` as described by the paper.
24
- """
25
- for n, module in model.named_children():
26
- if len(list(module.children())) > 0:
27
- replace_8bit_linear(module, threshold)
28
-
29
- if isinstance(module, torch.nn.Linear) and n not in ["lm_head", "score"]:
30
- model._modules[n] = bnb.nn.Linear8bitLt(
31
- module.in_features,
32
- module.out_features,
33
- module.bias is not None,
34
- has_fp16_weights=False,
35
- threshold=threshold,
36
- memory_efficient_backward=PETALS_8BIT_BACKWARD,
37
- )
38
- model._modules[n].weight = bnb.nn.Int8Params(
39
- module.weight.data, requires_grad=False, has_fp16_weights=False
40
- ).to(module.weight.dtype)
41
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/utils/generation_algorithms.py DELETED
@@ -1,78 +0,0 @@
1
- from abc import ABC
2
- from typing import Tuple
3
-
4
- import torch
5
-
6
- TokenIds = torch.Tensor
7
- HypoIds = torch.Tensor
8
-
9
-
10
- class DecodingAlgorithm(ABC):
11
- """
12
- An abstract class for decoding algorithms. Describe base function of those algorithms: they have to select new tokens and provide the corresponding hypothesis.
13
- """
14
-
15
- def __init__(self) -> None:
16
- pass
17
-
18
- def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
19
- """
20
- :param logits: A tensor of shape (batch_size, seq_lenth, vocab_size)
21
- :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size)
22
- """
23
- pass
24
-
25
-
26
- class GreedyAlgorithm(DecodingAlgorithm):
27
- """
28
- The simpliest algorithm for decoding. It selects the most probable token.
29
- """
30
-
31
- def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
32
- """
33
- Returns the most propable token. The second return object always are range of integers from 0 to batch_size - 1.
34
- """
35
- return logits.max(-1)[1].unsqueeze(1), torch.arange(logits.size(0))
36
-
37
-
38
- class SamplingAlgorithm(DecodingAlgorithm):
39
- def sample(self, logits: torch.Tensor, indices_to_remove: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
40
- """
41
- :param logits: A tensor of shape (batch_size * num_hypos, vocab_size)
42
- :param indices_to_remove: A bool tensor of shape (batch_size * num_hypos, vocab_size)
43
- :return: A tuple of selected token ids and corresponding hypothesis. The shape of the token ids is (batch_size, seq_length) and the shape of the hypothesis is (batch_size).
44
- """
45
- logits[indices_to_remove] = -float("Inf")
46
- probs = torch.softmax(logits / self.temperature, -1)
47
- return torch.multinomial(probs, num_samples=1), torch.arange(logits.size(0))
48
-
49
-
50
- class TopKAlgorithm(SamplingAlgorithm):
51
- # TODO: Add NumHypos, maxBatchSize
52
- def __init__(self, top_k: int, temperature: float = 1.0) -> None:
53
- self.top_k = top_k
54
- self.temperature = temperature
55
-
56
- def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
57
- indices_to_remove = logits < torch.topk(logits, self.top_k, dim=-1)[0][..., -1, None]
58
- return self.sample(logits, indices_to_remove)
59
-
60
-
61
- class NucleusAlgorithm(SamplingAlgorithm):
62
- def __init__(self, top_p: float, temperature: float = 1.0) -> None:
63
- self.top_p = top_p
64
- self.temperature = temperature
65
-
66
- def __call__(self, logits: torch.Tensor) -> Tuple[TokenIds, HypoIds]:
67
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
68
- probs = torch.softmax(sorted_logits / self.temperature, -1)
69
- cumulative_probs = torch.cumsum(probs, dim=-1)
70
- sorted_indices_to_remove = cumulative_probs > self.top_p
71
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
72
- sorted_indices_to_remove[..., 0] = False
73
- indices_to_remove = torch.zeros_like(sorted_indices_to_remove)
74
- indices_to_remove.scatter_(-1, sorted_indices, sorted_indices_to_remove)
75
- return self.sample(logits, indices_to_remove)
76
-
77
-
78
- # TODO: In generate function we need to check usage of top_k or sampling algorithm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
petals/src/utils/generation_constraints.py DELETED
@@ -1,84 +0,0 @@
1
- from abc import ABC
2
-
3
- import torch
4
-
5
-
6
- class ABCBloomConstraint(ABC):
7
- """
8
- Base class of all kind of decoding constraints. It can be used to implement a new constraint.
9
- """
10
-
11
- def __init__(self) -> None:
12
- pass
13
-
14
- def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
15
- """
16
- This method is called by the decoding algorithm to apply the constraint. It changes and returns new logits.
17
- :param tokens_id: The token id of the last choosen token.
18
- :param logits: The logits from the Bloom model.
19
- :param hypo_ids: The hypothesis ids of the last tokens.
20
- """
21
- pass
22
-
23
-
24
- class MaxNewTokensConstraint(ABCBloomConstraint):
25
- """
26
- Constraint that forbids to generate more than max_new_tokens tokens after the prefix.
27
-
28
- Args:
29
- prefix: The prefix of the sequence.
30
- max_new_tokens: The maximum number of tokens that can be generated after the prefix.
31
- eos_token_id: The id of the end of sentence token.
32
- pad_token_id: The id of the padding token.
33
- min_logits: The minimum logits that can be generated. Default: -1e6.
34
- """
35
-
36
- def __init__(
37
- self, prefix: torch.Tensor, max_new_tokens: int, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8
38
- ) -> None:
39
- self.max_new_tokens = max_new_tokens
40
- self.current_generated_tokens = None
41
- self.eos_token_id = eos_token_id
42
- self.min_logits = min_logits
43
-
44
- max_pad_size = (prefix == pad_token_id).sum(1).unsqueeze(1).max()
45
- self.current_generated_tokens = (prefix == pad_token_id).sum(1).unsqueeze(1) - max_pad_size
46
-
47
- def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
48
- if tokens_id is not None:
49
- self.current_generated_tokens += 1
50
-
51
- mask = self.current_generated_tokens >= self.max_new_tokens
52
- logits += self.min_logits * mask
53
- logits[mask[:, 0], self.eos_token_id] = 0
54
- return logits
55
-
56
-
57
- class EosConstraint(ABCBloomConstraint):
58
- """
59
- This constrained repeats EOS token if it was generated on the previous step.
60
- Args:
61
- prefix: The prefix of the sequence.
62
- eos_token_id: The id of the end of sentence token.
63
- pad_token_id: The id of the padding token.
64
- min_logits: The minimum logits that can be generated. Default: -1e6.
65
- """
66
-
67
- def __init__(self, prefix: torch.Tensor, eos_token_id: int, pad_token_id: int, min_logits: float = -1e8) -> None:
68
- self.eos_token_id = eos_token_id
69
- self.min_logits = min_logits
70
- self.past_tokens = None
71
-
72
- self.wait_until_starting = (prefix == pad_token_id).sum(1).unsqueeze(1)
73
-
74
- def __call__(self, tokens_id: torch.Tensor, logits: torch.Tensor, hypo_ids: torch.Tensor) -> torch.Tensor:
75
- if self.past_tokens is not None:
76
- mask = (self.wait_until_starting < 0) & (self.past_tokens == self.eos_token_id)
77
- logits += self.min_logits * mask
78
- logits[mask[:, 0], self.eos_token_id] = 0
79
-
80
- if tokens_id is not None:
81
- self.past_tokens = tokens_id
82
- self.wait_until_starting -= 1
83
-
84
- return logits