diff --git a/.gitattributes b/.gitattributes index 39283779a6be0638423d8f887be6c18e2bfc82f6..d567b311d9e3ce55d7c5bf0c9ee6b85c8783ee12 100644 --- a/.gitattributes +++ b/.gitattributes @@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text lib/python3.11/site-packages/llvmlite/binding/libllvmlite.dylib filter=lfs diff=lfs merge=lfs -text lib/python3.11/site-packages/mlx/core.cpython-311-darwin.so filter=lfs diff=lfs merge=lfs -text +lib/python3.11/site-packages/mlx/lib/libmlx.dylib filter=lfs diff=lfs merge=lfs -text +lib/python3.11/site-packages/mlx/lib/mlx.metallib filter=lfs diff=lfs merge=lfs -text diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h new file mode 100644 index 0000000000000000000000000000000000000000..bdd1419f2ebd61b1c1b50baa46be2f2c3991be31 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/defines.h @@ -0,0 +1,16 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#ifdef __METAL__ +#define MTL_CONST constant +#else +#define MTL_CONST +#endif + +static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; +static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; +static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; +static MTL_CONST constexpr int REDUCE_N_READS = 16; +static MTL_CONST constexpr int SOFTMAX_N_READS = 4; +static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096; diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h new file mode 100644 index 0000000000000000000000000000000000000000..0a370d3048bc26dcd6b5a6a33eb13370c341a3ce --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/erf.h @@ -0,0 +1,70 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +/* + * Approximation to the error function. + * Based on code from: + * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 + */ +float erf(float a) { + float r, s, t, u; + t = metal::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = metal::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = metal::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = metal::fma(r, s, u); + r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = metal::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - metal::exp(r); + r = metal::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = metal::fma(r, a, a); + } + return r; +} + +float erfinv(float a) { + auto t = metal::fma(a, 0.0f - a, 1.0f); + t = metal::log(t); + float p; + if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} \ No newline at end of file diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h new file mode 100644 index 0000000000000000000000000000000000000000..1db3ebac895d25f5563745ba9bda3fb233990950 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/conv.h @@ -0,0 +1,481 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/conv_params.h" + +#define MLX_MTL_CONST static constant constexpr const + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int vec_size, + int tgp_size, + int tgp_padding = 0> +struct Conv2DInputBlockLoader { + // Destination dimensions + MLX_MTL_CONST int dst_fd = BM; + MLX_MTL_CONST int dst_ld = BK + tgp_padding; + MLX_MTL_CONST int n_vecs = BK / vec_size; + + // Stride along block row within the block + MLX_MTL_CONST int bstride = tgp_size / n_vecs; + MLX_MTL_CONST int n_rows = dst_fd / bstride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>& params; + + int weight_h; + int weight_w; + + int offsets_n[n_rows]; + int offsets_oh[n_rows]; + int offsets_ow[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoader( + const device T* src_, + threadgroup T* dst_, + const constant MLXConvParams<2>& params_, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / n_vecs), + bj(vec_size * (thread_idx % n_vecs)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bj), + params(params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params.oS[0] * params.oS[1]; + + for (int i = 0; i < n_rows; ++i) { + int offset_nhw = tid.y * BM + bi + i * bstride; + offsets_n[i] = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + offsets_oh[i] = hw / params.oS[1]; + offsets_ow[i] = hw % params.oS[1]; + } + + (void)lid; + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { +#pragma clang loop unroll(full) + for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) { + int n = offsets_n[i]; + int oh = offsets_oh[i]; + int ow = offsets_ow[i]; + + int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0]; + int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1]; + + // Read from input if in bounds + if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) { + const device T* curr_src = src + n * params.in_strides[0] + + ih * params.in_strides[1] + iw * params.in_strides[2]; + +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = curr_src[j]; + } + } + + // Zero pad otherwise + else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params.wS[1]) { + return; + } + + weight_w = 0; + + if (++weight_h < params.wS[0]) { + return; + } + + weight_h = 0; + + src += BK; + } +}; + +template < + typename T, + int BM, + int BN, + int BK, + int vec_size, + int tgp_size, + int tgp_padding = 0> +struct Conv2DWeightBlockLoader { + // Destination dimensions + MLX_MTL_CONST int dst_fd = BN; + MLX_MTL_CONST int dst_ld = BK + tgp_padding; + MLX_MTL_CONST int n_vecs = BK / vec_size; + + // Stride along block row within the block + MLX_MTL_CONST int bstride = tgp_size / n_vecs; + MLX_MTL_CONST int n_rows = dst_fd / bstride; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>& params; + + int weight_h; + int weight_w; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoader( + const device T* src_, + threadgroup T* dst_, + const constant MLXConvParams<2>& params_, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_.wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / n_vecs), + bj(vec_size * (thread_idx % n_vecs)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + weight_h(0), + weight_w(0) { + (void)lid; + (void)tid; + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + const device T* curr_src = + src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2]; +#pragma clang loop unroll(full) + for (short i = 0; i < dst_fd; i += bstride) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params.wS[1]) { + return; + } + + weight_w = 0; + + if (++weight_h < params.wS[0]) { + return; + } + + weight_h = 0; + + src += BK; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Transforms +/////////////////////////////////////////////////////////////////////////////// + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + int tgp_padding_a = 0, + int tgp_padding_b = 0, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct Conv2DBlockMMA { + // Warp tile size along M + MLX_MTL_CONST int TM = BM / (WM * 8); + // Warp tile size along N + MLX_MTL_CONST int TN = BN / (WN * 8); + + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TN_stride = 8 * WN; + + // Leading dimensions of threadgroup A, B blocks + MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a; + MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b; + + // Strides of A, B along reduction axis + MLX_MTL_CONST short simd_stride_a = + transpose_a ? TM_stride : TM_stride * lda_tgp; + MLX_MTL_CONST short simd_stride_b = + transpose_b ? TN_stride * ldb_tgp : TN_stride; + + // Jump between elements + MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1; + MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1; + + // Offsets within threadgroup + const int tm; + const int tn; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + short sm; + short sn; + + /* Constructor */ + METAL_FUNC Conv2DBlockMMA( + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { +// Iterate over BK in blocks of 8 +#pragma clang loop unroll(full) + for (short kk = 0; kk < BK; kk += 8) { + short2 offset_a = + transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm); + short2 offset_b = + transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm); + + const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x; + const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x; + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup A as simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = static_cast(As__[0]); + Asimd[i].thread_elements()[1] = static_cast(As__[jump_a]); + As__ += simd_stride_a; + } + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup B as simdgroup matrices +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = static_cast(Bs__[0]); + Bsimd[j].thread_elements()[1] = static_cast(Bs__[jump_b]); + Bs__ += simd_stride_b; + } + + simdgroup_barrier(mem_flags::mem_none); +// Multiply and accumulate into result simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + simdgroup_multiply_accumulate( + results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device T* C, const int ldc) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + + METAL_FUNC void + store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { + if (tm + i * TM_stride + sm < dst_tile_dims.y) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + if (tn + j * TN_stride + sn < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + } + + if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct Conv2DImplicitGEMMKernel { + MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T); + MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T); + MLX_MTL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + MLX_MTL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + MLX_MTL_CONST short tgp_size = WM * WN * 32; + MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4; + + using loader_a_t = + Conv2DInputBlockLoader; + using loader_b_t = + Conv2DWeightBlockLoader; + using mma_t = Conv2DBlockMMA< + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + tgp_padding_a, + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const int K = params.wt_strides[0]; + const int N = params.O; + + B += c_col * K; + C += c_row * N + c_col; + + // Prepare threadgroup memory for loading + threadgroup T* As = tgp_memory; + threadgroup T* Bs = tgp_memory + tgp_mem_size_a; + + // Prepare threadgroup loading operations + loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid); + loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + mma_op.store_result(C, N); + } +}; \ No newline at end of file diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..95d2e6497f55eb8d464ce7a11cfeb354d8602e0f --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/gemm/gemm.h @@ -0,0 +1,538 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +#define MLX_MTL_CONST static constant constexpr const + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BROWS, + int BCOLS, + int BK, + int vec_size, + int tgp_size, + bool transpose, + bool ldK, + int tgp_padding = 0> +struct BlockLoader { + // Destination dimensions + MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS; + MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding; + MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size; + + // Stride along block row within the block + MLX_MTL_CONST int bstride = tgp_size / n_vecs; + + // Leading dimension for src + const int src_ld; + // Stride along reduction axis between blocks + const int tstride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tstride( + BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / n_vecs), + bj(vec_size * (thread_idx % n_vecs)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { +#pragma clang loop unroll(full) + for (short i = 0; i < dst_fd; i += bstride) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy; + + // Iterate over rows of block +#pragma clang loop unroll(full) + for (short i = 0; i < dst_fd; i += bstride) { + // Row is in bounds, we check against column + if ((bi + i) < src_tile_dim.y) { + // Use fast thread memory for bound checks + short tmp_idx[vec_size]; + T tmp_val[vec_size]; + + // Make sure tmp_idx only contains valid indices +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0; + } + + // Read all valid indices into tmp_val +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[i * src_ld + tmp_idx[j]]; + } + + // Zero out unneeded values +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + + // Row is out of bounds, we just fill tgp memory with zeros + else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tstride; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Transforms +/////////////////////////////////////////////////////////////////////////////// + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + int tgp_padding_a = 0, + int tgp_padding_b = 0, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct BlockMMA { + // Warp tile size along M + MLX_MTL_CONST int TM = BM / (WM * 8); + // Warp tile size along N + MLX_MTL_CONST int TN = BN / (WN * 8); + + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TN_stride = 8 * WN; + + // Leading dimensions of threadgroup A, B blocks + MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a; + MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b; + + // Strides of A, B along reduction axis + MLX_MTL_CONST short simd_stride_a = + transpose_a ? TM_stride : TM_stride * lda_tgp; + MLX_MTL_CONST short simd_stride_b = + transpose_b ? TN_stride * ldb_tgp : TN_stride; + + // Jump between elements + MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1; + MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1; + + // Offsets within threadgroup + const int tm; + const int tn; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + short sm; + short sn; + + /* Constructor */ + METAL_FUNC BlockMMA( + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { +// Iterate over BK in blocks of 8 +#pragma clang loop unroll(full) + for (short kk = 0; kk < BK; kk += 8) { + short2 offset_a = + transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm); + short2 offset_b = + transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm); + + const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x; + const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x; + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup A as simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = static_cast(As__[0]); + Asimd[i].thread_elements()[1] = static_cast(As__[jump_a]); + As__ += simd_stride_a; + } + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup B as simdgroup matrices +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = static_cast(Bs__[0]); + Bsimd[j].thread_elements()[1] = static_cast(Bs__[jump_b]); + Bs__ += simd_stride_b; + } + + simdgroup_barrier(mem_flags::mem_none); +// Multiply and accumulate into result simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + simdgroup_multiply_accumulate( + results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device T* C, const int ldc) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + + METAL_FUNC void + store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { + if (tm + i * TM_stride + sm < dst_tile_dims.y) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + if (tn + j * TN_stride + sn < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + } + + if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T); + MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T); + MLX_MTL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + MLX_MTL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + MLX_MTL_CONST short tgp_size = WM * WN * 32; + MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4; + + using loader_a_t = BlockLoader< + T, + BM, + BK, + BK, + vec_size, + tgp_size, + transpose_a, + true, + tgp_padding_a>; + using loader_b_t = BlockLoader< + T, + BK, + BN, + BK, + vec_size, + tgp_size, + transpose_b, + false, + tgp_padding_b>; + using mma_t = BlockMMA< + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + tgp_padding_a, + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant int& M [[buffer(3)]], + const constant int& N [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& batch_stride_a [[buffer(6)]], + const constant int& batch_stride_b [[buffer(7)]], + const constant int& batch_stride_c [[buffer(8)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + // Adjust for batch + A += batch_stride_a * tid.z; + B += batch_stride_b * tid.z; + C += batch_stride_c * tid.z; + + // Adjust for transpose + const int lda_dev = transpose_a ? M : K; + const int ldb_dev = transpose_b ? K : N; + + // Find block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + + A += transpose_a ? c_row : c_row * K; + B += transpose_b ? c_col * K : c_col; + C += c_row * N + c_col; + + // Prepare threadgroup memory for loading + threadgroup T* As = tgp_memory; + threadgroup T* Bs = tgp_memory + tgp_mem_size_a; + + // Prepare threadgroup loading operations + loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id); + loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned && K_aligned) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + mma_op.store_result(C, N); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN aligned, K unaligned loop + else if (MN_aligned && !K_aligned) { + // Main loop + int k = 0; + for (; k + BK <= K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Loop tail + threadgroup_barrier(mem_flags::mem_threadgroup); + + loader_a.load_safe(short2(K - k, BM)); + loader_b.load_safe(short2(BN, K - k)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + + // Store results to device memory + mma_op.store_result(C, N); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MNK unaligned loop + else { // Loop over K - unaligned case + + short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row)); + + if (src_tile_dims.y == BM && src_tile_dims.x == BN) { + int k = 0; + for (; k + BK <= K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + if (k < K) { + loader_a.load_safe(short2(K - k, BM)); + loader_b.load_safe(short2(BN, K - k)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + mma_op.store_result(C, N); + return; + + } else { + int k = 0; + for (; k + BK <= K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_safe(short2(BK, src_tile_dims.y)); + loader_b.load_safe(short2(src_tile_dims.x, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + if (k < K) { + loader_a.load_safe(short2(K - k, src_tile_dims.y)); + loader_b.load_safe(short2(src_tile_dims.x, K - k)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + threadgroup_barrier(mem_flags::mem_none); + mma_op.store_result_safe(C, N, src_tile_dims); + + return; + } + } + } +}; \ No newline at end of file diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..1d2b971b29dba8fb78b9b3cca7da6b19c0fd39ce --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/reduce.h @@ -0,0 +1,176 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/atomic.h" +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +union bool4_or_uint { + bool4 b; + unsigned int i; +}; + +struct None { + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_store_explicit(out, val, offset); + } +}; + +struct And { + bool simd_reduce(bool val) { + return simd_all(val); + }; + + static constexpr constant bool init = true; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + int offset = 0) { + if (!val) { + bool4_or_uint update; + update.b = {true, true, true, true}; + update.b[elem_idx] = false; + mlx_atomic_fetch_and_explicit(out, update.i, offset); + } + } + + void atomic_update(device mlx_atomic* out, bool val, int offset = 0) { + if (!val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out &= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a && b; + } +}; + +struct Or { + bool simd_reduce(bool val) { + return simd_any(val); + }; + + static constexpr constant bool init = false; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + int offset = 0) { + if (val) { + bool4_or_uint update; + update.b = {false, false, false, false}; + update.b[elem_idx] = true; + mlx_atomic_fetch_or_explicit(out, update.i, offset); + } + } + + void atomic_update(device mlx_atomic* out, bool val, int offset = 0) { + if (val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out |= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a || b; + } +}; + +template +struct Sum { + template + T simd_reduce(T val) { + return simd_sum(val); + }; + + static constexpr constant U init = U(0); + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_add_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a + b; + } +}; + +template +struct Prod { + template + T simd_reduce(T val) { + return simd_product(val); + }; + + static constexpr constant U init = U(1); + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_mul_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a * b; + } +}; + +template +struct Min { + template + T simd_reduce(T val) { + return simd_min(val); + }; + + static constexpr constant U init = Limits::max; + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_min_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a < b ? a : b; + } +}; + +template +struct Max { + template + T simd_reduce(T val) { + return simd_max(val); + }; + + static constexpr constant U init = Limits::min; + + template + void atomic_update(device mlx_atomic* out, T val, int offset = 0) { + mlx_atomic_fetch_max_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a > b ? a : b; + } +}; diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..72cdd8b20446a5a802895de2755108dadfedc0ab --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/kernels/utils.h @@ -0,0 +1,246 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/complex.h" + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static const constant U max; + static const constant U min; + static const constant U finite_max; + static const constant U finite_min; +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; + +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); + +template <> +struct Limits { + static constexpr constant bool max = true; + static constexpr constant bool min = false; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +inline size_t elem_to_loc( + uint elem, + device const int* shape, + device const size_t* strides, + int ndim) { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +inline size_t elem_to_loc( + uint elem, + constant const int* shape, + constant const size_t* strides, + int ndim) { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +inline uint2 elem_to_loc_2_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t a_strides[NDIM], + constant const size_t b_strides[NDIM]) { + uint2 loc = { + static_cast( + elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), + static_cast( + elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; + for (int d = NDIM - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +template +inline size_t elem_to_loc_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t strides[NDIM]) { + size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; + for (int d = NDIM - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) { + return elem * stride; +} + +inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) { + return elem.x * strides[1] + elem.y * strides[0]; +} + +inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) { + return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +} + +// Non templated version to handle arbitrary dims +inline size_t elem_to_loc( + uint3 elem, + constant const int* shape, + constant const size_t* strides, + int ndim) { + size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +inline uint2 elem_to_loc_2_nd( + uint3 elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + uint2 loc = { + static_cast( + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + static_cast( + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +template +inline uint elem_to_loc_nd( + uint elem, + device const int* shape, + device const size_t* strides); + +template <> +inline uint elem_to_loc_nd<1>( + uint elem, + device const int* shape, + device const size_t* strides) { + return (elem % shape[0]) * strides[0]; +} + +template <> +inline uint elem_to_loc_nd<2>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +template <> +inline uint elem_to_loc_nd<3>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[2]) * strides[2]; + elem /= shape[2]; + loc += (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +template <> +inline uint elem_to_loc_nd<4>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[3]) * strides[3]; + elem /= shape[3]; + loc += (elem % shape[2]) * strides[2]; + elem /= shape[2]; + loc += (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Calculation utils +/////////////////////////////////////////////////////////////////////////////// + +/** Compute ceil((float)N/(float)M) */ +inline size_t ceildiv(size_t N, size_t M) { + return (N + M - 1) / M; +} + +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + float xp1 = 1.0f + x; + return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f)); +} + +inline bfloat16_t log1p(bfloat16_t x) { + float xp1 = 1.0f + static_cast(x); + bfloat16_t ret = + (xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); + return ret; +} diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h new file mode 100644 index 0000000000000000000000000000000000000000..78a0e7f2695a49bbbcc737accaeaa162fe6818a1 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/matmul.h @@ -0,0 +1,31 @@ +// Copyright © 2023 Apple Inc. + +#include +#include +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/mps/gemm.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/utils.h" + +namespace mlx::core { + +void mlx_matmul( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies); + +} // namespace mlx::core \ No newline at end of file diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h new file mode 100644 index 0000000000000000000000000000000000000000..99f4009568826f417303f3399bec32c88b0100e7 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/metal.h @@ -0,0 +1,31 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::metal { + +constexpr bool is_available() { +#ifdef _METAL_ + return true; +#else + return false; +#endif +} + +void new_stream(Stream stream); +std::shared_ptr new_scoped_memory_pool(); + +std::function make_task( + array& arr, + std::vector> deps, + std::shared_ptr> p, + bool retain_graph); + +} // namespace mlx::core::metal diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..c93df68d8317356a24f342ff06635b2ffe2ff3b3 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/mps/gemm.h @@ -0,0 +1,370 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol) +#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor) + +namespace MTL::Private::Class { +_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor); +_MTL_PRIVATE_DEF_CLS(MPSMatrix); +_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor); +_MTL_PRIVATE_DEF_CLS(MPSVector); +_MTL_PRIVATE_DEF_CLS(MPSKernel); +_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication); +_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication); +} // namespace MTL::Private::Class + +namespace MTL::Private::Selector { +_MTL_PRIVATE_DEF_SEL( + matrixDescriptorWithRows_columns_rowBytes_dataType, + "matrixDescriptorWithRows:columns:rowBytes:dataType:"); +_MTL_PRIVATE_DEF_SEL( + matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType, + "matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:"); +_MTL_PRIVATE_DEF_SEL(rows, "rows"); +_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:"); +_MTL_PRIVATE_DEF_SEL( + initWithDevice_, + "initWithDevice:transposeLeft:transposeRight:" + "resultRows:resultColumns:interiorColumns:alpha:beta:"); +_MTL_PRIVATE_DEF_SEL( + encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix, + "encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:"); +_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:"); +_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:"); +_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:"); +_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:"); +_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:"); +_MTL_PRIVATE_DEF_SEL( + vectorDescriptorWithLength_dataType, + "vectorDescriptorWithLength:dataType:"); +_MTL_PRIVATE_DEF_SEL( + vectorDescriptorWithLength_vectors_vectorBytes_dataType, + "vectorDescriptorWithLength:vectors:vectorBytes:dataType:"); +_MTL_PRIVATE_DEF_SEL( + initWithDevice_transpose_rows_columns_alpha_beta, + "initWithDevice:transpose:rows:columns:alpha:beta:"); +_MTL_PRIVATE_DEF_SEL( + encodeToCommandBuffer_inputMatrix_inputVector_resultVector, + "encodeToCommandBuffer:inputMatrix:inputVector:resultVector:"); +} // namespace MTL::Private::Selector + +namespace MPS { + +typedef enum DataType : uint32_t { + DataTypeFloatBit = 0x10000000, + DataTypeAlternateEncodingBit = 0x80000000, + DataTypeFloat16 = DataTypeFloatBit | 16, + DataTypeFloat32 = DataTypeFloatBit | 32, + DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16 +} DataType; + +class MatrixDescriptor : public NS::Copying { + public: + static class MatrixDescriptor* matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger rowBytes, + NS::UInteger dataType); + static class MatrixDescriptor* matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger matrices, + NS::UInteger rowBytes, + NS::UInteger matrixBytes, + NS::UInteger dataType); + NS::UInteger rows() const; +}; + +class Matrix : public NS::Referencing { + public: + static class Matrix* alloc(); + Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor); + Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor); +}; + +class Kernel : public NS::Referencing { + public: + NS::String* label() const; + MTL::Device* device() const; +}; + +class MatrixMultiplication + : public NS::Referencing { + public: + static class MatrixMultiplication* alloc(); + + MatrixMultiplication* init( + MTL::Device* device, + bool transposeLeft, + bool transposeRight, + NS::UInteger resultRows, + NS::UInteger resultColumns, + NS::UInteger interiorColumns, + double alpha, + double beta); + + void encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* leftMatrix, + Matrix* rightMatrix, + Matrix* resultMatrix); + + void setLeftMatrixOrigin(MTL::Origin origin); + void setRightMatrixOrigin(MTL::Origin origin); + void setResultMatrixOrigin(MTL::Origin origin); + void setBatchStart(NS::UInteger batchStart); + void setBatchSize(NS::UInteger batchSize); +}; + +class VectorDescriptor : public NS::Copying { + public: + static class VectorDescriptor* vectorDescriptor( + NS::UInteger length, + NS::UInteger dataType); + static class VectorDescriptor* vectorDescriptor( + NS::UInteger length, + NS::UInteger vectors, + NS::UInteger vectorBytes, + NS::UInteger dataType); +}; + +class Vector : public NS::Referencing { + public: + static class Vector* alloc(); + Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor); + Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor); +}; + +class MatrixVectorMultiplication + : public NS::Referencing { + public: + static class MatrixVectorMultiplication* alloc(); + + MatrixVectorMultiplication* init( + MTL::Device* device, + bool transpose, + NS::UInteger rows, + NS::UInteger columns, + double alpha, + double beta); + + void encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* inputMatrix, + Vector* inputVector, + Vector* resultVector); +}; + +_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger rowBytes, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSMatrixDescriptor), + _MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType), + rows, + columns, + rowBytes, + dataType); +} + +_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger matrices, + NS::UInteger rowBytes, + NS::UInteger matrixBytes, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSMatrixDescriptor), + _MPS_PRIVATE_SEL( + matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType), + rows, + columns, + matrices, + rowBytes, + matrixBytes, + dataType); +} + +_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const { + return Object::sendMessage(this, _MPS_PRIVATE_SEL(rows)); +} + +_MTL_INLINE Matrix* Matrix::alloc() { + return NS::Object::alloc(_MPS_PRIVATE_CLS(MPSMatrix)); +} + +_MTL_INLINE Matrix* Matrix::init( + MTL::Buffer* buffer, + MatrixDescriptor* descriptor) { + return Object::sendMessage( + this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor); +} + +_MTL_INLINE Matrix* Matrix::init( + const MTL::Buffer* buffer, + MatrixDescriptor* descriptor) { + return init(const_cast(buffer), descriptor); +} + +_MTL_INLINE NS::String* Kernel::label() const { + return Object::sendMessage(this, _MPS_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::Device* Kernel::device() const { + return Object::sendMessage(this, _MPS_PRIVATE_SEL(device)); +} + +_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() { + return NS::Object::alloc( + _MPS_PRIVATE_CLS(MPSMatrixMultiplication)); +} + +_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init( + MTL::Device* device, + bool transposeLeft, + bool transposeRight, + NS::UInteger resultRows, + NS::UInteger resultColumns, + NS::UInteger interiorColumns, + double alpha, + double beta) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL(initWithDevice_), + device, + transposeLeft, + transposeRight, + resultRows, + resultColumns, + interiorColumns, + alpha, + beta); +} + +_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* leftMatrix, + Matrix* rightMatrix, + Matrix* resultMatrix) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL( + encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix), + commandBuffer, + leftMatrix, + rightMatrix, + resultMatrix); +} + +_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) { + Object::sendMessage( + this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin); +} + +_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin( + MTL::Origin origin) { + Object::sendMessage( + this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin); +} + +_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin( + MTL::Origin origin) { + Object::sendMessage( + this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin); +} + +_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) { + Object::sendMessage(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart); +} + +_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) { + Object::sendMessage(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize); +} + +_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor( + NS::UInteger length, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSVectorDescriptor), + _MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType), + length, + dataType); +} + +_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor( + NS::UInteger length, + NS::UInteger vectors, + NS::UInteger vectorBytes, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSVectorDescriptor), + _MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType), + length, + vectors, + vectorBytes, + dataType); +} + +_MTL_INLINE Vector* Vector::alloc() { + return NS::Object::alloc(_MPS_PRIVATE_CLS(MPSVector)); +} + +_MTL_INLINE Vector* Vector::init( + MTL::Buffer* buffer, + VectorDescriptor* descriptor) { + return Object::sendMessage( + this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor); +} + +_MTL_INLINE Vector* Vector::init( + const MTL::Buffer* buffer, + VectorDescriptor* descriptor) { + return init(const_cast(buffer), descriptor); +} + +_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() { + return NS::Object::alloc( + _MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication)); +} + +_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init( + MTL::Device* device, + bool transpose, + NS::UInteger rows, + NS::UInteger columns, + double alpha, + double beta) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta), + device, + transpose, + rows, + columns, + alpha, + beta); +} + +_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* inputMatrix, + Vector* inputVector, + Vector* resultVector) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL( + encodeToCommandBuffer_inputMatrix_inputVector_resultVector), + commandBuffer, + inputMatrix, + inputVector, + resultVector); +} + +} // namespace MPS \ No newline at end of file diff --git a/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..378850802212cab755d77aca4267ada4b525005c --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/backend/metal/utils.h @@ -0,0 +1,169 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +namespace { + +void set_array_buffer( + MTL::ComputeCommandEncoder* compute_encoder, + MTL::ArgumentEncoder* enc, + const array& a, + int idx) { + auto a_buf = static_cast(a.buffer().ptr()); + auto offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + enc->setBuffer(a_buf, offset, idx); + // MTL::Resource usage through argument buffer needs to be explicitly + // flagged to enable hazard tracking + compute_encoder->useResource(a_buf, MTL::ResourceUsageRead); +} + +void set_array_buffer( + MTL::ComputeCommandEncoder* enc, + const array& a, + int idx) { + auto a_buf = static_cast(a.buffer().ptr()); + auto offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + enc->setBuffer(a_buf, offset, idx); +} + +std::string type_to_name(const array& a) { + std::string tname; + switch (a.dtype()) { + case bool_: + tname = "bool_"; + break; + case uint8: + tname = "uint8"; + break; + case uint16: + tname = "uint16"; + break; + case uint32: + tname = "uint32"; + break; + case uint64: + tname = "uint64"; + break; + case int8: + tname = "int8"; + break; + case int16: + tname = "int16"; + break; + case int32: + tname = "int32"; + break; + case int64: + tname = "int64"; + break; + case float16: + tname = "float16"; + break; + case float32: + tname = "float32"; + break; + case bfloat16: + tname = "bfloat16"; + break; + case complex64: + tname = "complex64"; + break; + } + return tname; +} + +MTL::Size get_block_dims(int dim0, int dim1, int dim2) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + // Check all the pows + if (dim0 >= (1 << (pows[0] + 1))) { + pows[0]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim1 >= (1 << (pows[1] + 1))) { + pows[1]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim2 >= (1 << (pows[2] + 1))) { + pows[2]++; + sum++; + } + if (sum == presum || sum == 10) { + break; + } + } + return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; +} + +// Collapse dims that are contiguous to possibly route to a better kernel +// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) +// should return {{2, 4}, {{1, 2}}}. +// +// When multiple arrays are passed they should all have the same shape. The +// collapsed axes are also the same so one shape is returned. +std::tuple, std::vector>> +collapse_contiguous_dims(const std::vector& xs) { + // Make a vector that has axes separated with -1. Collapse all axes between + // -1. + std::vector to_collapse; + if (xs[0].ndim() > 0) { + to_collapse.push_back(0); + for (int i = 1; i < xs[0].ndim(); i++) { + bool contiguous = true; + for (auto& x : xs) { + if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) { + contiguous = false; + } + if (!contiguous) { + break; + } + } + if (!contiguous) { + to_collapse.push_back(-1); + } + to_collapse.push_back(i); + } + to_collapse.push_back(-1); + } + + std::vector out_shape; + std::vector> out_strides(xs.size()); + for (int i = 0; i < to_collapse.size(); i++) { + int current_shape = xs[0].shape()[to_collapse[i]]; + while (to_collapse[++i] != -1) { + current_shape *= xs[0].shape()[to_collapse[i]]; + } + out_shape.push_back(current_shape); + for (int j = 0; j < xs.size(); j++) { + out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]); + } + } + + return std::make_tuple(out_shape, out_strides); +} + +template +std::tuple, std::vector>> +collapse_contiguous_dims(Arrays... xs) { + return collapse_contiguous_dims( + std::vector{std::forward(xs)...}); +} + +} // namespace + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/device.h b/lib/python3.11/site-packages/mlx/include/mlx/device.h new file mode 100644 index 0000000000000000000000000000000000000000..e11edf7936176aef0f5bf56863c14715c94f8b0c --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/device.h @@ -0,0 +1,29 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +namespace mlx::core { + +struct Device { + enum class DeviceType { + cpu, + gpu, + }; + + static constexpr DeviceType cpu = DeviceType::cpu; + static constexpr DeviceType gpu = DeviceType::gpu; + + Device(DeviceType type, int index = 0) : type(type), index(index){}; + + DeviceType type; + int index; +}; + +const Device& default_device(); + +void set_default_device(const Device& d); + +bool operator==(const Device& lhs, const Device& rhs); +bool operator!=(const Device& lhs, const Device& rhs); + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/dtype.h b/lib/python3.11/site-packages/mlx/include/mlx/dtype.h new file mode 100644 index 0000000000000000000000000000000000000000..d5283048505a80cf9c17fe89d5113903c903a10a --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/dtype.h @@ -0,0 +1,105 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#include "mlx/types/complex.h" +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct Dtype { + enum class Val { + bool_, + uint8, + uint16, + uint32, + uint64, + int8, + int16, + int32, + int64, + float16, + float32, + bfloat16, + complex64, + }; + + enum class Kind { + b, /* bool */ + u, /* unsigned int */ + i, /* signed int */ + f, /* float */ + c, /* complex */ + V, /* void - used for brain float */ + }; + + Val val; + const uint8_t size; + constexpr explicit Dtype(Val val, uint8_t size) : val(val), size(size){}; + constexpr operator Val() const { + return val; + }; +}; + +inline bool is_available(const Dtype& dtype) { + return true; +} + +static constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; + +static constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; +static constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)}; +static constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)}; +static constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)}; + +static constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)}; +static constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)}; +static constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)}; +static constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; + +static constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; +static constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; +static constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; +static constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; + +Dtype promote_types(const Dtype& t1, const Dtype& t2); + +inline uint8_t size_of(const Dtype& t) { + return t.size; +} + +Dtype::Kind kindof(const Dtype& t); + +inline bool is_unsigned(const Dtype& t) { + return kindof(t) == Dtype::Kind::u || kindof(t) == Dtype::Kind::b; +} + +inline bool is_floating_point(const Dtype& t) { + return kindof(t) == Dtype::Kind::f || kindof(t) == Dtype::Kind::V || + kindof(t) == Dtype::Kind::c; +} + +inline bool is_complex(const Dtype& t) { + return kindof(t) == Dtype::Kind::c; +} + +inline bool is_integral(const Dtype& t) { + return !(is_floating_point(t)); +} + +template +struct TypeToDtype { + operator Dtype(); +}; + +// Array protocol typestring for Dtype +std::string dtype_to_array_protocol(const Dtype& t); +// Dtype from array protocol type string +Dtype dtype_from_array_protocol(const std::string& t); + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/fft.h b/lib/python3.11/site-packages/mlx/include/mlx/fft.h new file mode 100644 index 0000000000000000000000000000000000000000..dbcc777fed3e68742d017e5b9df85bb95f34c780 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/fft.h @@ -0,0 +1,151 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "array.h" +#include "device.h" +#include "stream.h" + +namespace mlx::core::fft { + +using StreamOrDevice = std::variant; + +/** Compute the n-dimensional Fourier Transform. */ +array fftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array fftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); +array fftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse Fourier Transform. */ +array ifftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array ifftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array ifftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform. */ +inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return fftn(a, {n}, {axis}, s); +} +inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return fftn(a, {axis}, s); +} + +/** Compute the one-dimensional inverse Fourier Transform. */ +inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return ifftn(a, {n}, {axis}, s); +} +inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return ifftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform. */ +inline array fft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return fftn(a, n, axes, s); +} +inline array fft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return fftn(a, axes, s); +} + +/** Compute the two-dimensional inverse Fourier Transform. */ +inline array ifft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return ifftn(a, n, axes, s); +} +inline array ifft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return ifftn(a, axes, s); +} + +/** Compute the n-dimensional Fourier Transform on a real input. */ +array rfftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array rfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array rfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse of `rfftn`. */ +array irfftn( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}); +array irfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array irfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform on a real input. */ +inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return rfftn(a, {n}, {axis}, s); +} +inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return rfftn(a, {axis}, s); +} +/** Compute the one-dimensional inverse of `rfft`. */ +inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return irfftn(a, {n}, {axis}, s); +} +inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return irfftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform on a real input. */ +inline array rfft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return rfftn(a, n, axes, s); +} +inline array rfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return rfftn(a, axes, s); +} + +/** Compute the two-dimensional inverse of `rfft2`. */ +inline array irfft2( + const array& a, + const std::vector& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return irfftn(a, n, axes, s); +} +inline array irfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return irfftn(a, axes, s); +} + +} // namespace mlx::core::fft diff --git a/lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h b/lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3bd373beca78c75dcea7058dca91a689e8496969 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/graph_utils.h @@ -0,0 +1,23 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void print_graph(std::ostream& os, const std::vector& outputs); + +template +void print_graph(std::ostream& os, Arrays... outputs) { + print_graph(os, std::vector{std::forward(outputs)...}); +} + +void export_to_dot(std::ostream& os, const std::vector& outputs); + +template +void export_to_dot(std::ostream& os, Arrays... outputs) { + export_to_dot(os, std::vector{std::forward(outputs)...}); +} + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/io/load.h b/lib/python3.11/site-packages/mlx/include/mlx/io/load.h new file mode 100644 index 0000000000000000000000000000000000000000..1d193392a71b347337996d52a5f7b8bfc75e3fcf --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/io/load.h @@ -0,0 +1,114 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +namespace io { + +class Reader { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() const = 0; + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void read(char* data, size_t n) = 0; + virtual std::string label() const = 0; +}; + +class Writer { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() const = 0; + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void write(const char* data, size_t n) = 0; + virtual std::string label() const = 0; +}; + +class FileReader : public Reader { + public: + explicit FileReader(const std::shared_ptr& is) + : is_(is), label_("stream") {} + explicit FileReader(const std::string& file_path) + : is_(std::make_shared(file_path, std::ios::binary)), + label_(file_path) {} + + bool is_open() const override { + return is_->is_open(); + } + + bool good() const override { + return is_->good(); + } + + size_t tell() const override { + return is_->tellg(); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + is_->seekg(off, way); + } + + void read(char* data, size_t n) override { + is_->read(data, n); + } + + std::string label() const override { + return "file " + label_; + } + + private: + std::shared_ptr is_; + std::string label_; +}; + +class FileWriter : public Writer { + public: + explicit FileWriter(const std::shared_ptr& is) + : os_(is), label_("stream") {} + explicit FileWriter(const std::string& file_path) + : os_(std::make_shared(file_path, std::ios::binary)), + label_(file_path) {} + + bool is_open() const override { + return os_->is_open(); + } + + bool good() const override { + return os_->good(); + } + + size_t tell() const override { + return os_->tellp(); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + os_->seekp(off, way); + } + + void write(const char* data, size_t n) override { + os_->write(data, n); + } + + std::string label() const override { + return "file " + label_; + } + + private: + std::shared_ptr os_; + std::string label_; +}; + +} // namespace io +} // namespace mlx::core \ No newline at end of file diff --git a/lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h b/lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h new file mode 100644 index 0000000000000000000000000000000000000000..104a226ce8042b72ad3d0b3dd7167d4ca12fda13 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/io/safetensor.h @@ -0,0 +1,32 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/io/load.h" +#include "mlx/ops.h" +#include "mlx/primitives.h" + +using json = nlohmann::json; + +namespace mlx::core { + +#define ST_F16 "F16" +#define ST_BF16 "BF16" +#define ST_F32 "F32" + +#define ST_BOOL "BOOL" +#define ST_I8 "I8" +#define ST_I16 "I16" +#define ST_I32 "I32" +#define ST_I64 "I64" +#define ST_U8 "U8" +#define ST_U16 "U16" +#define ST_U32 "U32" +#define ST_U64 "U64" + +// Note: Complex numbers aren't in the spec yet so this could change - +// https://github.com/huggingface/safetensors/issues/389 +#define ST_C64 "C64" +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/linalg.h b/lib/python3.11/site-packages/mlx/include/mlx/linalg.h new file mode 100644 index 0000000000000000000000000000000000000000..80e484eb5d5fea92139400820abdf6f6aef9a716 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/linalg.h @@ -0,0 +1,63 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/ops.h" +#include "mlx/stream.h" + +namespace mlx::core::linalg { + +/** + * Compute vector or matrix norms. + * + * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). + * - If axis is not provided but ord is, then x must be either 1D or 2D. + * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm + * for matrices) is computed along the given axes. At most 2 axes can be + * specified. + * - If both axis and ord are provided, then the corresponding matrix or vector + * norm is computed. At most 2 axes can be specified. + */ +array norm( + const array& a, + const double ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const double ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::string& ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const std::string& ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array +norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { + return norm(a, std::vector{axis}, keepdims, s); +} + +} // namespace mlx::core::linalg diff --git a/lib/python3.11/site-packages/mlx/include/mlx/mlx.h b/lib/python3.11/site-packages/mlx/include/mlx/mlx.h new file mode 100644 index 0000000000000000000000000000000000000000..8d785c39fcdd9246dd6e7c21e54fb00b26f1a234 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/mlx.h @@ -0,0 +1,14 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/device.h" +#include "mlx/fft.h" +#include "mlx/linalg.h" +#include "mlx/ops.h" +#include "mlx/random.h" +#include "mlx/stream.h" +#include "mlx/transforms.h" +#include "mlx/utils.h" diff --git a/lib/python3.11/site-packages/mlx/include/mlx/ops.h b/lib/python3.11/site-packages/mlx/include/mlx/ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0f7b52da445b609a1f55d4f64c30736bc0056c61 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/ops.h @@ -0,0 +1,1094 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +#include "array.h" +#include "device.h" +#include "io/load.h" +#include "stream.h" + +namespace mlx::core { + +using StreamOrDevice = std::variant; + +Stream to_stream(StreamOrDevice s); + +/** Creation operations */ + +/** + * A 1D array of numbers starting at `start` (optional), + * stopping at stop, stepping by `step` (optional). */ +array arange( + double start, + double stop, + double step, + Dtype dtype, + StreamOrDevice s = {}); +array arange(double start, double stop, double step, StreamOrDevice s = {}); +array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {}); +array arange(double start, double stop, StreamOrDevice s = {}); +array arange(double stop, Dtype dtype, StreamOrDevice s = {}); +array arange(double stop, StreamOrDevice s = {}); + +array arange(int start, int stop, int step, StreamOrDevice s = {}); +array arange(int start, int stop, StreamOrDevice s = {}); +array arange(int stop, StreamOrDevice s = {}); + +/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */ +array linspace( + double start, + double stop, + int num = 50, + Dtype dtype = float32, + StreamOrDevice s = {}); + +/** Convert an array to the given data type. */ +array astype(const array& a, Dtype dtype, StreamOrDevice s = {}); + +/** Create a view of an array with the given shape and strides. */ +array as_strided( + const array& a, + std::vector shape, + std::vector strides, + size_t offset, + StreamOrDevice s = {}); + +/** Copy another array. */ +array copy(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with the given value(s). */ +array full( + const std::vector& shape, + const array& vals, + Dtype dtype, + StreamOrDevice s = {}); +array full( + const std::vector& shape, + const array& vals, + StreamOrDevice s = {}); +template +array full( + const std::vector& shape, + T val, + Dtype dtype, + StreamOrDevice s = {}) { + return full(shape, array(val, dtype), to_stream(s)); +} +template +array full(const std::vector& shape, T val, StreamOrDevice s = {}) { + return full(shape, array(val), to_stream(s)); +} + +/** Fill an array of the given shape with zeros. */ +array zeros(const std::vector& shape, Dtype dtype, StreamOrDevice s = {}); +inline array zeros(const std::vector& shape, StreamOrDevice s = {}) { + return zeros(shape, float32, s); +} +array zeros_like(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with ones. */ +array ones(const std::vector& shape, Dtype dtype, StreamOrDevice s = {}); +inline array ones(const std::vector& shape, StreamOrDevice s = {}) { + return ones(shape, float32, s); +} +array ones_like(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape (n,m) with ones in the specified diagonal + * k, and zeros everywhere else. */ +array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {}); +inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) { + return eye(n, n, 0, dtype, s); +} +inline array eye(int n, int m, StreamOrDevice s = {}) { + return eye(n, m, 0, float32, s); +} +inline array eye(int n, int m, int k, StreamOrDevice s = {}) { + return eye(n, m, k, float32, s); +} +inline array eye(int n, StreamOrDevice s = {}) { + return eye(n, n, 0, float32, s); +} + +/** Create a square matrix of shape (n,n) of zeros, and ones in the major + * diagonal. */ +array identity(int n, Dtype dtype, StreamOrDevice s = {}); +inline array identity(int n, StreamOrDevice s = {}) { + return identity(n, float32, s); +} + +array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {}); +inline array tri(int n, Dtype type, StreamOrDevice s = {}) { + return tri(n, n, 0, type, s); +} + +array tril(array x, int k, StreamOrDevice s = {}); +array triu(array x, int k, StreamOrDevice s = {}); + +/** array manipulation */ + +/** Reshape an array to the given shape. */ +array reshape(const array& a, std::vector shape, StreamOrDevice s = {}); + +/** Flatten the dimensions in the range `[start_axis, end_axis]` . */ +array flatten( + const array& a, + int start_axis, + int end_axis = -1, + StreamOrDevice s = {}); + +/** Flatten the array to 1D. */ +array flatten(const array& a, StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axes. */ +array squeeze( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axis. */ +inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) { + return squeeze(a, std::vector{axis}, s); +} + +/** Remove all singleton dimensions. */ +array squeeze(const array& a, StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axes. */ +array expand_dims( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axis. */ +inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) { + return expand_dims(a, std::vector{axis}, s); +} + +/** Slice an array. */ +array slice( + const array& a, + std::vector start, + std::vector stop, + std::vector strides, + StreamOrDevice s = {}); + +/** Slice an array with a stride of 1 in each dimension. */ +array slice( + const array& a, + const std::vector& start, + const std::vector& stop, + StreamOrDevice s = {}); + +/** Split an array into sub-arrays along a given axis. */ +std::vector +split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); +std::vector split(const array& a, int num_splits, StreamOrDevice s = {}); +std::vector split( + const array& a, + const std::vector& indices, + int axis, + StreamOrDevice s = {}); +std::vector +split(const array& a, const std::vector& indices, StreamOrDevice s = {}); + +/** + * Clip (limit) the values in an array. + */ +array clip( + const array& a, + const std::optional& a_min = std::nullopt, + const std::optional& a_max = std::nullopt, + StreamOrDevice s = {}); + +/** Concatenate arrays along a given axis. */ +array concatenate( + const std::vector& arrays, + int axis, + StreamOrDevice s = {}); +array concatenate(const std::vector& arrays, StreamOrDevice s = {}); + +/** Stack arrays along a new axis. */ +array stack(const std::vector& arrays, int axis, StreamOrDevice s = {}); +array stack(const std::vector& arrays, StreamOrDevice s = {}); + +/** Repeat an array along an axis. */ +array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {}); +array repeat(const array& arr, int repeats, StreamOrDevice s = {}); + +/** Permutes the dimensions according to the given axes. */ +array transpose(const array& a, std::vector axes, StreamOrDevice s = {}); +inline array transpose( + const array& a, + std::initializer_list axes, + StreamOrDevice s = {}) { + return transpose(a, std::vector(axes), s); +} + +/** Swap two axes of an array. */ +array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {}); + +/** Move an axis of an array. */ +array moveaxis( + const array& a, + int source, + int destination, + StreamOrDevice s = {}); + +/** Pad an array with a constant value */ +array pad( + const array& a, + const std::vector& axes, + const std::vector& low_pad_size, + const std::vector& high_pad_size, + const array& pad_value = array(0), + StreamOrDevice s = {}); + +/** Pad an array with a constant value along all axes */ +array pad( + const array& a, + const std::vector>& pad_width, + const array& pad_value = array(0), + StreamOrDevice s = {}); +array pad( + const array& a, + const std::pair& pad_width, + const array& pad_value = array(0), + StreamOrDevice s = {}); +array pad( + const array& a, + int pad_width, + const array& pad_value = array(0), + StreamOrDevice s = {}); + +/** Permutes the dimensions in reverse order. */ +array transpose(const array& a, StreamOrDevice s = {}); + +/** Broadcast an array to a given shape. */ +array broadcast_to( + const array& a, + const std::vector& shape, + StreamOrDevice s = {}); + +/** Broadcast a vector of arrays against one another. */ +std::vector broadcast_arrays( + const std::vector& inputs, + StreamOrDevice s = {}); + +/** Comparison operations */ + +/** Returns the bool array with (a == b) element-wise. */ +array equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator==(const array& a, const array& b) { + return equal(a, b); +} +template +array operator==(T a, const array& b) { + return equal(array(a), b); +} +template +array operator==(const array& a, T b) { + return equal(a, array(b)); +} + +/** Returns the bool array with (a != b) element-wise. */ +array not_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator!=(const array& a, const array& b) { + return not_equal(a, b); +} +template +array operator!=(T a, const array& b) { + return not_equal(array(a), b); +} +template +array operator!=(const array& a, T b) { + return not_equal(a, array(b)); +} + +/** Returns bool array with (a > b) element-wise. */ +array greater(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>(const array& a, const array& b) { + return greater(a, b); +} +template +array operator>(T a, const array& b) { + return greater(array(a), b); +} +template +array operator>(const array& a, T b) { + return greater(a, array(b)); +} + +/** Returns bool array with (a >= b) element-wise. */ +array greater_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>=(const array& a, const array& b) { + return greater_equal(a, b); +} +template +array operator>=(T a, const array& b) { + return greater_equal(array(a), b); +} +template +array operator>=(const array& a, T b) { + return greater_equal(a, array(b)); +} + +/** Returns bool array with (a < b) element-wise. */ +array less(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<(const array& a, const array& b) { + return less(a, b); +} +template +array operator<(T a, const array& b) { + return less(array(a), b); +} +template +array operator<(const array& a, T b) { + return less(a, array(b)); +} + +/** Returns bool array with (a <= b) element-wise. */ +array less_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<=(const array& a, const array& b) { + return less_equal(a, b); +} +template +array operator<=(T a, const array& b) { + return less_equal(array(a), b); +} +template +array operator<=(const array& a, T b) { + return less_equal(a, array(b)); +} + +/** True if two arrays have the same shape and elements. */ +array array_equal( + const array& a, + const array& b, + bool equal_nan, + StreamOrDevice s = {}); +inline array +array_equal(const array& a, const array& b, StreamOrDevice s = {}) { + return array_equal(a, b, false, s); +} + +/** Select from x or y depending on condition. */ +array where( + const array& condition, + const array& x, + const array& y, + StreamOrDevice s = {}); + +/** Reduction operations */ + +/** True if all elements in the array are true (or non-zero). **/ +array all(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array all(const array& a, StreamOrDevice s = {}) { + return all(a, false, to_stream(s)); +} + +/** True if the two arrays are equal within the specified tolerance. */ +array allclose( + const array& a, + const array& b, + double rtol = 1e-5, + double atol = 1e-8, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axes. An output value is true + * if all the corresponding inputs are true. + **/ +array all( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if all the corresponding inputs are true. + **/ +array all( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** True if any elements in the array are true (or non-zero). **/ +array any(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array any(const array& a, StreamOrDevice s = {}) { + return any(a, false, to_stream(s)); +} + +/** + * Reduces the input along the given axes. An output value is true + * if any of the corresponding inputs are true. + **/ +array any( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if any of the corresponding inputs are true. + **/ +array any( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Sums the elements of an array. */ +array sum(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array sum(const array& a, StreamOrDevice s = {}) { + return sum(a, false, to_stream(s)); +} + +/** Sums the elements of an array along the given axes. */ +array sum( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Sums the elements of an array along the given axis. */ +array sum( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array. */ +array mean(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array mean(const array& a, StreamOrDevice s = {}) { + return mean(a, false, to_stream(s)); +} + +/** Computes the mean of the elements of an array along the given axes */ +array mean( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array along the given axis */ +array mean( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array. */ +array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); +inline array var(const array& a, StreamOrDevice s = {}) { + return var(a, false, 0, to_stream(s)); +} + +/** Computes the var of the elements of an array along the given axes */ +array var( + const array& a, + const std::vector& axes, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the var of the elements of an array along the given axis */ +array var( + const array& a, + int axis, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** The product of all elements of the array. */ +array prod(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array prod(const array& a, StreamOrDevice s = {}) { + return prod(a, false, to_stream(s)); +} + +/** The product of the elements of an array along the given axes. */ +array prod( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The product of the elements of an array along the given axis. */ +array prod( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The maximum of all elements of the array. */ +array max(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array max(const array& a, StreamOrDevice s = {}) { + return max(a, false, to_stream(s)); +} + +/** The maximum of the elements of an array along the given axes. */ +array max( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The maximum of the elements of an array along the given axis. */ +array max( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The minimum of all elements of the array. */ +array min(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array min(const array& a, StreamOrDevice s = {}) { + return min(a, false, to_stream(s)); +} + +/** The minimum of the elements of an array along the given axes. */ +array min( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The minimum of the elements of an array along the given axis. */ +array min( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns the index of the minimum value in the array. */ +array argmin(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmin(const array& a, StreamOrDevice s = {}) { + return argmin(a, false, s); +} + +/** Returns the indices of the minimum values along a given axis. */ +array argmin( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns the index of the maximum value in the array. */ +array argmax(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmax(const array& a, StreamOrDevice s = {}) { + return argmax(a, false, s); +} + +/** Returns the indices of the maximum values along a given axis. */ +array argmax( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns a sorted copy of the flattened array. */ +array sort(const array& a, StreamOrDevice s = {}); + +/** Returns a sorted copy of the array along a given axis. */ +array sort(const array& a, int axis, StreamOrDevice s = {}); + +/** Returns indices that sort the flattened array. */ +array argsort(const array& a, StreamOrDevice s = {}); + +/** Returns indices that sort the array along a given axis. */ +array argsort(const array& a, int axis, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the flattened array + * such that the smaller kth elements are first. + **/ +array partition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the array along a given axis + * such that the smaller kth elements are first. + **/ +array partition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** + * Returns indices that partition the flattened array + * such that the smaller kth elements are first. + **/ +array argpartition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns indices that partition the array along a given axis + * such that the smaller kth elements are first. + **/ +array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** Returns topk elements of the flattened array. */ +array topk(const array& a, int k, StreamOrDevice s = {}); + +/** Returns topk elements of the array along a given axis. */ +array topk(const array& a, int k, int axis, StreamOrDevice s = {}); + +/** The logsumexp of all elements of the array. */ +array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array logsumexp(const array& a, StreamOrDevice s = {}) { + return logsumexp(a, false, to_stream(s)); +} + +/** The logsumexp of the elements of an array along the given axes. */ +array logsumexp( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The logsumexp of the elements of an array along the given axis. */ +array logsumexp( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Simple arithmetic operations */ + +/** Absolute value of elements in an array. */ +array abs(const array& a, StreamOrDevice s = {}); + +/** Negate an array. */ +array negative(const array& a, StreamOrDevice s = {}); +array operator-(const array& a); + +/** The sign of the elements in an array. */ +array sign(const array& a, StreamOrDevice s = {}); + +/** Logical not of an array */ +array logical_not(const array& a, StreamOrDevice s = {}); + +/** The reciprocal (1/x) of the elements in an array. */ +array reciprocal(const array& a, StreamOrDevice s = {}); + +/** Add two arrays. */ +array add(const array& a, const array& b, StreamOrDevice s = {}); +array operator+(const array& a, const array& b); +template +array operator+(T a, const array& b) { + return add(array(a), b); +} +template +array operator+(const array& a, T b) { + return add(a, array(b)); +} + +/** Subtract two arrays. */ +array subtract(const array& a, const array& b, StreamOrDevice s = {}); +array operator-(const array& a, const array& b); +template +array operator-(T a, const array& b) { + return subtract(array(a), b); +} +template +array operator-(const array& a, T b) { + return subtract(a, array(b)); +} + +/** Multiply two arrays. */ +array multiply(const array& a, const array& b, StreamOrDevice s = {}); +array operator*(const array& a, const array& b); +template +array operator*(T a, const array& b) { + return multiply(array(a), b); +} +template +array operator*(const array& a, T b) { + return multiply(a, array(b)); +} + +/** Divide two arrays. */ +array divide(const array& a, const array& b, StreamOrDevice s = {}); +array operator/(const array& a, const array& b); +array operator/(double a, const array& b); +array operator/(const array& a, double b); + +/** Compute integer division. Equivalent to doing floor(a / x). */ +array floor_divide(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute the element-wise remainder of division */ +array remainder(const array& a, const array& b, StreamOrDevice s = {}); +array operator%(const array& a, const array& b); +template +array operator%(T a, const array& b) { + return remainder(array(a), b); +} +template +array operator%(const array& a, T b) { + return remainder(a, array(b)); +} + +/** Element-wise maximum between two arrays. */ +array maximum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise minimum between two arrays. */ +array minimum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Floor the element of an array. **/ +array floor(const array& a, StreamOrDevice s = {}); + +/** Ceil the element of an array. **/ +array ceil(const array& a, StreamOrDevice s = {}); + +/** Square the elements of an array. */ +array square(const array& a, StreamOrDevice s = {}); + +/** Exponential of the elements of an array. */ +array exp(const array& a, StreamOrDevice s = {}); + +/** Sine of the elements of an array */ +array sin(const array& a, StreamOrDevice s = {}); + +/** Cosine of the elements of an array */ +array cos(const array& a, StreamOrDevice s = {}); + +/** Tangent of the elements of an array */ +array tan(const array& a, StreamOrDevice s = {}); + +/** Arc Sine of the elements of an array */ +array arcsin(const array& a, StreamOrDevice s = {}); + +/** Arc Cosine of the elements of an array */ +array arccos(const array& a, StreamOrDevice s = {}); + +/** Arc Tangent of the elements of an array */ +array arctan(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Sine of the elements of an array */ +array sinh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Cosine of the elements of an array */ +array cosh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Tangent of the elements of an array */ +array tanh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Sine of the elements of an array */ +array arcsinh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Cosine of the elements of an array */ +array arccosh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Tangent of the elements of an array */ +array arctanh(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of the elements of an array. */ +array log(const array& a, StreamOrDevice s = {}); + +/** Log base 2 of the elements of an array. */ +array log2(const array& a, StreamOrDevice s = {}); + +/** Log base 10 of the elements of an array. */ +array log10(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */ +array log1p(const array& a, StreamOrDevice s = {}); + +/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */ +array logaddexp(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */ +array sigmoid(const array& a, StreamOrDevice s = {}); + +/** Computes the error function of the elements of an array. */ +array erf(const array& a, StreamOrDevice s = {}); + +/** Computes the inverse error function of the elements of an array. */ +array erfinv(const array& a, StreamOrDevice s = {}); + +/** Stop the flow of gradients. */ +array stop_gradient(const array& a, StreamOrDevice s = {}); + +/** Round a floating point number */ +array round(const array& a, int decimals, StreamOrDevice s = {}); +inline array round(const array& a, StreamOrDevice s = {}) { + return round(a, 0, s); +} + +/** Matrix-matrix multiplication. */ +array matmul(const array& a, const array& b, StreamOrDevice s = {}); + +/** Gather array entries given indices and slices */ +array gather( + const array& a, + const std::vector& indices, + const std::vector& axes, + const std::vector& slice_sizes, + StreamOrDevice s = {}); +inline array gather( + const array& a, + const array& indices, + int axis, + const std::vector& slice_sizes, + StreamOrDevice s = {}) { + return gather(a, {indices}, std::vector{axis}, slice_sizes, s); +} + +/** Take array slices at the given indices of the specified axis. */ +array take( + const array& a, + const array& indices, + int axis, + StreamOrDevice s = {}); + +/** Take array entries at the given indices treating the array as flattened. */ +array take(const array& a, const array& indices, StreamOrDevice s = {}); + +/** Take array entries given indices along the axis */ +array take_along_axis( + const array& a, + const array& indices, + int axis, + StreamOrDevice s = {}); + +/** Scatter updates to given linear indices */ +array scatter( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and add updates to given indices */ +array scatter_add( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_add( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_add(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and prod updates to given indices */ +array scatter_prod( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_prod( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_prod(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and max updates to given linear indices */ +array scatter_max( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_max( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_max(a, {indices}, updates, std::vector{axis}, s); +} +/** Scatter and min updates to given linear indices */ +array scatter_min( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_min( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_min(a, {indices}, updates, std::vector{axis}, s); +} + +/** Square root the elements of an array. */ +array sqrt(const array& a, StreamOrDevice s = {}); + +/** Square root and reciprocal the elements of an array. */ +array rsqrt(const array& a, StreamOrDevice s = {}); + +/** Softmax of an array. */ +array softmax( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Softmax of an array. */ +array softmax(const array& a, StreamOrDevice s = {}); + +/** Softmax of an array. */ +inline array softmax(const array& a, int axis, StreamOrDevice s = {}) { + return softmax(a, std::vector{axis}, s); +} + +/** Raise elements of a to the power of b element-wise */ +array power(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator^(const array& a, const array& b) { + return power(a, b); +} +template +array operator^(T a, const array& b) { + return power(array(a), b); +} +template +array operator^(const array& a, T b) { + return power(a, array(b)); +} + +/** Cumulative sum of an array. */ +array cumsum( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative product of an array. */ +array cumprod( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative max of an array. */ +array cummax( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative min of an array. */ +array cummin( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Convolution operations */ + +/** 1D convolution with a filter */ +array conv1d( + const array& input, + const array& weight, + int stride = 1, + int padding = 0, + int dilation = 1, + int groups = 1, + StreamOrDevice s = {}); + +/** 2D convolution with a filter */ +array conv2d( + const array& input, + const array& weight, + const std::pair& stride = {1, 1}, + const std::pair& padding = {0, 0}, + const std::pair& dilation = {1, 1}, + int groups = 1, + StreamOrDevice s = {}); + +/** Serialization operations */ + +/** Save array to out stream in .npy format */ +void save( + std::shared_ptr out_stream, + array a, + bool retain_graph = true); + +/** Save array to file in .npy format */ +void save(const std::string& file, array a, bool retain_graph = true); + +/** Load array from reader in .npy format */ +array load(std::shared_ptr in_stream, StreamOrDevice s = {}); + +/** Load array from file in .npy format */ +array load(const std::string& file, StreamOrDevice s = {}); + +/** Quantized matmul multiplies x with a quantized matrix w*/ +array quantized_matmul( + const array& x, + const array& w, + const array& scales, + const array& biases, + bool transpose = true, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + +/** Quantize a matrix along its last axis */ +std::tuple quantize( + const array& w, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + +/** Dequantize a matrix produced by quantize() */ +array dequantize( + const array& w, + const array& scales, + const array& biases, + int group_size = 64, + int bits = 4, + StreamOrDevice s = {}); + +/** TensorDot returns a contraction of a and b over multiple dimensions. */ +array tensordot( + const array& a, + const array& b, + const int dims = 2, + StreamOrDevice s = {}); + +array tensordot( + const array& a, + const array& b, + const std::pair, std::vector>& dims, + StreamOrDevice s = {}); + +/** Load array map from .safetensors file format */ +std::unordered_map load_safetensors( + std::shared_ptr in_stream, + StreamOrDevice s = {}); +std::unordered_map load_safetensors( + const std::string& file, + StreamOrDevice s = {}); + +void save_safetensors( + std::shared_ptr in_stream, + std::unordered_map, + std::optional retain_graph = std::nullopt); +void save_safetensors( + const std::string& file, + std::unordered_map, + std::optional retain_graph = std::nullopt); +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/primitives.h b/lib/python3.11/site-packages/mlx/include/mlx/primitives.h new file mode 100644 index 0000000000000000000000000000000000000000..e87e934ab4cbea12936ed788b6723bc8d0f6268a --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/primitives.h @@ -0,0 +1,1636 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "array.h" +#include "device.h" +#include "io/load.h" +#include "stream.h" + +#define DEFINE_GRADS() \ + array jvp( \ + const std::vector& primals, \ + const std::vector& tangents, \ + const std::vector& argnums) override; \ + \ + std::vector vjp( \ + const std::vector& primals, \ + const array& cotan, \ + const std::vector& argnums) override; + +#define DEFINE_PRINT(PRIMITIVE) \ + void print(std::ostream& os) override { \ + os << #PRIMITIVE; \ + } + +#define DEFINE_DEFAULT_IS_EQUIVALENT() \ + bool is_equivalent(const Primitive& other) const override { \ + return true; \ + } + +namespace mlx::core { + +// Abstract base class +class Primitive { + public: + explicit Primitive(Stream stream) : stream_(stream) {} + + /** The device the primitive will run on. */ + const Device& device() { + return stream().device; + } + + /** The stream the primitive will run on. */ + const Stream& stream() { + return stream_; + } + + /** + * A primitive must know how to evaluate itself on + * the CPU/GPU for the given inputs and populate the output array. + * + * To avoid unnecessary allocations, the evaluation function + * is responsible for allocating space for the array. + */ + virtual void eval_cpu(const std::vector& inputs, array& out) = 0; + virtual void eval_gpu(const std::vector& inputs, array& out) = 0; + + /** + * The Jacobian-vector product. + */ + virtual array jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums); + + /** + * The vector-Jacobian product. + */ + virtual std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums); + + /** + * The primitive must know how to vectorize itself across + * the given axes. The output is a pair containing the array + * representing the vectorized computation and the axis which + * corresponds to the output vectorized dimension. + */ + virtual std::pair vmap( + const std::vector& inputs, + const std::vector& axes); + + /** Print the primitive. */ + virtual void print(std::ostream& os) = 0; + + /** Equivalence check defaults to false unless overridden by the primitive */ + virtual bool is_equivalent(const Primitive& other) const { + return false; + } + + virtual ~Primitive() = default; + Primitive(const Primitive& other) = delete; + Primitive(Primitive&& other) = delete; + Primitive& operator=(const Primitive& other) = delete; + Primitive& operator=(Primitive&& other) = delete; + + private: + // Every primitive stores the stream it should run in + Stream stream_; +}; + +class Abs : public Primitive { + public: + explicit Abs(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Abs) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Add : public Primitive { + public: + explicit Add(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Add) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Arange : public Primitive { + public: + explicit Arange(Stream stream, double start, double stop, double step) + : Primitive(stream), start_(start), stop_(stop), step_(step){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(Arange) + bool is_equivalent(const Primitive& other) const override; + + private: + double start_; + double stop_; + double step_; + + void eval(const std::vector& inputs, array& out); +}; + +class ArcCos : public Primitive { + public: + explicit ArcCos(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcCos) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcCosh : public Primitive { + public: + explicit ArcCosh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcCosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcSin : public Primitive { + public: + explicit ArcSin(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcSin) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcSinh : public Primitive { + public: + explicit ArcSinh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcSinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcTan : public Primitive { + public: + explicit ArcTan(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcTan) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcTanh : public Primitive { + public: + explicit ArcTanh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcTanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArgPartition : public Primitive { + public: + explicit ArgPartition(Stream stream, int kth, int axis) + : Primitive(stream), kth_(kth), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(ArgPartition) + bool is_equivalent(const Primitive& other) const override; + + private: + int kth_; + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class ArgReduce : public Primitive { + public: + enum ReduceType { + ArgMin, + ArgMax, + }; + + explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) + : Primitive(stream), reduce_type_(reduce_type), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(ArgReduce) + bool is_equivalent(const Primitive& other) const override; + + private: + ReduceType reduce_type_; + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class ArgSort : public Primitive { + public: + explicit ArgSort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(ArgSort) + bool is_equivalent(const Primitive& other) const override; + + private: + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class AsType : public Primitive { + public: + explicit AsType(Stream stream, Dtype dtype) + : Primitive(stream), dtype_(dtype){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(AsType) + bool is_equivalent(const Primitive& other) const override; + + private: + Dtype dtype_; + + void eval(const std::vector& inputs, array& out); +}; + +class AsStrided : public Primitive { + public: + explicit AsStrided( + Stream stream, + const std::vector& shape, + const std::vector& strides, + size_t offset) + : Primitive(stream), shape_(shape), strides_(strides), offset_(offset){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_PRINT(AsStrided) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + std::vector strides_; + size_t offset_; + + void eval(const std::vector& inputs, array& out); +}; + +class Broadcast : public Primitive { + public: + explicit Broadcast(Stream stream, const std::vector& shape) + : Primitive(stream), shape_(shape){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Broadcast) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + + void eval(const std::vector& inputs, array& out); +}; + +class Ceil : public Primitive { + public: + explicit Ceil(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Ceil) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Concatenate : public Primitive { + public: + explicit Concatenate(Stream stream, int axis) + : Primitive(stream), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Concatenate) + bool is_equivalent(const Primitive& other) const override; + + private: + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class Convolution : public Primitive { + public: + explicit Convolution( + Stream stream, + const std::vector& padding, + const std::vector& kernel_strides, + const std::vector& kernel_dilation, + const std::vector& input_dilation) + : Primitive(stream), + padding_(padding), + kernel_strides_(kernel_strides), + kernel_dilation_(kernel_dilation), + input_dilation_(input_dilation){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + DEFINE_PRINT(Convolution) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector padding_; + std::vector kernel_strides_; + std::vector kernel_dilation_; + std::vector input_dilation_; + + void eval(const std::vector& inputs, array& out); +}; + +class Copy : public Primitive { + public: + explicit Copy(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Copy) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Cos : public Primitive { + public: + explicit Cos(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Cos) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Cosh : public Primitive { + public: + explicit Cosh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Cosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Divide : public Primitive { + public: + explicit Divide(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Divide) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Remainder : public Primitive { + public: + explicit Remainder(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Remainder) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Equal : public Primitive { + public: + explicit Equal(Stream stream, bool equal_nan = false) + : Primitive(stream), equal_nan_(equal_nan){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Equal) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); + bool equal_nan_; +}; + +class Erf : public Primitive { + public: + explicit Erf(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Erf) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ErfInv : public Primitive { + public: + explicit ErfInv(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ErfInv) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Exp : public Primitive { + public: + explicit Exp(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Exp) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class FFT : public Primitive { + public: + explicit FFT( + Stream stream, + const std::vector& axes, + bool inverse, + bool real) + : Primitive(stream), axes_(axes), inverse_(inverse), real_(real){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(FFT) + + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector axes_; + bool inverse_; + bool real_; + + void eval(const std::vector& inputs, array& out); +}; + +class Floor : public Primitive { + public: + explicit Floor(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Floor) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Full : public Primitive { + public: + explicit Full(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Full) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Gather : public Primitive { + public: + explicit Gather( + Stream stream, + const std::vector& axes, + const std::vector& slice_sizes) + : Primitive(stream), axes_(axes), slice_sizes_(slice_sizes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Gather) + bool is_equivalent(const Primitive& other) const override; + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; + std::vector slice_sizes_; +}; + +class Greater : public Primitive { + public: + explicit Greater(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Greater) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class GreaterEqual : public Primitive { + public: + explicit GreaterEqual(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(GreaterEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Less : public Primitive { + public: + explicit Less(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Less) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class LessEqual : public Primitive { + public: + explicit LessEqual(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LessEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Load : public Primitive { + public: + explicit Load( + Stream stream, + std::shared_ptr reader, + size_t offset, + bool swap_endianness = false) + : Primitive(stream), + reader_(reader), + offset_(offset), + swap_endianness_(swap_endianness){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(Load) + + private: + void eval(const std::vector& inputs, array& out); + std::shared_ptr reader_; + size_t offset_; + bool swap_endianness_; +}; + +class Log : public Primitive { + public: + enum Base { two, ten, e }; + + explicit Log(Stream stream, Base base) : Primitive(stream), base_(base){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Log) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + Base base_; + void eval(const std::vector& inputs, array& out); +}; + +class Log1p : public Primitive { + public: + explicit Log1p(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Log1p) + + private: + void eval(const std::vector& inputs, array& out); +}; + +class LogicalNot : public Primitive { + public: + explicit LogicalNot(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LogicalNot) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class LogAddExp : public Primitive { + public: + explicit LogAddExp(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LogAddExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Matmul : public Primitive { + public: + explicit Matmul(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(Matmul) + DEFINE_DEFAULT_IS_EQUIVALENT() +}; + +class Maximum : public Primitive { + public: + explicit Maximum(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Maximum) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Minimum : public Primitive { + public: + explicit Minimum(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Minimum) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Multiply : public Primitive { + public: + explicit Multiply(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Multiply) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Negative : public Primitive { + public: + explicit Negative(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Negative) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class NotEqual : public Primitive { + public: + explicit NotEqual(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(NotEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Pad : public Primitive { + public: + explicit Pad( + Stream stream, + const std::vector& axes, + const std::vector& low_pad_size, + const std::vector& high_pad_size) + : Primitive(stream), + axes_(axes), + low_pad_size_(low_pad_size), + high_pad_size_(high_pad_size){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Pad) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector axes_; + std::vector low_pad_size_; + std::vector high_pad_size_; + + void eval(const std::vector& inputs, array& out); +}; + +class Partition : public Primitive { + public: + explicit Partition(Stream stream, int kth, int axis) + : Primitive(stream), kth_(kth), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Partition) + bool is_equivalent(const Primitive& other) const override; + + private: + int kth_; + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class Power : public Primitive { + public: + explicit Power(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Power) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class QuantizedMatmul : public Primitive { + public: + explicit QuantizedMatmul( + Stream stream, + int group_size, + int bits, + bool transpose) + : Primitive(stream), + group_size_(group_size), + bits_(bits), + transpose_(transpose){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(QuantizedMatmul) + bool is_equivalent(const Primitive& other) const override; + + private: + int group_size_; + int bits_; + bool transpose_; + + void eval(const std::vector& inputs, array& out); +}; + +class RandomBits : public Primitive { + public: + explicit RandomBits(Stream stream, const std::vector& shape, int width) + : Primitive(stream), shape_(shape), width_(width){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(RandomBits) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + int width_; + + void eval(const std::vector& inputs, array& out); +}; + +class Reshape : public Primitive { + public: + explicit Reshape(Stream stream, const std::vector& shape) + : Primitive(stream), shape_(shape){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Reshape) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + + void eval(const std::vector& inputs, array& out); +}; + +class Reduce : public Primitive { + public: + enum ReduceType { And, Or, Sum, Prod, Min, Max }; + + explicit Reduce( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + void print(std::ostream& os) override { + switch (reduce_type_) { + case And: + os << "And"; + case Or: + os << "And"; + break; + case Sum: + os << "Sum"; + break; + case Prod: + os << "Prod"; + break; + case Min: + os << "Min"; + break; + case Max: + os << "Max"; + break; + } + os << " Reduce"; + } + bool is_equivalent(const Primitive& other) const override; + + private: + ReduceType reduce_type_; + std::vector axes_; + + void eval(const std::vector& inputs, array& out); +}; + +class Round : public Primitive { + public: + explicit Round(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Round) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Scan : public Primitive { + public: + enum ReduceType { Max, Min, Sum, Prod }; + + explicit Scan( + Stream stream, + ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive) + : Primitive(stream), + reduce_type_(reduce_type), + axis_(axis), + reverse_(reverse), + inclusive_(inclusive){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS(); + void print(std::ostream& os) override { + os << "Cum"; + switch (reduce_type_) { + case Sum: + os << "Sum"; + break; + case Prod: + os << "Prod"; + break; + case Min: + os << "Min"; + break; + case Max: + os << "Max"; + break; + } + os << " Reduce"; + } + bool is_equivalent(const Primitive& other) const override; + + private: + ReduceType reduce_type_; + int axis_; + bool reverse_; + bool inclusive_; + + void eval(const std::vector& inputs, array& out); +}; + +class Scatter : public Primitive { + public: + enum ReduceType { Max, Min, Sum, Prod, None }; + + explicit Scatter( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(Scatter) + bool is_equivalent(const Primitive& other) const override; + + private: + void eval(const std::vector& inputs, array& out); + ReduceType reduce_type_; + std::vector axes_; +}; + +class Sigmoid : public Primitive { + public: + explicit Sigmoid(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sigmoid) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sign : public Primitive { + public: + explicit Sign(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sign) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sin : public Primitive { + public: + explicit Sin(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sin) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sinh : public Primitive { + public: + explicit Sinh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Slice : public Primitive { + public: + explicit Slice( + Stream stream, + const std::vector& start_indices, + const std::vector& end_indices, + const std::vector& strides) + : Primitive(stream), + start_indices_(start_indices), + end_indices_(end_indices), + strides_(strides){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Slice) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector start_indices_; + std::vector end_indices_; + std::vector strides_; + + void eval(const std::vector& inputs, array& out); +}; + +class Softmax : public Primitive { + public: + explicit Softmax(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Softmax) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sort : public Primitive { + public: + explicit Sort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sort) + bool is_equivalent(const Primitive& other) const override; + + private: + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class Square : public Primitive { + public: + explicit Square(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Square) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sqrt : public Primitive { + public: + explicit Sqrt(Stream stream, bool recip = false) + : Primitive(stream), recip_(recip){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sqrt) + bool is_equivalent(const Primitive& other) const override; + + private: + void eval(const std::vector& inputs, array& out); + bool recip_; +}; + +class StopGradient : public Primitive { + public: + explicit StopGradient(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(StopGradient) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Subtract : public Primitive { + public: + explicit Subtract(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Subtract) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Tan : public Primitive { + public: + explicit Tan(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Tan) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Tanh : public Primitive { + public: + explicit Tanh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Tanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Uniform : public Primitive { + public: + explicit Uniform(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(Uniform) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Transpose : public Primitive { + public: + explicit Transpose(Stream stream, const std::vector& axes) + : Primitive(stream), axes_(axes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Transpose) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector axes_; + + void eval(const std::vector& inputs, array& out); +}; + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/random.h b/lib/python3.11/site-packages/mlx/include/mlx/random.h new file mode 100644 index 0000000000000000000000000000000000000000..360bdbdb11d8e5fa65f9966edea91e5e07b61d24 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/random.h @@ -0,0 +1,193 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::random { + +class KeySequence { + public: + explicit KeySequence(uint64_t seed); + + void seed(uint64_t seed); + array next(); + + // static default + static KeySequence& default_() { + static KeySequence ks(0); + return ks; + } + + private: + array key_; +}; + +/** Get a PRNG key from a seed. */ +array key(uint64_t seed); + +/** Seed the default PRNG key. */ +void seed(uint64_t seed); + +/** Generate an array with type uint32 filled with random bits. */ +array bits( + const std::vector& shape, + int width, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array bits( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bits(shape, 4, key, s); +} + +/** Split the rng key into a pair of keys. */ +std::pair split(const array& key, StreamOrDevice s = {}); + +/** Split the rng key into `num` keys. */ +array split(const array& key, int num, StreamOrDevice s = {}); + +/** Generate uniform random numbers between low and high. */ +array uniform( + const array& low, + const array& high, + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array uniform( + T low, + U high, + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(array(low), array(high), shape, dtype, key, to_stream(s)); +} + +/** Generate uniform random numbers between 0 and 1. */ +array uniform( + const std::vector& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array uniform( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(shape, float32, key); +} + +/** Generate samples from the standard normal distribution. */ +array normal( + const std::vector& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array normal( + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, key, s); +} + +/** Generate integer samples uniformly at random */ +array randint( + const array& low, + const array& high, + const std::vector& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array randint( + T low, + U high, + const std::vector& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return randint(array(low), array(high), shape, dtype, key, to_stream(s)); +}; + +/** Generate binary variables with probability to be true equal to p */ +array bernoulli( + const array& p, + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +array bernoulli( + const array& p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array bernoulli( + T p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), key, s); +}; + +template +array bernoulli( + T p, + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), shape, key, s); +}; + +array bernoulli( + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array truncated_normal( + const array& lower, + const array& upper, + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array truncated_normal( + const array& lower, + const array& upper, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array gumbel( + const std::vector& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits, + int axis, + const std::vector& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits_, + int axis, + int num_samples, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits, + int axis = -1, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +} // namespace mlx::core::random diff --git a/lib/python3.11/site-packages/mlx/include/mlx/scheduler.h b/lib/python3.11/site-packages/mlx/include/mlx/scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..150cc96dbe6def42d2f6c3b6fe851c19cdbb7dbc --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/scheduler.h @@ -0,0 +1,173 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include + +#include "mlx/backend/metal/metal.h" +#include "mlx/device.h" +#include "mlx/stream.h" + +namespace mlx::core::scheduler { + +struct StreamThread { + std::mutex mtx; + std::queue> q; + std::condition_variable cond; + bool stop; + Stream stream; + std::thread thread; + + StreamThread(Stream stream) + : stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {} + + ~StreamThread() { + { + std::unique_lock lk(mtx); + stop = true; + } + cond.notify_one(); + thread.join(); + } + + void thread_fn() { + auto thread_pool = metal::new_scoped_memory_pool(); + metal::new_stream(stream); + while (true) { + std::function task; + { + std::unique_lock lk(mtx); + cond.wait(lk, [this] { return !this->q.empty() || this->stop; }); + if (q.empty() && stop) { + return; + } + task = std::move(q.front()); + q.pop(); + } + task(); + } + } + + template + void enqueue(F&& f) { + { + std::unique_lock lk(mtx); + if (stop) { + throw std::runtime_error( + "Cannot enqueue work after stream is stopped."); + } + q.emplace(std::forward(f)); + } + cond.notify_one(); + } +}; + +class Scheduler { + public: + Scheduler() : n_active_tasks_(0) { + if (metal::is_available()) { + default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); + } + default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); + } + + // Not copyable or moveable + Scheduler(const Scheduler&) = delete; + Scheduler(Scheduler&&) = delete; + Scheduler& operator=(const Scheduler&) = delete; + Scheduler& operator=(Scheduler&&) = delete; + + Stream new_stream(const Device& d) { + auto stream = Stream(streams_.size(), d); + streams_.push_back(new StreamThread{stream}); + return stream; + } + + template + void enqueue(const Stream& stream, F&& f); + + Stream get_default_stream(const Device& d) { + return default_streams_.at(d.type); + } + + void set_default_stream(const Stream& s) { + default_streams_.at(s.device.type) = s; + } + + void notify_new_task(const Stream& stream) { + { + std::unique_lock lk(mtx); + n_active_tasks_++; + } + completion_cv.notify_all(); + } + + void notify_task_completion(const Stream& stream) { + { + std::unique_lock lk(mtx); + n_active_tasks_--; + } + completion_cv.notify_all(); + } + + int n_active_tasks() const { + return n_active_tasks_; + } + + void wait_for_one() { + std::unique_lock lk(mtx); + int n_tasks_old = n_active_tasks(); + if (n_tasks_old > 1) { + completion_cv.wait(lk, [this, n_tasks_old] { + return this->n_active_tasks() != n_tasks_old; + }); + } + } + + ~Scheduler() { + for (auto s : streams_) { + delete s; + } + } + + private: + int n_active_tasks_; + std::vector streams_; + std::unordered_map default_streams_; + std::condition_variable completion_cv; + std::mutex mtx; +}; + +template +void Scheduler::enqueue(const Stream& stream, F&& f) { + streams_[stream.index]->enqueue(std::forward(f)); +} + +Scheduler& scheduler(); + +template +void enqueue(const Stream& stream, F&& f) { + scheduler().enqueue(stream, std::forward(f)); +} + +inline int n_active_tasks() { + return scheduler().n_active_tasks(); +} + +inline void notify_new_task(const Stream& stream) { + scheduler().notify_new_task(stream); +} + +inline void notify_task_completion(const Stream& stream) { + scheduler().notify_task_completion(stream); +} + +inline void wait_for_one() { + scheduler().wait_for_one(); +} + +} // namespace mlx::core::scheduler diff --git a/lib/python3.11/site-packages/mlx/include/mlx/stream.h b/lib/python3.11/site-packages/mlx/include/mlx/stream.h new file mode 100644 index 0000000000000000000000000000000000000000..d7b4268fd0a203685b7e080b20277716c28fb088 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/stream.h @@ -0,0 +1,32 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/device.h" + +namespace mlx::core { + +struct Stream { + int index; + Device device; + explicit Stream(int index, Device device) : index(index), device(device) {} +}; + +/** Get the default stream for the given device. */ +Stream default_stream(Device d); + +/** Make the stream the default for its device. */ +void set_default_stream(Stream s); + +/** Make a new stream on the given device. */ +Stream new_stream(Device d); + +inline bool operator==(const Stream& lhs, const Stream& rhs) { + return lhs.index == rhs.index; +} + +inline bool operator!=(const Stream& lhs, const Stream& rhs) { + return !(lhs == rhs); +} + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/transforms.h b/lib/python3.11/site-packages/mlx/include/mlx/transforms.h new file mode 100644 index 0000000000000000000000000000000000000000..caf648163aef9e2da3b0f1cac14dd95f28841bcc --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/transforms.h @@ -0,0 +1,187 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "array.h" + +namespace mlx::core { + +/** Fuse equivalent arrays to avoid duplicate execution. */ +void simplify(const std::vector& outputs); + +template +void simplify(Arrays... outputs) { + simplify(std::vector{std::forward(outputs)...}); +} + +void eval(const std::vector& outputs, bool retain_graph = false); + +template +void eval(Arrays... outputs) { + eval(std::vector{std::forward(outputs)...}, false); +} + +/** + * Computes the output and vector-Jacobian product (VJP) of a function. + * + * Computes the vector-Jacobian product of the vector of cotangents with the + * Jacobian of the function evaluated at the primals. Returns a pair of + * vectors of output arrays and VJP arrays. + **/ +std::pair, std::vector> vjp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& cotangents); + +/** + * Computes the output and vector-Jacobian product (VJP) of a unary function. + */ +std::pair vjp( + const std::function& fun, + const array& primal, + const array& cotangent); + +/** + * Computes the output and Jacobian-vector product (JVP) of a function. + * + * Computes the Jacobian-vector product of the Jacobian of the function + * evaluated at the primals with the vector of tangents. Returns a pair of + * vectors of output arrays and JVP arrays. + **/ +std::pair, std::vector> jvp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& tangents); + +/** + * Computes the output and Jacobian-vector product (JVP) of a unary function. + */ +std::pair jvp( + const std::function& fun, + const array& primal, + const array& tangent); + +// Return type of general value_and_grad: a function which takes an input +// vector of arrays and returns a pair of vectors of arrays one for the +// values and one for the gradients wrt the first value. +using ValueAndGradFn = + std::function, std::vector>( + const std::vector&)>; +using SimpleValueAndGradFn = std::function>( + const std::vector&)>; + +/** + * Returns a function which computes the value and gradient of the input + * function with respect to a vector of input arrays. + **/ +ValueAndGradFn value_and_grad( + const std::function(const std::vector&)>& fun, + const std::vector& argnums); + +/** + * Returns a function which computes the value and gradient of the input + * function with respect to a single input array. + **/ +ValueAndGradFn inline value_and_grad( + const std::function(const std::vector&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the value and gradient of the unary + * input function. + **/ +std::function(const array&)> inline value_and_grad( + const std::function& fun) { + return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + const std::vector& argnums) { + return [fun, argnums](auto inputs) { + auto result = value_and_grad( + [fun](auto inputs) { return std::vector{fun(inputs)}; }, + argnums)(inputs); + + return std::make_pair(result.first[0], result.second); + }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the input function with + * respect to a vector of input arrays. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The vector of `argnums` specifies which the arguments to compute + * the gradient with respect to. At least one argument must be specified. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + const std::vector& argnums) { + auto fn = value_and_grad(fun, argnums); + return [fn](const std::vector& inputs) { return fn(inputs).second; }; +} + +/** + * Returns a function which computes the gradient of the input function with + * respect to a single input array. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The optional `argnum` index specifies which the argument to compute + * the gradient with respect to and defaults to 0. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + int argnum = 0) { + return grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the unary input function. + **/ +std::function inline grad( + const std::function& fun) { + auto fn = value_and_grad(fun); + return [fn](const array& input) { return fn(input).second; }; +} + +/** + * Automatically vectorize a unary function over the requested axes. + */ +std::function vmap( + const std::function& fun, + int in_axis = 0, + int out_axis = 0); + +/** + * Automatically vectorize a binary function over the requested axes. + */ +std::function vmap( + const std::function& fun, + int in_axis_a = 0, + int in_axis_b = 0, + int out_axis = 0); + +/** + * Automatically vectorize a function over the requested axes. + * + * The input function to `vmap` takes as an argument a vector of arrays and + * returns a vector of arrays. Optionally specify the axes to vectorize over + * with `in_axes` and `out_axes`, otherwise a default of 0 is used. + * Returns a vectorized function with the same signature as the input + * function. + */ +std::function(const std::vector&)> vmap( + const std::function(const std::vector&)>& fun, + const std::vector& in_axes = {}, + const std::vector& out_axes = {}); + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h b/lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..201e8009b314607b93005d5a8bb493e0baa43804 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/transforms_impl.h @@ -0,0 +1,17 @@ +// Copyright © 2023 Apple Inc. + +namespace mlx::core::detail { + +std::pair, std::vector> vmap_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs, + const std::vector& in_axes); + +std::vector vmap_replace( + const std::vector& inputs, + const std::vector& s_inputs, + const std::vector& s_outputs, + const std::vector& in_axes, + const std::vector& out_axes); + +} // namespace mlx::core::detail diff --git a/lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h b/lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h new file mode 100644 index 0000000000000000000000000000000000000000..5951941747c022284e7926b88ac4bfebbdf78226 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/types/bf16.h @@ -0,0 +1,187 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#define __MLX_BFLOAT_NAN__ 0x7FC0 + +namespace mlx::core { + +namespace { +union float_bits_bf16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_BFloat16 { + uint16_t bits_; + + // Default constructor + _MLX_BFloat16() = default; + + // Default copy constructor + _MLX_BFloat16(_MLX_BFloat16 const&) = default; + + // Appease std::vector for being special + _MLX_BFloat16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_BFloat16& operator=(const float& x) { + return (*this = _MLX_BFloat16(x)); + } + + // From float32 + _MLX_BFloat16(const float& x) { + if (std::isnan(x)) { + bits_ = __MLX_BFLOAT_NAN__; + } else { + // Union + float_bits_bf16 in; + + // Take bits + in.f = x; + + // Round to nearest even + in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); + + // Take upper 16 bits + bits_ = in.u >> 16; + } + } + + // To float32 + operator float() const { + // Union + float_bits_bf16 out; + + // Upper 16 bits are the data and lower 16 bits are 0s + out.u = ((uint32_t)bits_) << 16; + + return out.f; + } +}; + +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, double, double, double); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +#undef bfloat_binop + +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, double, double); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop + +// Negative +inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define bfloat_inplace_op(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_op(+, operator+=); +bfloat_inplace_op(-, operator-=); +bfloat_inplace_op(*, operator*=); +bfloat_inplace_op(/, operator/=); + +#undef bfloat_inplace_op + +// Bitwise ops + +#define bfloat_bitop(__op__, __operator__) \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +bfloat_bitop(|, operator|); +bfloat_bitop(&, operator&); +bfloat_bitop(^, operator^); + +#undef bfloat_bitop + +#define bfloat_inplace_bitop(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_bitop(|, operator|=); +bfloat_inplace_bitop(&, operator&=); +bfloat_inplace_bitop(^, operator^=); + +#undef bfloat_inplace_bitop + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/types/complex.h b/lib/python3.11/site-packages/mlx/include/mlx/types/complex.h new file mode 100644 index 0000000000000000000000000000000000000000..55cbe447af0b659771d6cc0f6c84664cf5fba2ae --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/types/complex.h @@ -0,0 +1,77 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct complex64_t; + +template +static constexpr bool can_convert_to_complex64 = + !std::is_same_v && std::is_convertible_v; + +struct complex64_t : public std::complex { + complex64_t(float v, float u) : std::complex(v, u){}; + complex64_t(std::complex v) : std::complex(v){}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex64_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +inline bool operator>=(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || + (a.real() == b.real() && a.imag() >= b.imag()); +} + +inline bool operator>(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); +} + +inline bool operator<=(const complex64_t& a, const complex64_t& b) { + return operator>=(b, a); +} + +inline bool operator<(const complex64_t& a, const complex64_t& b) { + return operator>(b, a); +} + +inline complex64_t operator-(const complex64_t& v) { + return -static_cast>(v); +} + +// clang-format off +#define complex_binop_helper(_op_, _operator_, itype) \ + inline complex64_t _operator_(itype x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, itype y) { \ + return static_cast>(x) _op_ y; \ + } + +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ + complex_binop_helper(_op_, _operator_, const std::complex&) \ + complex_binop_helper(_op_, _operator_, float) +// clang-format on + +complex_binop(+, operator+) + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h b/lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..c174afebcb4e221a7437af2ff6e0e6812fe97939 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/types/fp16.h @@ -0,0 +1,234 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#define __MLX_HALF_NAN__ 0x7D00 + +namespace mlx::core { + +namespace { +union float_bits_fp16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_Float16 { + uint16_t bits_; + + // Default constructor + _MLX_Float16() = default; + + // Default copy constructor + _MLX_Float16(_MLX_Float16 const&) = default; + + // Appease std::vector for being special + _MLX_Float16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_Float16& operator=(const float& x) { + return (*this = _MLX_Float16(x)); + } + + // From float32 + _MLX_Float16(const float& x) : bits_(0) { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 in; + + // Take fp32 bits + in.f = x; + + // Find and take sign bit + uint32_t x_sign_32 = in.u & uint32_t(0x80000000); + uint16_t x_sign_16 = (x_sign_32 >> 16); + + if (std::isnan(x)) { + bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); + } else { + // Union + float_bits_fp16 inf_scale, zero_scale, magic_bits; + + // Find exponent bits and take the max supported by half + uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); + uint32_t max_expo_32 = uint32_t(0x38800000); + x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; + x_expo_32 += uint32_t(15) << 23; + + // Handle scaling to inf as needed + inf_scale.u = uint32_t(0x77800000); + zero_scale.u = uint32_t(0x08800000); + + // Combine with magic and let addition do rounding + magic_bits.u = x_expo_32; + magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; + + // Take the lower 5 bits of the exponent + uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); + + // Collect the lower 12 bits which have the mantissa + uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); + + // Combine sign, exp and mantissa + bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); + } + } + + // To float32 + operator float() const { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 out; + + uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); + uint32_t base = (bits_ << 16); + uint32_t two_base = base + base; + + uint32_t denorm_max = 1u << 27; + if (two_base < denorm_max) { + out.u = uint32_t(126) << 23; // magic mask + out.u |= (two_base >> 17); // Bits from fp16 + out.f -= 0.5f; // magic bias + } else { + out.u = uint32_t(0xE0) << 23; // exponent offset + out.u += (two_base >> 4); // Bits from fp16 + float out_unscaled = out.f; // Store value + out.u = uint32_t(0x7800000); // exponent scale + out.f *= out_unscaled; + } + + // Add sign + out.u |= x_sign_32; + + return out.f; + } +}; + +#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define half_binop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, float, float, float); \ + half_binop_helper(__op__, __operator__, double, double, double); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); + +half_binop(+, operator+); +half_binop(-, operator-); +half_binop(*, operator*); +half_binop(/, operator/); + +#undef half_binop + +// Comparison ops +#define half_compop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, bool, float, float); \ + half_binop_helper(__op__, __operator__, bool, double, double); \ + half_binop_helper(__op__, __operator__, bool, int32_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + half_binop_helper(__op__, __operator__, bool, int64_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint64_t, float); + +half_compop(>, operator>); +half_compop(<, operator<); +half_compop(>=, operator>=); +half_compop(<=, operator<=); +half_compop(==, operator==); +half_compop(!=, operator!=); + +#undef half_compop + +// Negative +inline _MLX_Float16 operator-(_MLX_Float16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define half_inplace_op(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +half_inplace_op(+, operator+=); +half_inplace_op(-, operator-=); +half_inplace_op(*, operator*=); +half_inplace_op(/, operator/=); + +#undef half_inplace_op + +// Bitwise ops + +#define half_bitop(__op__, __operator__) \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +half_bitop(|, operator|); +half_bitop(&, operator&); +half_bitop(^, operator^); + +#undef half_bitop + +#define half_inplace_bitop(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +half_inplace_bitop(|, operator|=); +half_inplace_bitop(&, operator&=); +half_inplace_bitop(^, operator^=); + +#undef half_inplace_bitop + +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h b/lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h new file mode 100644 index 0000000000000000000000000000000000000000..430279565726e75bc0b9b5b421e2f61eb8502b7a --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/types/half_types.h @@ -0,0 +1,56 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#include +namespace mlx::core { +typedef __fp16 float16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/fp16.h" +namespace mlx::core { +typedef struct _MLX_Float16 float16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#ifdef __ARM_FEATURE_BF16 + +#include +namespace mlx::core { +typedef __bf16 bfloat16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/bf16.h" +namespace mlx::core { +typedef struct _MLX_BFloat16 bfloat16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_BF16 + +#ifdef ADD_HALF_BINOPS +namespace mlx::core { + +// clang-format off +#define fp16_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp16_bf16_binop_helper(+, operator+) +fp16_bf16_binop_helper(-, operator-) +fp16_bf16_binop_helper(*, operator*) +fp16_bf16_binop_helper(/, operator/) +// clang-format on + +} // namespace mlx::core +#endif diff --git a/lib/python3.11/site-packages/mlx/include/mlx/utils.h b/lib/python3.11/site-packages/mlx/include/mlx/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..823b4c872a1cfff1e047e28945fe87d74c8babb5 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/include/mlx/utils.h @@ -0,0 +1,44 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "array.h" +#include "device.h" +#include "dtype.h" +#include "stream.h" + +namespace mlx::core { + +/** The type from promoting the arrays' types with one another. */ +Dtype result_type(const std::vector& arrays); + +std::vector broadcast_shapes( + const std::vector& s1, + const std::vector& s2); + +bool is_same_shape(const std::vector& arrays); + +/** + * Returns the axis normalized to be in the range [0, ndim). + * Based on numpy's normalize_axis_index. See + * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html + */ +int normalize_axis(int axis, int ndim); + +std::ostream& operator<<(std::ostream& os, const Device& d); +std::ostream& operator<<(std::ostream& os, const Stream& s); +std::ostream& operator<<(std::ostream& os, const Dtype& d); +std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); +std::ostream& operator<<(std::ostream& os, array a); +std::ostream& operator<<(std::ostream& os, const std::vector& v); +std::ostream& operator<<(std::ostream& os, const std::vector& v); +inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { + return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j"; +} +inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { + return os << static_cast(v); +} +inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { + return os << static_cast(v); +} +} // namespace mlx::core diff --git a/lib/python3.11/site-packages/mlx/lib/libmlx.dylib b/lib/python3.11/site-packages/mlx/lib/libmlx.dylib new file mode 100644 index 0000000000000000000000000000000000000000..a9b17d74e39b1690b24d9f70b6d0fdc1477e49d6 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/lib/libmlx.dylib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8abefe46a1f39c92b28464814f05a730fa9899b17757703403c6ef362f06ac93 +size 12420704 diff --git a/lib/python3.11/site-packages/mlx/lib/mlx.metallib b/lib/python3.11/site-packages/mlx/lib/mlx.metallib new file mode 100644 index 0000000000000000000000000000000000000000..111aa15f90129127d5f8c559a7cdae722534dcd9 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/lib/mlx.metallib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2eedf41000ed270283da11d889bb101aa4c88c6f8f0ec68fe6b040a5be424501 +size 59495531 diff --git a/lib/python3.11/site-packages/mlx/nn/__init__.py b/lib/python3.11/site-packages/mlx/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb7cc63d28f1655b5951e2ed3105b5bb65189a1 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/nn/__init__.py @@ -0,0 +1,5 @@ +# Copyright © 2023 Apple Inc. + +from mlx.nn import losses +from mlx.nn.layers import * +from mlx.nn.utils import value_and_grad diff --git a/lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89ccbda6a22dab2afda4af184e00c6ecd4ab888a Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/__pycache__/__init__.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8371c6275032768ce291c9ffc2f2123559cae047 Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/__pycache__/losses.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4ded89a910dd42237f2a077690f78510c3555a8 Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/__pycache__/utils.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__init__.py b/lib/python3.11/site-packages/mlx/nn/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..80d500c9faeaa5351d9b554c60c5ed006243e780 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/nn/layers/__init__.py @@ -0,0 +1,63 @@ +# Copyright © 2023 Apple Inc. + +from mlx.nn.layers.activations import ( + CELU, + ELU, + GELU, + SELU, + Hardswish, + LeakyReLU, + LogSigmoid, + LogSoftmax, + Mish, + PReLU, + ReLU, + ReLU6, + SiLU, + Softmax, + Softplus, + Softsign, + Step, + Tanh, + celu, + elu, + gelu, + gelu_approx, + gelu_fast_approx, + hardswish, + leaky_relu, + log_sigmoid, + log_softmax, + mish, + prelu, + relu, + relu6, + selu, + silu, + softmax, + softplus, + softsign, + step, + tanh, +) +from mlx.nn.layers.base import Module +from mlx.nn.layers.containers import Sequential +from mlx.nn.layers.convolution import Conv1d, Conv2d +from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d +from mlx.nn.layers.embedding import Embedding +from mlx.nn.layers.linear import Bilinear, Identity, Linear +from mlx.nn.layers.normalization import ( + BatchNorm, + GroupNorm, + InstanceNorm, + LayerNorm, + RMSNorm, +) +from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding +from mlx.nn.layers.quantized import QuantizedLinear +from mlx.nn.layers.transformer import ( + MultiHeadAttention, + Transformer, + TransformerEncoder, + TransformerEncoderLayer, +) diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49fd39acdb57f5985b6276ccaa677d96576af8f7 Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/__init__.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66be8623924c8fa753040396fcfb4702204c7cde Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/activations.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbb2721e78295c4237307856d09b64f70795661a Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/base.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..383a0074044b157fbc5e180b000f9908bf2e0c88 Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/containers.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9760a34bca18d61775836f1a2893d7e52ce9f3fe Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/convolution.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3173ca1e44ab94df19aed0c610d539b71554481 Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/dropout.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4dbcc0a5805e5c55bf9ac1ea6ef46509383bc7b Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/embedding.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4b6078542c8c9f1ac05b0392f3233ac8e019abd Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/linear.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b439138b76426bdd609dae2212aa9432a125939d Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/normalization.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..553cc4920f4cfa7bb81a575cc410e54dbdf124f9 Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ecc0d0fc7720f260f5650ed3ad99735c4c291a3 Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/quantized.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23ca9ca9ff46e1b8a05f2a1cbd826e4330d0bc8a Binary files /dev/null and b/lib/python3.11/site-packages/mlx/nn/layers/__pycache__/transformer.cpython-311.pyc differ diff --git a/lib/python3.11/site-packages/mlx/nn/layers/activations.py b/lib/python3.11/site-packages/mlx/nn/layers/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb4483ac6aea8ce6c6d3a2010ad45cec39c8d50 --- /dev/null +++ b/lib/python3.11/site-packages/mlx/nn/layers/activations.py @@ -0,0 +1,501 @@ +# Copyright © 2023 Apple Inc. + +import math + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +def _make_activation_module(f): + def decorator(klass): + klass.__doc__ = f.__doc__ + klass.__call__ = lambda self, x: f(x) + return klass + + return decorator + + +def sigmoid(x): + r"""Applies the element-wise function: + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + """ + return mx.sigmoid(x) + + +def relu(x): + r"""Applies the Rectified Linear Unit. + + Simply ``mx.maximum(x, 0)``. + """ + return mx.maximum(x, 0) + + +def leaky_relu(x, negative_slope=0.01): + r"""Applies the Leaky Rectified Linear Unit. + + Simply ``mx.maximum(negative_slope * x, x)``. + """ + return mx.maximum(negative_slope * x, x) + + +def log_softmax(x, axis=-1): + r"""Applies the Log Softmax function. + + Applies :math:`x + \log \sum_i e^{x_i}` element wise. + """ + return x - mx.logsumexp(x, axis=axis, keepdims=True) + + +def elu(x, alpha=1.0): + r"""Applies the Exponential Linear Unit. + + Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``. + """ + return mx.where(x > 0, x, alpha * (mx.exp(x) - 1)) + + +def relu6(x): + r"""Applies the Rectified Linear Unit 6. + + Applies :math:`\min(\max(x, 0), 6)` element wise. + """ + return mx.minimum(mx.maximum(x, 0), 6.0) + + +def softmax(x, axis=-1): + r"""Applies the Softmax function. + + Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise. + """ + return mx.softmax(x, axis=axis) + + +def softplus(x): + r"""Applies the Softplus function. + + Applies :math:`\log(1 + \exp(x))` element wise. + """ + return mx.logaddexp(x, 0) + + +def softsign(x): + r"""Applies the Softsign function. + + Applies :math:`\frac{x}{1 + |x|}` element wise. + """ + return mx.divide(x, 1 + mx.abs(x)) + + +def celu(x, alpha=1.0): + r"""Applies the Continuously Differentiable Exponential Linear Unit. + + Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))` + element wise. + """ + return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1) + + +def silu(x): + r"""Applies the Sigmoid Linear Unit. Also known as Swish. + + Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is + the logistic sigmoid. + """ + return x * mx.sigmoid(x) + + +def log_sigmoid(x): + r"""Applies the Log Sigmoid function. + + Applies :math:`\log(\sigma(x)) = -\log(1 + e^{-x})` element wise. + """ + return -softplus(-x) + + +def gelu(x): + r"""Applies the Gaussian Error Linear Units function. + + .. math:: + \\textrm{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Gaussian CDF. + + See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster + approximations. + """ + return x * (1 + mx.erf(x / math.sqrt(2))) / 2 + + +def gelu_approx(x): + r"""An approximation to Gaussian Error Linear Unit. + + See :func:`gelu` for the exact computation. + + This function approximates ``gelu`` with a maximum absolute error :math:`< + 0.0003` in the range :math:`[-6, 6]` using the following + + .. math:: + + x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right) + + where :math:`\sigma(\cdot)` is the logistic sigmoid. + """ + return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square())) + + +def gelu_fast_approx(x): + r"""A fast approximation to Gaussian Error Linear Unit. + + See :func:`gelu` for the exact computation. + + This function approximates ``gelu`` with a maximum absolute error :math:`< + 0.015` in the range :math:`[-6, 6]` using the following + + .. math:: + + x = x \sigma\left(1.773 x\right) + + where :math:`\sigma(\cdot)` is the logistic sigmoid. + """ + return x * mx.sigmoid(1.773 * x) + + +@_make_activation_module +class Sigmoid(Module): + r"""Applies the sigmoid function, element-wise. + + .. math:: + \text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)} + """ + pass + + +def step(x: mx.array, threshold: float = 0.0): + r"""Applies the Step Activation Function. + + This function implements a binary step activation, where the output is set + to 1 if the input is greater than a specified threshold, and 0 otherwise. + + .. math:: + \text{step}(x) = \begin{cases} + 0 & \text{if } x < \text{threshold} \\ + 1 & \text{if } x \geq \text{threshold} + \end{cases} + + Args: + threshold: The value to threshold at. + """ + + return mx.where(x > threshold, 1, 0) + + +def selu(x): + r"""Applies the Scaled Exponential Linear Unit. + + .. math:: + \text{selu}(x) = \begin{cases} + \lambda x & \text{if } x > 0 \\ + \lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0 + \end{cases} + + where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`. + + See also :func:`elu`. + """ + return elu(x, 1.67326) * 1.0507 + + +def prelu(x: mx.array, alpha: mx.array) -> mx.array: + r"""Applies the element-wise parametric ReLU. + + .. math:: + \text{PReLU}(x) = \max(0,x) + a * \min(0,x) + + where :math:`a` is an array. + """ + return mx.maximum(0, x) + alpha * mx.minimum(0, x) + + +def mish(x: mx.array) -> mx.array: + r"""Applies the Mish function, element-wise. + Mish: A Self Regularized Non-Monotonic Neural Activation Function. + + Reference: https://arxiv.org/abs/1908.08681 + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + """ + return x * mx.tanh(softplus(x)) + + +def hardswish(x): + r"""Applies the hardswish function, element-wise. + + .. math:: + \text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6 + """ + max_x_3 = mx.maximum(x + 3, 0) + return x * mx.minimum(max_x_3, 6) / 6 + + +@_make_activation_module(mish) +class Mish(Module): + r"""Applies the Mish function, element-wise. + + Reference: https://arxiv.org/abs/1908.08681 + + .. math:: + \text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x)) + + """ + pass + + +@_make_activation_module(relu) +class ReLU(Module): + r"""Applies the Rectified Linear Unit. + Simply ``mx.maximum(x, 0)``. + + See :func:`relu`, for the functional equivalent. + """ + pass + + +class LeakyReLU(Module): + r"""Applies the Leaky Rectified Linear Unit. + + Simply ``mx.maximum(negative_slope * x, x)``. + + Args: + negative_slope: Controls the angle of the negative slope. Default: 1e-2. + """ + + def __init__(self, negative_slope=1e-2): + super().__init__() + self._negative_slope = negative_slope + + def __call__(self, x): + return leaky_relu(x, self._negative_slope) + + +class ELU(Module): + r"""Applies the Exponential Linear Unit. + Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``. + + See :func:`elu`, for the functional equivalent. + + Args: + alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0 + """ + + def __init__(self, alpha=1.0): + super().__init__() + self._alpha = alpha + + def __call__(self, x): + return elu(x, self._alpha) + + +@_make_activation_module(relu6) +class ReLU6(Module): + r"""Applies the Rectified Linear Unit 6. + + See :func:`relu6`, for the functional equivalent. + """ + pass + + +@_make_activation_module(softmax) +class Softmax(Module): + r"""Applies the Softmax function. + + See :func:`softmax`, for the functional equivalent. + """ + pass + + +@_make_activation_module(softplus) +class Softplus(Module): + r"""Applies the Softplus function. + + See :func:`softplus`, for the functional equivalent. + """ + pass + + +@_make_activation_module(softsign) +class Softsign(Module): + r"""Applies the Softsign function. + + See :func:`softsign`, for the functional equivalent. + """ + pass + + +class CELU(Module): + r"""Applies the Continuously Differentiable Exponential Linear Unit. + Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))` + element wise. + + See :func:`celu`, for the functional equivalent. + + Args: + alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0 + """ + + def __init__(self, alpha=1.0): + super().__init__() + self._alpha = alpha + + def __call__(self, x): + return celu(x, self._alpha) + + +@_make_activation_module(silu) +class SiLU(Module): + r"""Applies the Sigmoid Linear Unit. Also known as Swish. + + See :func:`silu`, for the functional equivalent. + """ + pass + + +@_make_activation_module(log_softmax) +class LogSoftmax(Module): + r"""Applies the Log Softmax function. + + See :func:`log_softmax`, for the functional equivalent. + """ + pass + + +@_make_activation_module(log_sigmoid) +class LogSigmoid(Module): + r"""Applies the Log Sigmoid function. + + See :func:`log_sigmoid`, for the functional equivalent. + """ + pass + + +class PReLU(Module): + r"""Applies the element-wise parametric ReLU. + Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a` + is an array. + + See :func:`prelu`, for the functional equivalent. + + Args: + num_parameters: number of :math:`a` to learn. Default: 1 + init: the initial value of :math:`a`. Default: 0.25 + """ + + def __init__(self, num_parameters=1, init=0.25): + super().__init__() + self.weight = mx.full([num_parameters], init) + + def __call__(self, x: mx.array): + return prelu(x, self.weight) + + +class GELU(Module): + r"""Applies the Gaussian Error Linear Units. + + .. math:: + \textrm{GELU}(x) = x * \Phi(x) + + where :math:`\Phi(x)` is the Gaussian CDF. + + However, if ``approx`` is set to 'precise' or 'fast' it applies + + .. math:: + \textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\ + \textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right) + + respectively. + + See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the + functional equivalents and information regarding error bounds. + + Args: + approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any. + """ + + def __init__(self, approx="none"): + super().__init__() + + if approx == "none": + self._act = gelu + elif approx == "precise": + self._act = gelu_approx + elif approx == "fast": + self._act = gelu_fast_approx + else: + raise ValueError( + f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given" + ) + + def __call__(self, x): + return self._act(x) + + +def tanh(x): + """Applies the hyperbolic tangent function. + + Simply ``mx.tanh(x)``. + """ + return mx.tanh(x) + + +@_make_activation_module(tanh) +class Tanh(Module): + r"""Applies the hyperbolic tangent function. + + See :func:`tanh`, for the functional equivalent. + """ + pass + + +@_make_activation_module(hardswish) +class Hardswish(Module): + r"""Applies the hardswish function, element-wise. + + See :func:`hardswish`, for the functional equivalent. + """ + pass + + +class Step(Module): + r"""Applies the Step Activation Function. + + This function implements a binary step activation, where the output is set + to 1 if the input is greater than a specified threshold, and 0 otherwise. + + .. math:: + \text{step}(x) = \begin{cases} + 0 & \text{if } x < \text{threshold} \\ + 1 & \text{if } x \geq \text{threshold} + \end{cases} + + Args: + threshold: The value to threshold at. + """ + + def __init__(self, threshold: float = 0.0): + super().__init__() + self.threshold = threshold + + def __call__(self, x: mx.array): + return step(x, self.threshold) + + +@_make_activation_module(selu) +class SELU(Module): + r"""Applies the Scaled Exponential Linear Unit. + + See :func:`selu`, for the functional equivalent. + """ + pass