|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
def ensure_divisibility(numerator, denominator): |
|
"""Ensure that numerator is divisible by the denominator.""" |
|
assert numerator % denominator == 0, "{} is not divisible by {}".format( |
|
numerator, denominator |
|
) |
|
|
|
|
|
def divide(numerator, denominator): |
|
"""Ensure that numerator is divisible by the denominator and return |
|
the division value.""" |
|
ensure_divisibility(numerator, denominator) |
|
return numerator // denominator |
|
|
|
|
|
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False): |
|
"""Split a tensor along its last dimension. |
|
Arguments: |
|
tensor: input tensor. |
|
num_partitions: number of partitions to split the tensor |
|
contiguous_split_chunks: If True, make each chunk contiguous |
|
in memory. |
|
""" |
|
|
|
last_dim = tensor.dim() - 1 |
|
last_dim_size = divide(tensor.size()[last_dim], num_partitions) |
|
|
|
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) |
|
|
|
if contiguous_split_chunks: |
|
return tuple(chunk.contiguous() for chunk in tensor_list) |
|
|
|
return tensor_list |
|
|
|
|
|
def split_tensor_along_any_dim( |
|
tensor, num_partitions, seq_dim, contiguous_split_chunks=False |
|
): |
|
"""Split a tensor along a user-specified dimension. |
|
Arguments: |
|
tensor: input tensor. |
|
num_partitions: number of partitions to split the tensor |
|
seq_dim: dimension along which to split the tensor |
|
contiguous_split_chunks: If True, make each chunk contiguous |
|
in memory. |
|
""" |
|
|
|
seq_dim_size = divide(tensor.size()[seq_dim], num_partitions) |
|
|
|
tensor_list = torch.split(tensor, seq_dim_size, dim=seq_dim) |
|
|
|
if contiguous_split_chunks: |
|
return tuple(chunk.contiguous() for chunk in tensor_list) |
|
|
|
return tensor_list |
|
|
|
|
|
class VocabUtility: |
|
"""Split the vocabulary into `world_size` chunks amd return the |
|
first and last index of the vocabulary belonging to the `rank` |
|
partition: Note that indices in [first, last]""" |
|
|
|
@staticmethod |
|
def vocab_range_from_per_partition_vocab_size( |
|
per_partition_vocab_size, rank, world_size |
|
): |
|
index_f = rank * per_partition_vocab_size |
|
index_l = index_f + per_partition_vocab_size |
|
return index_f, index_l |
|
|
|
@staticmethod |
|
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size): |
|
per_partition_vocab_size = divide(global_vocab_size, world_size) |
|
return VocabUtility.vocab_range_from_per_partition_vocab_size( |
|
per_partition_vocab_size, rank, world_size |
|
) |
|
|