Arrcttacsrks commited on
Commit
8cb76ae
·
verified ·
1 Parent(s): 8f8a164

Upload llama.cpp/ggml/src/ggml-cann.cpp with huggingface_hub

Browse files
Files changed (1) hide show
  1. llama.cpp/ggml/src/ggml-cann.cpp +2128 -0
llama.cpp/ggml/src/ggml-cann.cpp ADDED
@@ -0,0 +1,2128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #include "ggml-cann.h"
24
+
25
+ #include <acl/acl.h>
26
+ #include <stdarg.h>
27
+
28
+ #include <cmath>
29
+ #include <cstdio>
30
+ #include <cstring>
31
+ #include <mutex>
32
+
33
+ #include "ggml-impl.h"
34
+ #include "ggml-backend-impl.h"
35
+ #include "ggml-cann/aclnn_ops.h"
36
+ #include "ggml-cann/common.h"
37
+
38
+ #define GGML_COMMON_DECL_C
39
+
40
+ #include "ggml-common.h"
41
+
42
+ #define GGML_CANN_NAME "CANN"
43
+
44
+ /**
45
+ * @brief Handles CANN errors by printing an error message and aborting.
46
+ *
47
+ * @param stmt The statement that caused the error.
48
+ * @param func The function in which the error occurred.
49
+ * @param file The file in which the error occurred.
50
+ * @param line The line number where the error occurred.
51
+ * @param msg The error message.
52
+ */
53
+ [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
54
+ const char* file, int line, const char* msg) {
55
+ int32_t id = -1;
56
+ aclrtGetDevice(&id);
57
+
58
+ GGML_LOG_ERROR("CANN error: %s\n", msg);
59
+ GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
60
+ file, line);
61
+ GGML_LOG_ERROR(" %s\n", stmt);
62
+ // abort with GGML_ASSERT to get a stack trace
63
+ GGML_ABORT("CANN error");
64
+ }
65
+
66
+ /**
67
+ * @brief Sets the device to be used by CANN.
68
+ *
69
+ * @param device The device ID to set.
70
+ */
71
+ void ggml_cann_set_device(const int32_t device) {
72
+ // TODO: uncomment these lines after empty context has fixed.
73
+ // int current_device;
74
+ // ACL_CHECK(aclrtGetDevice(&current_device));
75
+
76
+ // if (device == current_device) {
77
+ // return;
78
+ // }
79
+ ACL_CHECK(aclrtSetDevice(device));
80
+ }
81
+
82
+ /**
83
+ * @brief Retrieves the current device ID.
84
+ *
85
+ * @return The current device ID.
86
+ */
87
+ int32_t ggml_cann_get_device() {
88
+ int32_t id;
89
+ ACL_CHECK(aclrtGetDevice(&id));
90
+ return id;
91
+ }
92
+
93
+ /**
94
+ * @brief Initialize the CANN device information.
95
+ *
96
+ * This function initializes the CANN device information by obtaining the
97
+ * device count and setting the memory allocation granularity for each device.
98
+ *
99
+ * @return A structure containing the device information.
100
+ */
101
+ static ggml_cann_device_info ggml_cann_init() {
102
+ ggml_cann_device_info info = {};
103
+
104
+ aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
105
+
106
+ if (err != ACL_SUCCESS) {
107
+ GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
108
+ __func__, aclGetRecentErrMsg());
109
+ return info;
110
+ }
111
+
112
+ GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
113
+
114
+ for (int id = 0; id < info.device_count; ++id) {
115
+ aclrtPhysicalMemProp prop = {};
116
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
117
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
118
+ prop.memAttr = ACL_HBM_MEM_HUGE;
119
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
120
+ prop.location.id = id;
121
+ prop.reserve = 0;
122
+ ACL_CHECK(aclrtMemGetAllocationGranularity(
123
+ &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
124
+ &info.devices[id].vmm_granularity));
125
+ }
126
+
127
+ // TODO: add more device info later.
128
+ return info;
129
+ }
130
+
131
+ /**
132
+ * @brief Retrieve the CANN device information.
133
+ *
134
+ * This function returns a reference to a structure containing the CANN device
135
+ * information. The device information is initialized once and reused on
136
+ * subsequent calls.
137
+ *
138
+ * @return A reference to the structure containing the device information.
139
+ */
140
+ const ggml_cann_device_info& ggml_cann_info() {
141
+ static ggml_cann_device_info info = ggml_cann_init();
142
+ return info;
143
+ }
144
+
145
+ //#define DEBUG_CANN_MALLOC
146
+ /**
147
+ * @brief A pool of CANN buffers(legacy).
148
+ *
149
+ * This class manages a pool of CANN buffers for a specific device.
150
+ */
151
+ struct ggml_cann_pool_leg : public ggml_cann_pool {
152
+ /**
153
+ * @brief The maximum number of buffers in the pool.
154
+ */
155
+ static const int MAX_BUFFERS = 256;
156
+
157
+ /**
158
+ * @brief The device ID associated with this buffer pool.
159
+ */
160
+ int device;
161
+
162
+ /**
163
+ * @brief Structure representing a CANN buffer.
164
+ */
165
+ struct ggml_cann_buffer {
166
+ void* ptr = nullptr; ///< Pointer to the buffer memory.
167
+ size_t size = 0; ///< Size of the buffer.
168
+ };
169
+
170
+ /**
171
+ * @brief Array of CANN buffers in the pool.
172
+ */
173
+ ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
174
+
175
+ /**
176
+ * @brief Total size of all buffers in the pool.
177
+ */
178
+ size_t pool_size = 0;
179
+
180
+ /**
181
+ * @brief Constructor to initialize the buffer pool for a specific device.
182
+ *
183
+ * @param device The device ID to associate with this buffer pool.
184
+ */
185
+ explicit ggml_cann_pool_leg(int device) : device(device) {}
186
+
187
+ /**
188
+ * @brief Destructor to free all buffers in the pool.
189
+ */
190
+ ~ggml_cann_pool_leg() {
191
+ ggml_cann_set_device(device);
192
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
193
+ ggml_cann_buffer& b = buffer_pool[i];
194
+ if (b.ptr != nullptr) {
195
+ ACL_CHECK(aclrtFree(b.ptr));
196
+ pool_size -= b.size;
197
+ }
198
+ }
199
+ GGML_ASSERT(pool_size == 0);
200
+ }
201
+
202
+ /**
203
+ * @brief Allocate a buffer of the given size.
204
+ *
205
+ * @param size The size of the buffer to allocate.
206
+ * @param actual_size A pointer to a variable to receive the actual size of
207
+ * the allocated buffer.
208
+ * @return A pointer to the allocated buffer.
209
+ */
210
+ void* alloc(size_t size, size_t* actual_size) override {
211
+ #ifdef DEBUG_CANN_MALLOC
212
+ int nnz = 0;
213
+ size_t max_size = 0;
214
+ #endif
215
+ size_t best_diff = 1ull << 36;
216
+ int ibest = -1;
217
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
218
+ ggml_cann_buffer& b = buffer_pool[i];
219
+ if (b.ptr != nullptr) {
220
+ #ifdef DEBUG_CANN_MALLOC
221
+ ++nnz;
222
+ if (b.size > max_size) max_size = b.size;
223
+ #endif
224
+ if (b.size >= size) {
225
+ size_t diff = b.size - size;
226
+ if (diff < best_diff) {
227
+ best_diff = diff;
228
+ ibest = i;
229
+ if (!best_diff) {
230
+ void* ptr = b.ptr;
231
+ *actual_size = b.size;
232
+ b.ptr = nullptr;
233
+ b.size = 0;
234
+ return ptr;
235
+ }
236
+ }
237
+ }
238
+ }
239
+ }
240
+ if (ibest >= 0) {
241
+ ggml_cann_buffer& b = buffer_pool[ibest];
242
+ void* ptr = b.ptr;
243
+ *actual_size = b.size;
244
+ b.ptr = nullptr;
245
+ b.size = 0;
246
+ return ptr;
247
+ }
248
+ void* ptr;
249
+ size_t look_ahead_size = (size_t)(1.05 * size);
250
+ look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
251
+ ggml_cann_set_device(device);
252
+ ACL_CHECK(
253
+ aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
254
+ *actual_size = look_ahead_size;
255
+ pool_size += look_ahead_size;
256
+ #ifdef DEBUG_CANN_MALLOC
257
+ GGML_LOG_INFO(
258
+ "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
259
+ "requested %u MB\n",
260
+ __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
261
+ (uint32_t)(pool_size / 1024 / 1024),
262
+ (uint32_t)(size / 1024 / 1024));
263
+ #endif
264
+ return ptr;
265
+ }
266
+
267
+ /**
268
+ * @brief Free a buffer and return it to the pool.
269
+ *
270
+ * @param ptr Pointer to the buffer to free.
271
+ * @param size Size of the buffer to free.
272
+ */
273
+ void free(void* ptr, size_t size) override {
274
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
275
+ ggml_cann_buffer& b = buffer_pool[i];
276
+ if (b.ptr == nullptr) {
277
+ b.ptr = ptr;
278
+ b.size = size;
279
+ return;
280
+ }
281
+ }
282
+ // memory should always buffered. these memory may still needed by
283
+ // tasks in stream.
284
+ // TODO, fix me.
285
+ GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
286
+ }
287
+ };
288
+
289
+ /**
290
+ * @brief A pool of CANN buffers with virtual memory.
291
+ *
292
+ * This class manages a pool of CANN buffers with virtual memory for a specific
293
+ * device.
294
+ */
295
+ struct ggml_cann_pool_vmm : public ggml_cann_pool {
296
+ /**
297
+ * @brief The maximum size of the virtual memory pool (32 GB).
298
+ */
299
+ static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
300
+
301
+ /**
302
+ * @brief The device ID associated with this buffer pool.
303
+ */
304
+ int device;
305
+
306
+ /**
307
+ * @brief Pointer to the start of the virtual memory pool.
308
+ */
309
+ void* pool_addr = 0;
310
+
311
+ /**
312
+ * @brief Amount of virtual memory used in the pool.
313
+ */
314
+ size_t pool_used = 0;
315
+
316
+ /**
317
+ * @brief Total size of the virtual memory pool.
318
+ */
319
+ size_t pool_size = 0;
320
+
321
+ /**
322
+ * @brief Allocation granularity for the virtual memory pool.
323
+ */
324
+ size_t granularity;
325
+
326
+ /**
327
+ * @brief Handles for the physical memory allocated.
328
+ */
329
+ std::vector<aclrtDrvMemHandle> handles;
330
+
331
+ /**
332
+ * @brief Offsets for the mapped memory regions.
333
+ */
334
+ std::vector<void*> map_offsets;
335
+
336
+ /**
337
+ * @brief Constructor to initialize the buffer pool with virtual memory for
338
+ * a specific device.
339
+ *
340
+ * @param device The device ID to associate with this buffer pool.
341
+ */
342
+ explicit ggml_cann_pool_vmm(int device)
343
+ : device(device),
344
+ granularity(ggml_cann_info().devices[device].vmm_granularity) {}
345
+
346
+ /**
347
+ * @brief Destructor to free all buffers in the virtual memory pool.
348
+ */
349
+ ~ggml_cann_pool_vmm() {
350
+ if (pool_addr != 0) {
351
+ for (auto& offset : map_offsets) {
352
+ ACL_CHECK(aclrtUnmapMem(offset));
353
+ }
354
+ for (auto& handle : handles) {
355
+ ACL_CHECK(aclrtFreePhysical(handle));
356
+ }
357
+ ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
358
+ }
359
+ }
360
+
361
+ /**
362
+ * @brief Allocate a buffer of the given size in the virtual memory pool.
363
+ *
364
+ * @param size The size of the buffer to allocate.
365
+ * @param actual_size A pointer to a variable to receive the actual size of
366
+ * the allocated buffer.
367
+ * @return A pointer to the allocated buffer.
368
+ */
369
+ void* alloc(size_t size, size_t* actual_size) override {
370
+ // round up the allocation size to the alignment to ensure that all
371
+ // allocations are aligned for all data types
372
+ const size_t alignment = 128;
373
+ size = alignment * ((size + alignment - 1) / alignment);
374
+
375
+ size_t avail = pool_size - pool_used;
376
+
377
+ if (size > avail) {
378
+ // round up to the next multiple of the granularity
379
+ size_t reserve_size = size - avail;
380
+ reserve_size =
381
+ granularity * ((reserve_size + granularity - 1) / granularity);
382
+
383
+ GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
384
+
385
+ // allocate more physical memory
386
+ aclrtPhysicalMemProp prop = {};
387
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
388
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
389
+ prop.memAttr = ACL_HBM_MEM_HUGE;
390
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
391
+ prop.location.id = device;
392
+ prop.reserve = 0;
393
+ aclrtDrvMemHandle handle;
394
+ ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
395
+
396
+ // reserve virtual address space (if not already reserved)
397
+ if (pool_addr == 0) {
398
+ ACL_CHECK(aclrtReserveMemAddress(
399
+ &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
400
+ }
401
+
402
+ // map at the end of the pool
403
+ ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
404
+ handle, 0));
405
+
406
+ handles.push_back(handle);
407
+ map_offsets.push_back((char*)pool_addr + pool_size);
408
+
409
+ // add to the pool
410
+ pool_size += reserve_size;
411
+
412
+ // GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
413
+ // reserved %llu MB)\n",
414
+ // device, (unsigned long long) (pool_size/1024/1024),
415
+ // (unsigned long long) (reserve_size/1024/1024));
416
+ }
417
+
418
+ GGML_ASSERT(pool_addr != 0);
419
+
420
+ void* ptr = (void*)((char*)pool_addr + pool_used);
421
+ *actual_size = size;
422
+ pool_used += size;
423
+
424
+ #ifdef DEBUG_CANN_MALLOC
425
+ GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
426
+ (unsigned long long)size, (unsigned long long)ptr);
427
+ #endif
428
+ return ptr;
429
+ }
430
+
431
+ /**
432
+ * @brief Free a buffer and return it to the virtual memory pool.
433
+ *
434
+ * @param ptr Pointer to the buffer to free.
435
+ * @param size Size of the buffer to free.
436
+ */
437
+ void free(void* ptr, size_t size) override {
438
+ #ifdef DEBUG_CANN_MALLOC
439
+ GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
440
+ (unsigned long long)size, (unsigned long long)ptr);
441
+ #endif
442
+
443
+ pool_used -= size;
444
+
445
+ // all deallocations must be in reverse order of the allocations
446
+ GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
447
+ }
448
+ };
449
+
450
+ /**
451
+ * @brief Create a new CANN pool for a specific device.
452
+ *
453
+ * Factory method to create a new CANN pool object based on the device type.
454
+ *
455
+ * @param device The device ID for which to create the pool.
456
+ * @return A unique pointer to the created CANN pool.
457
+ */
458
+ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
459
+ int device) {
460
+ // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
461
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
462
+ }
463
+
464
+ // cann buffer
465
+ /**
466
+ * @brief Context for managing a CANN buffer associated with a specific device.
467
+ *
468
+ * This structure holds information about a CANN buffer, including the device
469
+ * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
470
+ */
471
+ struct ggml_backend_cann_buffer_context {
472
+ int32_t device; ///< The device ID associated with this buffer context.
473
+ void* dev_ptr =
474
+ nullptr; ///< Pointer to the device memory allocated for the buffer.
475
+
476
+ /**
477
+ * @brief Constructor to initialize the CANN buffer context.
478
+ *
479
+ * @param device The device ID associated with this buffer context.
480
+ * @param dev_ptr Pointer to the device memory allocated for the buffer.
481
+ */
482
+ ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
483
+ : device(device),
484
+ dev_ptr(dev_ptr) {}
485
+
486
+ /**
487
+ * @brief Destructor to free the device memory allocated for the buffer.
488
+ */
489
+ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
490
+ };
491
+
492
+ /**
493
+ * @brief Check if a buffer is a CANN buffer.
494
+ *
495
+ * This function checks if a given buffer is a CANN buffer by comparing its
496
+ * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
497
+ *
498
+ * @param buffer The buffer to check.
499
+ * @return true if the buffer is a CANN buffer, false otherwise.
500
+ */
501
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
502
+ static bool ggml_backend_buffer_is_cann(
503
+ ggml_backend_buffer_t buffer) {
504
+ return ggml_backend_buft_is_cann(buffer->buft);
505
+ }
506
+
507
+ /**
508
+ * @brief Free resources associated with a CANN buffer.
509
+ *
510
+ * This function frees the resources associated with a CANN buffer, including
511
+ * its context.
512
+ *
513
+ * @param buffer The CANN buffer to free.
514
+ */
515
+ static void ggml_backend_cann_buffer_free_buffer(
516
+ ggml_backend_buffer_t buffer) {
517
+ ggml_backend_cann_buffer_context* ctx =
518
+ (ggml_backend_cann_buffer_context*)buffer->context;
519
+ delete ctx;
520
+ }
521
+
522
+ /**
523
+ * @brief Retrieve the base pointer of a CANN buffer.
524
+ *
525
+ * This function returns the base pointer of a CANN buffer, which points to the
526
+ * device memory allocated for the buffer.
527
+ *
528
+ * @param buffer The CANN buffer whose base pointer is to be retrieved.
529
+ * @return A pointer to the base of the device memory allocated for the buffer.
530
+ */
531
+ static void* ggml_backend_cann_buffer_get_base(
532
+ ggml_backend_buffer_t buffer) {
533
+ ggml_backend_cann_buffer_context* ctx =
534
+ (ggml_backend_cann_buffer_context*)buffer->context;
535
+ return ctx->dev_ptr;
536
+ }
537
+
538
+ /**
539
+ * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
540
+ * processing.
541
+ *
542
+ * This function transforms quantized Q4.0 tensor data into a format suitable
543
+ * for CANN processing. It extracts quantization values and scales from the
544
+ * source data and prepares them in a format expected by CANN operations.
545
+ *
546
+ * @param tensor Pointer to the tensor information.
547
+ * @param src Pointer to the source data in Q4.0 format.
548
+ * @param dst Pointer to the destination buffer where transformed data will be
549
+ * stored.
550
+ */
551
+ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
552
+ const void* src,
553
+ void* dst) {
554
+
555
+ int64_t n_elems = ggml_nelements(tensor);
556
+ int64_t groups = n_elems / QK4_0;
557
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
558
+
559
+ uint8_t* quant_offset = (uint8_t*)dst;
560
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
561
+
562
+ for (int i = 0; i < groups; i++) {
563
+ const block_q4_0* group =
564
+ (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
565
+ *scale_offset = group->d;
566
+ scale_offset++;
567
+
568
+ // 0-15
569
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
570
+ (*quant_offset) = (group->qs[j] & 0x0F);
571
+ (*quant_offset) |= ((group->qs[j + 1] << 4));
572
+ quant_offset++;
573
+ }
574
+
575
+ // 16-31
576
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
577
+ (*quant_offset) = (group->qs[j] >> 4);
578
+ (*quant_offset) |= (group->qs[j + 1] & 0xF0);
579
+ quant_offset++;
580
+ }
581
+ }
582
+
583
+ // put (uint4b_t -8) into int4b_t
584
+ for (quant_offset = (uint8_t*)dst;
585
+ quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
586
+ (*quant_offset) ^= 0x88;
587
+ }
588
+ }
589
+
590
+ /**
591
+ * @brief Transform CANN processed data back into quantized Q4.0 format.
592
+ *
593
+ * This function transforms CANN processed data back into quantized Q4.0 format.
594
+ * It reverses the transformation performed by
595
+ * ggml_backend_cann_transform_q4_0(), converting the data back into its
596
+ * original quantized form.
597
+ *
598
+ * @param tensor Pointer to the tensor information.
599
+ * @param src Pointer to the source buffer containing transformed data.
600
+ * @param dst Pointer to the destination buffer where the Q4.0 formatted data
601
+ * will be stored.
602
+ */
603
+ static void ggml_backend_cann_transform_back_q4_0(
604
+ const ggml_tensor* tensor, void* src, void* dst) {
605
+
606
+ int64_t n_elems = ggml_nelements(tensor);
607
+ int64_t groups = n_elems / QK4_0;
608
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
609
+
610
+ uint8_t* quant_offset = (uint8_t*)src;
611
+ uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
612
+
613
+ for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
614
+ (*quant_offset) ^= 0x88;
615
+ }
616
+ quant_offset = (uint8_t*)src;
617
+
618
+ for (int i = 0; i < groups; i++) {
619
+ block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
620
+ group->d = *scale_offset;
621
+ scale_offset++;
622
+
623
+ // 0-15
624
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
625
+ group->qs[j] = ((*quant_offset) & 0x0F);
626
+ group->qs[j + 1] = ((*quant_offset) >> 4);
627
+ quant_offset++;
628
+ }
629
+
630
+ // 16-31
631
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
632
+ group->qs[j] |= ((*quant_offset) << 4);
633
+ group->qs[j + 1] |= ((*quant_offset) & 0xF0);
634
+ quant_offset++;
635
+ }
636
+ }
637
+ }
638
+
639
+ /**
640
+ * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
641
+ * processing.
642
+ *
643
+ * This function transforms quantized Q8.0 tensor data into a format suitable
644
+ * for CANN processing. It extracts quantization values and scales from the
645
+ * source data and prepares them in a format expected by CANN operations.
646
+ *
647
+ * @param tensor Pointer to the tensor information.
648
+ * @param src Pointer to the source data in Q8.0 format.
649
+ * @param dst Pointer to the destination buffer where transformed data will be
650
+ * stored.
651
+ */
652
+ static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
653
+ const void* src,
654
+ void* dst) {
655
+ int64_t n_elems = ggml_nelements(tensor);
656
+ int64_t groups = n_elems / QK8_0;
657
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
658
+
659
+ uint8_t* quant_offset = (uint8_t*)dst;
660
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
661
+
662
+ for (int i = 0; i < groups; i++) {
663
+ const block_q8_0* group =
664
+ (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
665
+ *scale_offset = group->d;
666
+ scale_offset++;
667
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
668
+ memcpy(quant_offset, group->qs, group_quant_size);
669
+ quant_offset += group_quant_size;
670
+ }
671
+ }
672
+
673
+ /**
674
+ * @brief Transform CANN processed data back into quantized Q8.0 format.
675
+ *
676
+ * This function transforms CANN processed data back into quantized Q8.0 format.
677
+ * It reverses the transformation performed by
678
+ * ggml_backend_cann_transform_q8_0(), converting the data back into its
679
+ * original quantized form.
680
+ *
681
+ * @param tensor Pointer to the tensor information.
682
+ * @param src Pointer to the source buffer containing transformed data.
683
+ * @param dst Pointer to the destination buffer where the Q8.0 formatted data
684
+ * will be stored.
685
+ */
686
+ static void ggml_backend_cann_transform_back_q8_0(
687
+ const ggml_tensor* tensor, const void* src, void* dst) {
688
+ int64_t n_elems = ggml_nelements(tensor);
689
+ int64_t groups = n_elems / QK8_0;
690
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
691
+
692
+ const uint8_t* quant_offset = (const uint8_t*)src;
693
+ const uint16_t* scale_offset =
694
+ (const uint16_t*)((const char*)src + quant_bytes);
695
+
696
+ for (int i = 0; i < groups; i++) {
697
+ block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
698
+ group->d = *scale_offset;
699
+ scale_offset++;
700
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
701
+ memcpy(group->qs, quant_offset, group_quant_size);
702
+ quant_offset += group_quant_size;
703
+ }
704
+ }
705
+
706
+ /**
707
+ * @brief Transform tensor data based on its type for CANN processing.
708
+ *
709
+ * This function transforms tensor data based on its quantization type for CANN
710
+ * processing. It dispatches the transformation based on the tensor's type to
711
+ * specialized functions handling Q4.0 and Q8.0 formats.
712
+ *
713
+ * @param tensor Pointer to the tensor information.
714
+ * @param src Pointer to the source data to be transformed.
715
+ * @param dst Pointer to the destination buffer where transformed data will be
716
+ * stored.
717
+ */
718
+ static void ggml_backend_cann_transform(ggml_tensor* tensor,
719
+ const void* src, void* dst) {
720
+ switch (tensor->type) {
721
+ case GGML_TYPE_Q4_0:
722
+ ggml_backend_cann_transform_q4_0(tensor, src, dst);
723
+ break;
724
+ case GGML_TYPE_Q8_0:
725
+ ggml_backend_cann_transform_q8_0(tensor, src, dst);
726
+ break;
727
+ default:
728
+ break;
729
+ }
730
+ }
731
+
732
+ /**
733
+ * @brief Transform CANN processed data back into tensor data based on its type.
734
+ *
735
+ * This function transforms CANN processed data back into tensor data based on
736
+ * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
737
+ * transformation based on the tensor's type to specialized functions.
738
+ *
739
+ * @param tensor Pointer to the tensor information.
740
+ * @param src Pointer to the source data containing CANN processed data.
741
+ * @param dst Pointer to the destination buffer where transformed tensor data
742
+ * will be stored.
743
+ */
744
+ static void ggml_backend_cann_transform_back(
745
+ const ggml_tensor* tensor, void* src, void* dst) {
746
+ switch (tensor->type) {
747
+ case GGML_TYPE_Q4_0:
748
+ ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
749
+ break;
750
+ case GGML_TYPE_Q8_0:
751
+ ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
752
+ break;
753
+ default:
754
+ break;
755
+ }
756
+ }
757
+
758
+ /**
759
+ * @brief Check if transformation is needed for a given tensor type.
760
+ *
761
+ * This function checks if transformation is needed for a given tensor type
762
+ * to prepare data for CANN processing.
763
+ *
764
+ * @param type The tensor type to check.
765
+ * @return true if transformation is needed, false otherwise.
766
+ */
767
+ static bool need_transform(ggml_type type) {
768
+ switch (type) {
769
+ case GGML_TYPE_Q4_0:
770
+ case GGML_TYPE_Q8_0:
771
+ return true;
772
+ default:
773
+ return false;
774
+ }
775
+ }
776
+
777
+ /**
778
+ * @brief Initialize a tensor using data from a CANN buffer.
779
+ *
780
+ * This function initializes a tensor using data from a CANN buffer.
781
+ * It handles special cases such as views and quantization.
782
+ *
783
+ * @param buffer The CANN buffer from which to initialize the tensor.
784
+ * @param tensor Pointer to the tensor to be initialized.
785
+ */
786
+ static void ggml_backend_cann_buffer_init_tensor(
787
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
788
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
789
+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
790
+ return;
791
+ }
792
+
793
+ // TODO: can backend doesn't support quantized yet. Just leave the code
794
+ // here.
795
+ if (ggml_is_quantized(tensor->type)) {
796
+ // Initialize padding to 0 to avoid possible NaN values
797
+ size_t original_size = ggml_nbytes(tensor);
798
+ size_t padded_size =
799
+ ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
800
+
801
+ if (padded_size > original_size && tensor->view_src == nullptr) {
802
+ size_t memset_size = padded_size - original_size;
803
+ ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
804
+ memset_size, 0, memset_size));
805
+ }
806
+ }
807
+ }
808
+
809
+ // TODO: need handle tensor which has paddings.
810
+ /**
811
+ * @brief Set tensor data in a CANN buffer.
812
+ *
813
+ * This function sets tensor data in a CANN buffer, handling transformations
814
+ * if needed based on the tensor's type.
815
+ *
816
+ * @param buffer The CANN buffer where the tensor data will be set.
817
+ * @param tensor Pointer to the tensor whose data will be set.
818
+ * @param data Pointer to the source data to be copied into the tensor.
819
+ * @param offset Offset in the source data from where to start copying.
820
+ * @param size Size of the data to be copied, in bytes.
821
+ */
822
+ static void ggml_backend_cann_buffer_set_tensor(
823
+ ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
824
+ size_t offset, size_t size) {
825
+ ggml_backend_cann_buffer_context *ctx =
826
+ (ggml_backend_cann_buffer_context *)buffer->context;
827
+
828
+ ggml_cann_set_device(ctx->device);
829
+ // TODO: refer to cann(#6017), it use thread's default stream.
830
+ // For acl, synchronous functions use this default stream.
831
+ // Why aclrtSynchronizeDevice?
832
+
833
+ if (!need_transform(tensor->type)) {
834
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
835
+ ACL_MEMCPY_HOST_TO_DEVICE));
836
+ } else {
837
+ void *transform_buffer = malloc(size);
838
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
839
+
840
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
841
+ transform_buffer, size,
842
+ ACL_MEMCPY_HOST_TO_DEVICE));
843
+ free(transform_buffer);
844
+ }
845
+ }
846
+
847
+ /**
848
+ * @brief Get tensor data from a CANN buffer.
849
+ *
850
+ * This function retrieves tensor data from a CANN buffer, handling
851
+ * transformations if needed based on the tensor's type.
852
+ *
853
+ * @param buffer The CANN buffer from which to retrieve tensor data.
854
+ * @param tensor Pointer to the tensor whose data will be retrieved.
855
+ * @param data Pointer to the destination buffer where the tensor data will be
856
+ * copied.
857
+ * @param offset Offset in the destination buffer where to start copying.
858
+ * @param size Size of the data to be copied, in bytes.
859
+ */
860
+ static void ggml_backend_cann_buffer_get_tensor(
861
+ ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
862
+ size_t offset, size_t size) {
863
+ ggml_backend_cann_buffer_context* ctx =
864
+ (ggml_backend_cann_buffer_context*)buffer->context;
865
+
866
+ ggml_cann_set_device(ctx->device);
867
+
868
+ if (!need_transform(tensor->type)) {
869
+ ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
870
+ ACL_MEMCPY_DEVICE_TO_HOST));
871
+ } else {
872
+ void* transform_buffer = malloc(size);
873
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size,
874
+ (char*)tensor->data + offset, size,
875
+ ACL_MEMCPY_DEVICE_TO_HOST));
876
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
877
+ free(transform_buffer);
878
+ }
879
+ }
880
+
881
+ /**
882
+ * @brief Copy tensor data between CANN buffers if possible.
883
+ *
884
+ * This function copies tensor data between CANN buffers if the source and
885
+ * destination buffers are CANN buffers and they meet the necessary conditions
886
+ * (same device or devices can access each other).
887
+ *
888
+ * @param buffer The destination CANN buffer where the tensor data will be
889
+ * copied.
890
+ * @param src Pointer to the source tensor whose data will be copied.
891
+ * @param dst Pointer to the destination tensor where the data will be copied.
892
+ * @return true if the copy operation succeeded, false otherwise.
893
+ */
894
+ static bool ggml_backend_cann_buffer_cpy_tensor(
895
+ ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
896
+ if (ggml_backend_buffer_is_cann(src->buffer)) {
897
+ ggml_backend_cann_buffer_context* src_ctx =
898
+ (ggml_backend_cann_buffer_context*)src->buffer->context;
899
+ ggml_backend_cann_buffer_context* dst_ctx =
900
+ (ggml_backend_cann_buffer_context*)buffer->context;
901
+
902
+ size_t memcpy_size = ggml_nbytes(src);
903
+ // Same device.
904
+ if (src_ctx->device == dst_ctx->device) {
905
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
906
+ (const char*)src->data, memcpy_size,
907
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
908
+ return true;
909
+ } else {
910
+ // Different device but can access by peer.
911
+ int32_t canAccessPeer = 0;
912
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
913
+ dst_ctx->device));
914
+ if (canAccessPeer) {
915
+ ggml_cann_set_device(src_ctx->device);
916
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
917
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
918
+ (const char*)src->data, memcpy_size,
919
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
920
+ return true;
921
+ }
922
+ }
923
+ }
924
+ return false;
925
+ }
926
+
927
+ /**
928
+ * @brief Clear a CANN buffer by setting all its memory to a specified value.
929
+ *
930
+ * This function clears a CANN buffer by setting all its memory to a specified
931
+ * value.
932
+ *
933
+ * @param buffer The CANN buffer to be cleared.
934
+ * @param value The value to which each byte in the buffer will be set.
935
+ */
936
+ static void ggml_backend_cann_buffer_clear(
937
+ ggml_backend_buffer_t buffer, uint8_t value) {
938
+ ggml_backend_cann_buffer_context* ctx =
939
+ (ggml_backend_cann_buffer_context*)buffer->context;
940
+
941
+ ggml_cann_set_device(ctx->device);
942
+ ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
943
+ }
944
+
945
+ /**
946
+ * @brief Interface for a CANN buffer in the backend.
947
+ *
948
+ * This structure defines function pointers to operations that can be performed
949
+ * on a CANN buffer within the backend.
950
+ */
951
+ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
952
+ /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
953
+ /* .get_base = */ ggml_backend_cann_buffer_get_base,
954
+ /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
955
+ /* .memset_tensor = */ NULL,
956
+ /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
957
+ /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
958
+ /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
959
+ /* .clear = */ ggml_backend_cann_buffer_clear,
960
+ /* .reset = */ NULL,
961
+ };
962
+
963
+ // cann buffer type
964
+ /**
965
+ * @brief Structure representing context information for a specific backend
966
+ * buffer type.
967
+ */
968
+ struct ggml_backend_cann_buffer_type_context {
969
+ int32_t
970
+ device; /**< Device identifier associated with the buffer context. */
971
+ std::string name; /**< Name associated with the buffer context. */
972
+ };
973
+
974
+ /**
975
+ * @brief Retrieves the name associated with a CANN buffer type.
976
+ *
977
+ * This function returns the descriptive name associated with the specified
978
+ * CANN buffer type context.
979
+ *
980
+ * @param buft Pointer to the buffer type context.
981
+ * @return Const pointer to the C-style string containing the name.
982
+ */
983
+ static const char* ggml_backend_cann_buffer_type_name(
984
+ ggml_backend_buffer_type_t buft) {
985
+ ggml_backend_cann_buffer_type_context* buft_ctx =
986
+ (ggml_backend_cann_buffer_type_context*)buft->context;
987
+
988
+ return buft_ctx->name.c_str();
989
+ }
990
+
991
+ /**
992
+ * @brief Allocates a new CANN buffer of the specified type and size.
993
+ *
994
+ * This function allocates a new CANN buffer on the specified device with the
995
+ * given size.
996
+ *
997
+ * @param buft Pointer to the buffer type context.
998
+ * @param size Size in bytes of the buffer to allocate.
999
+ * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1000
+ */
1001
+ static ggml_backend_buffer_t
1002
+ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1003
+ size_t size) {
1004
+ ggml_backend_cann_buffer_type_context* buft_ctx =
1005
+ (ggml_backend_cann_buffer_type_context*)buft->context;
1006
+
1007
+ ggml_cann_set_device(buft_ctx->device);
1008
+
1009
+ size = std::max(size, (size_t)1);
1010
+
1011
+ void* dev_ptr;
1012
+ aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1013
+ if (err != ACL_SUCCESS) {
1014
+ GGML_LOG_ERROR(
1015
+ "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1016
+ __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1017
+ aclGetRecentErrMsg());
1018
+ return nullptr;
1019
+ }
1020
+
1021
+ ggml_backend_cann_buffer_context* ctx =
1022
+ new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1023
+
1024
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1025
+ ctx, size);
1026
+ }
1027
+
1028
+ /**
1029
+ * @brief Retrieves the memory alignment requirement for CANN buffers of this
1030
+ * type.
1031
+ *
1032
+ * This function returns the alignment requirement in bytes for memory allocated
1033
+ * by the CANN buffer type.
1034
+ *
1035
+ * @param buft Pointer to the buffer type context (unused in this
1036
+ * implementation).
1037
+ * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1038
+ * buffers).
1039
+ */
1040
+ static size_t ggml_backend_cann_buffer_type_get_alignment(
1041
+ ggml_backend_buffer_type_t buft) {
1042
+ return 128;
1043
+
1044
+ GGML_UNUSED(buft);
1045
+ }
1046
+
1047
+ /**
1048
+ * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1049
+ *
1050
+ * Computes the total allocation size needed for storing the tensor's data in a
1051
+ * CANN buffer, considering any necessary padding or adjustments for quantized
1052
+ * types.
1053
+ *
1054
+ * @param buft Pointer to the buffer type context (unused in this
1055
+ * implementation).
1056
+ * @param tensor Pointer to the tensor for which the allocation size is
1057
+ * calculated.
1058
+ * @return The total allocation size in bytes required for the tensor in the
1059
+ * CANN buffer.
1060
+ */
1061
+ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1062
+ ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1063
+ size_t size = ggml_nbytes(tensor);
1064
+ int64_t ne0 = tensor->ne[0];
1065
+
1066
+ // last line must bigger than 32, because every single op deal at
1067
+ // least 32 bytes.
1068
+ // TODO: quantized type?
1069
+ // int64_t line_size = ne0 * ggml_element_size(tensor);
1070
+ // int64_t line_size_align_32 = (line_size + 31) & ~31;
1071
+ // size += (line_size_align_32 - line_size);
1072
+
1073
+ // TODO: not support quantized yet.
1074
+ // TODO: consider un-continue tensor.
1075
+ if (ggml_is_quantized(tensor->type)) {
1076
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
1077
+ size += ggml_row_size(
1078
+ tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1079
+ }
1080
+ }
1081
+
1082
+ return size;
1083
+
1084
+ GGML_UNUSED(buft);
1085
+ }
1086
+
1087
+ static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1088
+ return false;
1089
+
1090
+ GGML_UNUSED(buft);
1091
+ }
1092
+
1093
+ /**
1094
+ * @brief Interface for managing CANN buffer types in the GGML backend.
1095
+ *
1096
+ * Provides function pointers for allocating, querying properties, and managing
1097
+ * memory for CANN buffer types in the GGML backend.
1098
+ */
1099
+ static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1100
+ /* .get_name = */ ggml_backend_cann_buffer_type_name,
1101
+ /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
1102
+ /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
1103
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1104
+ /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
1105
+ /* .is_host = */ ggml_backend_cann_buffer_type_is_host,
1106
+ };
1107
+
1108
+ /**
1109
+ * @brief Retrieves the CANN buffer type for a specified device.
1110
+ *
1111
+ * This function initializes and returns the buffer type interface associated
1112
+ * with the given device. It ensures thread-safe access using a mutex.
1113
+ *
1114
+ * @param device The device index for which to retrieve the buffer type.
1115
+ * @return A pointer to the buffer type interface for the specified device, or
1116
+ * nullptr if the device index is out of range.
1117
+ */
1118
+ ggml_backend_buffer_type_t
1119
+ ggml_backend_cann_buffer_type(int32_t device) {
1120
+ static std::mutex mutex;
1121
+ std::lock_guard<std::mutex> lock(mutex);
1122
+
1123
+ if (device >= ggml_backend_cann_get_device_count()) {
1124
+ return nullptr;
1125
+ }
1126
+
1127
+ static ggml_backend_buffer_type
1128
+ ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1129
+
1130
+ static bool ggml_backend_cann_buffer_type_initialized = false;
1131
+
1132
+ if (!ggml_backend_cann_buffer_type_initialized) {
1133
+ for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
1134
+ ggml_backend_cann_buffer_types[i] = {
1135
+ /* .iface = */ ggml_backend_cann_buffer_type_interface,
1136
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
1137
+ /* .context = */
1138
+ new ggml_backend_cann_buffer_type_context{
1139
+ i, "CANN" + std::to_string(i)},
1140
+ };
1141
+ }
1142
+ ggml_backend_cann_buffer_type_initialized = true;
1143
+ }
1144
+
1145
+ return &ggml_backend_cann_buffer_types[device];
1146
+ }
1147
+
1148
+ /**
1149
+ * @brief Retrieves the name associated with a CANN host buffer type.
1150
+ *
1151
+ * This function returns the descriptive name associated with the specified
1152
+ * CANN host buffer type context.
1153
+ *
1154
+ * @param buft Pointer to the host buffer type context.
1155
+ * @return Const pointer to the C-style string containing the name.
1156
+ */
1157
+ static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1158
+ return "CANN_Host";
1159
+
1160
+ GGML_UNUSED(buft);
1161
+ }
1162
+
1163
+ /**
1164
+ * @brief Retrieves the name associated with a CANN host buffer.
1165
+ *
1166
+ * This function returns the descriptive name associated with the specified
1167
+ * CANN host buffer context.
1168
+ *
1169
+ * @param buft Pointer to the host buffer context.
1170
+ * @return Const pointer to the C-style string containing the name.
1171
+ */
1172
+ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
1173
+ return "CANN_Host";
1174
+
1175
+ GGML_UNUSED(buffer);
1176
+ }
1177
+
1178
+ /**
1179
+ * @brief Free resources associated with a CANN host buffer.
1180
+ *
1181
+ * This function frees the resources associated with a CANN host buffer, including
1182
+ * its context.
1183
+ *
1184
+ * @param buffer The CANN host buffer to free.
1185
+ */
1186
+ static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
1187
+ ACL_CHECK(aclrtFreeHost(buffer->context));
1188
+ }
1189
+
1190
+ /**
1191
+ * @brief Allocates a new CANN host buffer of the specified size.
1192
+ *
1193
+ * This function allocates a new CANN host buffer with the given size.
1194
+ * @param size Size in bytes of the host buffer to allocate.
1195
+ * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
1196
+ */
1197
+ static void * ggml_cann_host_malloc(size_t size) {
1198
+ if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
1199
+ return nullptr;
1200
+ }
1201
+
1202
+ void * hostPtr = nullptr;
1203
+ aclError err = aclrtMallocHost((void **) &hostPtr, size);
1204
+ if (err != ACL_SUCCESS) {
1205
+
1206
+ GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1207
+ size / 1024.0 / 1024.0, aclGetRecentErrMsg());
1208
+ return nullptr;
1209
+ }
1210
+ return hostPtr;
1211
+ }
1212
+
1213
+ /**
1214
+ * @brief Allocates a new CANN host buffer of the specified type and size.
1215
+ *
1216
+ * @param buft Pointer to the host buffer type context.
1217
+ * @param size Size in bytes of the host buffer to allocate.
1218
+ * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1219
+ */
1220
+ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1221
+ void * hostPtr = ggml_cann_host_malloc(size);
1222
+
1223
+ if (hostPtr == nullptr) {
1224
+ // fallback to cpu buffer
1225
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1226
+ }
1227
+
1228
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1229
+ buffer->buft = buft;
1230
+ buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1231
+
1232
+ return buffer;
1233
+ }
1234
+
1235
+ /**
1236
+ * @brief Interface for managing CANN host buffer types in the GGML backend.
1237
+ *
1238
+ * Provides function pointers for allocating, querying properties, and managing
1239
+ * memory for CANN buffer types in the GGML backend.
1240
+ */
1241
+ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1242
+ static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1243
+ /* .iface = */ {
1244
+ /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1245
+ /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1246
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1247
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1248
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1249
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1250
+ },
1251
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1252
+ /* .context = */ nullptr,
1253
+ };
1254
+
1255
+ return &ggml_backend_cann_buffer_type_host;
1256
+ }
1257
+
1258
+ /**
1259
+ * @brief Computes the forward operation for a given tensor using CANN
1260
+ * operations.
1261
+ *
1262
+ * This function selects the appropriate CANN operation based on the type of
1263
+ * operation specified in the tensor and performs the computation.
1264
+ *
1265
+ * @param ctx The CANN context containing necessary resources and
1266
+ * configurations.
1267
+ * @param dst The destination tensor where the result of the computation will be
1268
+ * stored.
1269
+ * @return true if the computation was successful; false otherwise.
1270
+ */
1271
+ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1272
+ struct ggml_tensor* dst) {
1273
+ switch (dst->op) {
1274
+ case GGML_OP_REPEAT:
1275
+ ggml_cann_repeat(ctx, dst);
1276
+ break;
1277
+ case GGML_OP_GET_ROWS:
1278
+ ggml_cann_get_rows(ctx, dst);
1279
+ break;
1280
+ case GGML_OP_DUP:
1281
+ ggml_cann_dup(ctx, dst);
1282
+ break;
1283
+ case GGML_OP_ADD:
1284
+ ggml_cann_add(ctx, dst);
1285
+ break;
1286
+ case GGML_OP_ACC:
1287
+ ggml_cann_acc(ctx, dst);
1288
+ break;
1289
+ case GGML_OP_MUL:
1290
+ ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
1291
+ break;
1292
+ case GGML_OP_DIV:
1293
+ ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
1294
+ break;
1295
+ case GGML_OP_UNARY:
1296
+ switch (ggml_get_unary_op(dst)) {
1297
+ case GGML_UNARY_OP_GELU:
1298
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1299
+ ctx, dst);
1300
+ break;
1301
+ case GGML_UNARY_OP_SILU:
1302
+ ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
1303
+ ctx, dst);
1304
+ break;
1305
+ // TODO: Use faster gelu??
1306
+ case GGML_UNARY_OP_GELU_QUICK:
1307
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1308
+ ctx, dst);
1309
+ break;
1310
+ case GGML_UNARY_OP_TANH:
1311
+ ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
1312
+ ctx, dst);
1313
+ break;
1314
+ case GGML_UNARY_OP_RELU:
1315
+ ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
1316
+ ctx, dst);
1317
+ break;
1318
+ case GGML_UNARY_OP_HARDSIGMOID:
1319
+ ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
1320
+ aclnnHardsigmoid>(ctx, dst);
1321
+ break;
1322
+ case GGML_UNARY_OP_HARDSWISH:
1323
+ ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
1324
+ aclnnHardswish>(ctx, dst);
1325
+ break;
1326
+ default:
1327
+ return false;
1328
+ }
1329
+ break;
1330
+ case GGML_OP_NORM:
1331
+ ggml_cann_norm(ctx, dst);
1332
+ break;
1333
+ case GGML_OP_GROUP_NORM:
1334
+ ggml_cann_group_norm(ctx, dst);
1335
+ break;
1336
+ case GGML_OP_CONCAT:
1337
+ ggml_cann_concat(ctx, dst);
1338
+ break;
1339
+ case GGML_OP_UPSCALE:
1340
+ ggml_cann_upsample_nearest2d(ctx, dst);
1341
+ break;
1342
+ case GGML_OP_PAD:
1343
+ ggml_cann_pad(ctx, dst);
1344
+ break;
1345
+ case GGML_OP_ARANGE:
1346
+ ggml_cann_arange(ctx, dst);
1347
+ break;
1348
+ case GGML_OP_TIMESTEP_EMBEDDING:
1349
+ ggml_cann_timestep_embedding(ctx, dst);
1350
+ break;
1351
+ case GGML_OP_LEAKY_RELU:
1352
+ ggml_cann_leaky_relu(ctx, dst);
1353
+ break;
1354
+ case GGML_OP_RMS_NORM:
1355
+ ggml_cann_rms_norm(ctx, dst);
1356
+ break;
1357
+ case GGML_OP_MUL_MAT:
1358
+ ggml_cann_mul_mat(ctx, dst);
1359
+ break;
1360
+ case GGML_OP_MUL_MAT_ID:
1361
+ return false;
1362
+ case GGML_OP_SCALE:
1363
+ ggml_cann_scale(ctx, dst);
1364
+ break;
1365
+ case GGML_OP_SQR:
1366
+ ggml_cann_sqr(ctx, dst);
1367
+ break;
1368
+ case GGML_OP_CLAMP:
1369
+ ggml_cann_clamp(ctx, dst);
1370
+ break;
1371
+ case GGML_OP_CPY:
1372
+ ggml_cann_cpy(ctx, dst);
1373
+ break;
1374
+ case GGML_OP_CONT:
1375
+ ggml_cann_dup(ctx, dst);
1376
+ break;
1377
+ case GGML_OP_NONE:
1378
+ case GGML_OP_RESHAPE:
1379
+ case GGML_OP_VIEW:
1380
+ case GGML_OP_PERMUTE:
1381
+ case GGML_OP_TRANSPOSE:
1382
+ break;
1383
+ case GGML_OP_DIAG_MASK_INF:
1384
+ ggml_cann_diag_mask(ctx, dst, -INFINITY);
1385
+ break;
1386
+ case GGML_OP_SOFT_MAX:
1387
+ ggml_cann_softmax(ctx, dst);
1388
+ break;
1389
+ case GGML_OP_ROPE:
1390
+ ggml_cann_rope(ctx, dst);
1391
+ break;
1392
+ case GGML_OP_IM2COL:
1393
+ ggml_cann_im2col(ctx, dst);
1394
+ break;
1395
+ case GGML_OP_POOL_2D:
1396
+ ggml_cann_pool2d(ctx, dst);
1397
+ break;
1398
+ case GGML_OP_SUM_ROWS:
1399
+ ggml_cann_sum_rows(ctx, dst);
1400
+ break;
1401
+ case GGML_OP_ARGSORT:
1402
+ ggml_cann_argsort(ctx, dst);
1403
+ break;
1404
+ default:
1405
+ return false;
1406
+ }
1407
+
1408
+ return true;
1409
+ }
1410
+
1411
+ // backend
1412
+ /**
1413
+ * @brief Retrieves the name associated with the CANN backend.
1414
+ *
1415
+ * This function returns the name assigned to the CANN backend, which is stored
1416
+ * in the context of the provided backend structure.
1417
+ *
1418
+ * @param backend Pointer to the CANN backend structure.
1419
+ * @return A pointer to a constant string representing the backend name.
1420
+ */
1421
+ static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1422
+ ggml_backend_cann_context* cann_ctx =
1423
+ (ggml_backend_cann_context*)backend->context;
1424
+
1425
+ return cann_ctx->name.c_str();
1426
+ }
1427
+
1428
+ /**
1429
+ * @brief Frees resources associated with the CANN backend.
1430
+ *
1431
+ * This function releases resources associated with the CANN backend context
1432
+ * and resets the device associated with the backend to its initial state.
1433
+ *
1434
+ * @param backend Pointer to the CANN backend structure to be freed.
1435
+ */
1436
+ static void ggml_backend_cann_free(ggml_backend_t backend) {
1437
+ ggml_backend_cann_context* cann_ctx =
1438
+ (ggml_backend_cann_context*)backend->context;
1439
+ ACL_CHECK(aclrtSynchronizeDevice());
1440
+ ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1441
+
1442
+ // finalize when last backend freed.
1443
+ if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1444
+ ACL_CHECK(aclFinalize());
1445
+ }
1446
+
1447
+ delete cann_ctx;
1448
+ delete backend;
1449
+ }
1450
+
1451
+ /**
1452
+ * @brief Sets tensor data asynchronously in the CANN backend.
1453
+ *
1454
+ * This function asynchronously sets tensor data in the CANN backend. Depending
1455
+ * on the tensor type, it may perform data transformations before copying data
1456
+ * to the device.
1457
+ *
1458
+ * @param backend Pointer to the CANN backend structure.
1459
+ * @param tensor Pointer to the tensor structure to set data for.
1460
+ * @param data Pointer to the host data to copy to the tensor.
1461
+ * @param offset Offset in bytes within the host data.
1462
+ * @param size Size of the data to copy in bytes.
1463
+ */
1464
+ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1465
+ ggml_tensor *tensor,
1466
+ const void *data,
1467
+ size_t offset,
1468
+ size_t size) {
1469
+ ggml_backend_cann_context *cann_ctx =
1470
+ (ggml_backend_cann_context *)backend->context;
1471
+
1472
+ if (!need_transform(tensor->type)) {
1473
+ ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
1474
+ size, ACL_MEMCPY_HOST_TO_DEVICE,
1475
+ cann_ctx->stream()));
1476
+ } else {
1477
+ void *transform_buffer = malloc(size);
1478
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
1479
+
1480
+ ACL_CHECK(aclrtMemcpyAsync(
1481
+ (char *)tensor->data + offset, size, transform_buffer, size,
1482
+ ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
1483
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1484
+ free(transform_buffer);
1485
+ }
1486
+ }
1487
+
1488
+ static void ggml_backend_cann_get_tensor_async(
1489
+ ggml_backend_t backend, const ggml_tensor *tensor, void *data,
1490
+ size_t offset, size_t size) {
1491
+ ggml_backend_cann_context *cann_ctx =
1492
+ (ggml_backend_cann_context *)backend->context;
1493
+ ggml_backend_buffer_t buf =
1494
+ tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1495
+
1496
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1497
+ "unsupported buffer type");
1498
+
1499
+ if (!need_transform(tensor->type)) {
1500
+ ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
1501
+ size, ACL_MEMCPY_DEVICE_TO_HOST,
1502
+ cann_ctx->stream()));
1503
+ } else {
1504
+ void *transform_buffer = malloc(size);
1505
+ ACL_CHECK(aclrtMemcpyAsync(
1506
+ transform_buffer, size, (char *)tensor->data + offset, size,
1507
+ ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
1508
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1509
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1510
+ free(transform_buffer);
1511
+ }
1512
+ }
1513
+
1514
+ /**
1515
+ * @brief Asynchronously copies tensor data between CANN backends.
1516
+ *
1517
+ * This function copies tensor data asynchronously between two CANN backends. It
1518
+ * checks if both tensors reside in CANN buffers and whether the devices support
1519
+ * peer-to-peer access for direct copying. If not, it returns false.
1520
+ *
1521
+ * @param backend_src Pointer to the source CANN backend structure.
1522
+ * @param backend_dst Pointer to the destination CANN backend structure.
1523
+ * @param src Pointer to the source tensor to copy data from.
1524
+ * @param dst Pointer to the destination tensor to copy data to.
1525
+ * @return true if the copy operation succeeds, false otherwise.
1526
+ */
1527
+ static bool ggml_backend_cann_cpy_tensor_async(
1528
+ ggml_backend_t backend_src, ggml_backend_t backend_dst,
1529
+ const ggml_tensor* src, ggml_tensor* dst) {
1530
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
1531
+ ggml_backend_is_cann(backend_dst));
1532
+
1533
+ if (!ggml_backend_buffer_is_cann(src->buffer) ||
1534
+ !ggml_backend_buffer_is_cann(dst->buffer)) {
1535
+ return false;
1536
+ }
1537
+
1538
+ ggml_backend_buffer_t buf_src =
1539
+ src->view_src ? src->view_src->buffer : src->buffer;
1540
+ ggml_backend_buffer_t buf_dst =
1541
+ dst->view_src ? dst->view_src->buffer : dst->buffer;
1542
+
1543
+ ggml_backend_cann_context* cann_ctx_src =
1544
+ (ggml_backend_cann_context*)backend_src->context;
1545
+ ggml_backend_cann_context* cann_ctx_dst =
1546
+ (ggml_backend_cann_context*)backend_dst->context;
1547
+
1548
+ size_t copy_size = ggml_nbytes(dst);
1549
+ if (backend_src != backend_dst) {
1550
+ ggml_backend_cann_buffer_context* buf_ctx_src =
1551
+ (ggml_backend_cann_buffer_context*)buf_src->context;
1552
+ ggml_backend_cann_buffer_context* buf_ctx_dst =
1553
+ (ggml_backend_cann_buffer_context*)buf_dst->context;
1554
+
1555
+ GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
1556
+ GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
1557
+
1558
+ int32_t canAccessPeer = 0;
1559
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
1560
+ cann_ctx_dst->device));
1561
+ if (!canAccessPeer) {
1562
+ return false;
1563
+ }
1564
+
1565
+ // need open both directions for memcpyasync between devices.
1566
+ ggml_cann_set_device(cann_ctx_dst->device);
1567
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
1568
+ ggml_cann_set_device(cann_ctx_src->device);
1569
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
1570
+
1571
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1572
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1573
+ cann_ctx_src->stream()));
1574
+
1575
+ //TODO: workaround for Event didn`t work here.
1576
+ aclrtSynchronizeStream(cann_ctx_src->stream());
1577
+ } else {
1578
+ // src and dst are on the same backend
1579
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1580
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1581
+ cann_ctx_dst->stream()));
1582
+ }
1583
+
1584
+ return true;
1585
+ }
1586
+
1587
+ /**
1588
+ * @brief Synchronizes a CANN backend.
1589
+ *
1590
+ * This function synchronizes the specified CANN backend by waiting for all
1591
+ * operations in its associated stream to complete.
1592
+ *
1593
+ * @param backend Pointer to the CANN backend structure to synchronize.
1594
+ */
1595
+ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1596
+ ggml_backend_cann_context* cann_ctx =
1597
+ (ggml_backend_cann_context*)backend->context;
1598
+
1599
+ ggml_cann_set_device(cann_ctx->device);
1600
+
1601
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1602
+ }
1603
+
1604
+ /**
1605
+ * @brief Computes a computational graph using a CANN backend.
1606
+ *
1607
+ * This function computes the operations defined in the computational graph
1608
+ * using the specified CANN backend.
1609
+ *
1610
+ * @param backend Pointer to the CANN backend structure to use for computation.
1611
+ * @param cgraph Pointer to the computational graph structure containing nodes
1612
+ * representing operations to be computed.
1613
+ * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
1614
+ * completes successfully, otherwise an appropriate error status.
1615
+ */
1616
+ static enum ggml_status ggml_backend_cann_graph_compute(
1617
+ ggml_backend_t backend, ggml_cgraph* cgraph) {
1618
+ ggml_backend_cann_context* cann_ctx =
1619
+ (ggml_backend_cann_context*)backend->context;
1620
+
1621
+ ggml_cann_set_device(cann_ctx->device);
1622
+
1623
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1624
+ ggml_tensor* node = cgraph->nodes[i];
1625
+
1626
+ if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
1627
+ continue;
1628
+ }
1629
+
1630
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
1631
+
1632
+ if (!ok) {
1633
+ GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
1634
+ node->name, ggml_op_name(node->op));
1635
+ }
1636
+ GGML_ASSERT(ok);
1637
+ }
1638
+
1639
+ return GGML_STATUS_SUCCESS;
1640
+ }
1641
+
1642
+ /**
1643
+ * @brief Checks if the CANN backend supports a specific operation.
1644
+ *
1645
+ * This function checks whether the specified operation is supported by the
1646
+ * CANN backend.
1647
+ *
1648
+ * @param backend Pointer to the CANN backend structure to check support for
1649
+ * the operation.
1650
+ * @param op Pointer to the tensor representing the operation to check.
1651
+ * @return bool Returns true if the operation is supported by the backend,
1652
+ * otherwise false.
1653
+ */
1654
+ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1655
+ const ggml_tensor* op) {
1656
+ switch (op->op) {
1657
+ case GGML_OP_UNARY:
1658
+ switch (ggml_get_unary_op(op)) {
1659
+ case GGML_UNARY_OP_GELU:
1660
+ case GGML_UNARY_OP_SILU:
1661
+ case GGML_UNARY_OP_RELU:
1662
+ case GGML_UNARY_OP_HARDSIGMOID:
1663
+ case GGML_UNARY_OP_HARDSWISH:
1664
+ case GGML_UNARY_OP_GELU_QUICK:
1665
+ case GGML_UNARY_OP_TANH:
1666
+ return true;
1667
+ default:
1668
+ return false;
1669
+ }
1670
+ case GGML_OP_MUL_MAT: {
1671
+ switch (op->src[0]->type) {
1672
+ case GGML_TYPE_F16:
1673
+ case GGML_TYPE_F32:
1674
+ case GGML_TYPE_Q8_0:
1675
+ // TODO: fix me
1676
+ // Current groupsize should not be greater than k-1 in
1677
+ // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
1678
+ case GGML_TYPE_Q4_0:
1679
+ return true;
1680
+ default:
1681
+ return false;
1682
+ }
1683
+ }
1684
+ case GGML_OP_MUL_MAT_ID:
1685
+ return false;
1686
+ // embedding
1687
+ case GGML_OP_GET_ROWS: {
1688
+ switch (op->src[0]->type) {
1689
+ case GGML_TYPE_F32:
1690
+ case GGML_TYPE_F16:
1691
+ case GGML_TYPE_Q4_0:
1692
+ case GGML_TYPE_Q8_0:
1693
+ return true;
1694
+ default:
1695
+ return false;
1696
+ }
1697
+ } break;
1698
+ case GGML_OP_CPY: {
1699
+ switch (op->type) {
1700
+ case GGML_TYPE_F32:
1701
+ case GGML_TYPE_F16:
1702
+ case GGML_TYPE_Q8_0:
1703
+ case GGML_TYPE_Q4_0:
1704
+ return true;
1705
+ default:
1706
+ return false;
1707
+ }
1708
+ }
1709
+ case GGML_OP_DUP:
1710
+ case GGML_OP_REPEAT:
1711
+ case GGML_OP_CONCAT:
1712
+ case GGML_OP_NONE:
1713
+ case GGML_OP_RESHAPE:
1714
+ case GGML_OP_VIEW:
1715
+ case GGML_OP_PERMUTE:
1716
+ case GGML_OP_TRANSPOSE:
1717
+ case GGML_OP_NORM:
1718
+ case GGML_OP_ADD:
1719
+ case GGML_OP_MUL:
1720
+ case GGML_OP_DIV:
1721
+ case GGML_OP_RMS_NORM:
1722
+ case GGML_OP_SCALE:
1723
+ case GGML_OP_SQR:
1724
+ case GGML_OP_CLAMP:
1725
+ case GGML_OP_CONT:
1726
+ case GGML_OP_DIAG_MASK_INF:
1727
+ case GGML_OP_SOFT_MAX:
1728
+ case GGML_OP_ROPE:
1729
+ case GGML_OP_IM2COL:
1730
+ case GGML_OP_POOL_2D:
1731
+ case GGML_OP_SUM_ROWS:
1732
+ case GGML_OP_ARGSORT:
1733
+ case GGML_OP_ACC:
1734
+ case GGML_OP_GROUP_NORM:
1735
+ case GGML_OP_UPSCALE:
1736
+ case GGML_OP_PAD:
1737
+ case GGML_OP_ARANGE:
1738
+ case GGML_OP_TIMESTEP_EMBEDDING:
1739
+ case GGML_OP_LEAKY_RELU:
1740
+ return true;
1741
+ default:
1742
+ return false;
1743
+ }
1744
+
1745
+ GGML_UNUSED(dev);
1746
+ }
1747
+
1748
+ /**
1749
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
1750
+ *
1751
+ * This function checks whether the provided backend buffer type is associated
1752
+ * with the CANN backend based on the comparison of its name retrieval function
1753
+ * pointer.
1754
+ *
1755
+ * @param buft Pointer to the backend buffer type to check.
1756
+ * @return bool Returns true if the buffer type is associated with the CANN
1757
+ * backend, otherwise false.
1758
+ */
1759
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1760
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
1761
+ }
1762
+
1763
+ /**
1764
+ * @brief Determines if a tensor operation should be offloaded to the CANN
1765
+ * backend.
1766
+ *
1767
+ * This function checks if a given tensor operation should be offloaded to the
1768
+ * CANN backend based on the operation type and the size of the tensor. It
1769
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
1770
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
1771
+ *
1772
+ * @param backend Pointer to the CANN backend.
1773
+ * @param op Pointer to the tensor operation to check.
1774
+ * @return bool Returns true if the operation should be offloaded, otherwise
1775
+ * false.
1776
+ */
1777
+ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
1778
+ const ggml_tensor* op) {
1779
+ const int min_batch_size = 32;
1780
+ GGML_UNUSED(dev);
1781
+
1782
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
1783
+ }
1784
+
1785
+ /**
1786
+ * @brief Records an event on the CANN backend stream.
1787
+ *
1788
+ * This function records the given event on the ACL runtime stream associated
1789
+ * with the backend context.
1790
+ *
1791
+ * @param event Pointer to the event structure to be recorded.
1792
+ */
1793
+ static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
1794
+ ggml_backend_cann_context* cann_ctx =
1795
+ (ggml_backend_cann_context*)backend->context;
1796
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
1797
+ }
1798
+
1799
+ /**
1800
+ * @brief Waits for a recorded event to complete on the CANN backend stream.
1801
+ *
1802
+ * This function makes the given backend wait for the event to complete on its
1803
+ * ACL runtime stream.
1804
+ *
1805
+ * @param backend Pointer to the backend structure.
1806
+ * @param event Pointer to the event structure that the backend needs to wait
1807
+ * for.
1808
+ */
1809
+ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
1810
+ ggml_backend_event_t event) {
1811
+ ggml_backend_cann_context* cann_ctx =
1812
+ (ggml_backend_cann_context*)backend->context;
1813
+ if (ggml_backend_is_cann(backend)) {
1814
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
1815
+ (aclrtEvent)event->context));
1816
+ } else {
1817
+ GGML_ABORT("fatal error");
1818
+ }
1819
+ }
1820
+
1821
+ /**
1822
+ * @brief Structure defining the interface for the CANN backend.
1823
+ *
1824
+ * This structure contains function pointers for various operations
1825
+ * supported by the CANN backend, including name retrieval, memory
1826
+ * management, tensor operations, synchronization, and event handling.
1827
+ */
1828
+ static const ggml_backend_i ggml_backend_cann_interface = {
1829
+ /* .get_name = */ ggml_backend_cann_name,
1830
+ /* .free = */ ggml_backend_cann_free,
1831
+ /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
1832
+ /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
1833
+ /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
1834
+ /* .synchronize = */ ggml_backend_cann_synchronize,
1835
+ /* .graph_plan_create = */ NULL,
1836
+ /* .graph_plan_free = */ NULL,
1837
+ /* .graph_plan_update = */ NULL,
1838
+ /* .graph_plan_compute = */ NULL,
1839
+ /* .graph_compute = */ ggml_backend_cann_graph_compute,
1840
+ /* .event_record = */ ggml_backend_cann_event_record,
1841
+ /* .event_wait = */ ggml_backend_cann_event_wait,
1842
+ };
1843
+
1844
+ /**
1845
+ * @brief Return the hardcoded GUID for the CANN backend.
1846
+ *
1847
+ * This function returns a static GUID which uniquely identifies the CANN
1848
+ * backend.
1849
+ *
1850
+ * @return A pointer to the static GUID.
1851
+ */
1852
+ static ggml_guid_t ggml_backend_cann_guid() {
1853
+ static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
1854
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
1855
+ return &guid;
1856
+ }
1857
+
1858
+ // backend device
1859
+ struct ggml_backend_cann_device_context {
1860
+ int device;
1861
+ std::string name;
1862
+ std::string description;
1863
+ };
1864
+
1865
+ static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
1866
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1867
+ return ctx->name.c_str();
1868
+ }
1869
+
1870
+ static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
1871
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1872
+ return ctx->description.c_str();
1873
+ }
1874
+
1875
+ static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1876
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1877
+ ggml_backend_cann_get_device_memory(ctx->device, free, total);
1878
+ }
1879
+
1880
+ static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
1881
+ GGML_UNUSED(dev);
1882
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
1883
+ }
1884
+
1885
+ static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
1886
+ props->name = ggml_backend_cann_device_get_name(dev);
1887
+ props->description = ggml_backend_cann_device_get_description(dev);
1888
+ props->type = ggml_backend_cann_device_get_type(dev);
1889
+ ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
1890
+
1891
+ bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
1892
+
1893
+ props->caps = {
1894
+ /* .async = */ false,
1895
+ /* .host_buffer = */ host_buffer,
1896
+ /* .buffer_from_host_ptr = */ false,
1897
+ /* .events = */ true,
1898
+ };
1899
+ }
1900
+
1901
+ static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
1902
+ GGML_UNUSED(params);
1903
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1904
+ return ggml_backend_cann_init(ctx->device);
1905
+ }
1906
+
1907
+ /**
1908
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
1909
+ *
1910
+ * This function determines whether the CANN backend supports the given backend
1911
+ * buffer type by comparing the device context of the backend and buffer type.
1912
+ * It returns true if the devices are same between the backend context and
1913
+ * buffer type context.
1914
+ *
1915
+ * @param backend Pointer to the CANN backend.
1916
+ * @param buft Pointer to the backend buffer type to check.
1917
+ * @return bool Returns true if the CANN backend supports the buffer type,
1918
+ * otherwise false.
1919
+ */
1920
+ static bool ggml_backend_cann_supports_buft(
1921
+ ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1922
+ if (ggml_backend_buft_is_cann(buft)) {
1923
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
1924
+ ggml_backend_cann_buffer_type_context * buft_ctx =
1925
+ (ggml_backend_cann_buffer_type_context *)buft->context;
1926
+ return buft_ctx->device == dev_ctx->device;
1927
+ }
1928
+ return false;
1929
+ }
1930
+
1931
+ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
1932
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1933
+ return ggml_backend_cann_buffer_type(ctx->device);
1934
+ }
1935
+
1936
+ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
1937
+ GGML_UNUSED(dev);
1938
+ return ggml_backend_cann_host_buffer_type();
1939
+ }
1940
+
1941
+ /**
1942
+ * @brief Creates a new event for the CANN backend device.
1943
+ *
1944
+ * This function initializes a new event for the CANN backend by setting the
1945
+ * device and creating an ACL runtime event. The created event is then wrapped
1946
+ * in a ggml_backend_event structure and returned.
1947
+ *
1948
+ * @param backend Pointer to the CANN backend.
1949
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
1950
+ */
1951
+ static ggml_backend_event_t ggml_backend_cann_device_event_new(
1952
+ ggml_backend_dev_t dev) {
1953
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
1954
+
1955
+ ggml_cann_set_device(dev_ctx->device);
1956
+
1957
+ aclrtEvent event;
1958
+ ACL_CHECK(aclrtCreateEvent(&event));
1959
+
1960
+ return new ggml_backend_event{
1961
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
1962
+ /* .context = */ event,
1963
+ };
1964
+ }
1965
+
1966
+ /**
1967
+ * @brief Frees a CANN backend event.
1968
+ *
1969
+ * This function destroys the ACL runtime event associated with the given CANN
1970
+ * backend event and then deletes the event structure itself.
1971
+ *
1972
+ * @param event Pointer to the event structure to be freed.
1973
+ */
1974
+ static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
1975
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
1976
+
1977
+ delete event;
1978
+ GGML_UNUSED(dev);
1979
+ }
1980
+
1981
+ /**
1982
+ * @brief Synchronizes the given event on the CANN backend.
1983
+ *
1984
+ * This function waits for the specified event to complete on the ACL runtime.
1985
+ *
1986
+ * @param event Pointer to the event structure to be synchronized.
1987
+ */
1988
+ static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
1989
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
1990
+
1991
+ GGML_UNUSED(dev);
1992
+ }
1993
+
1994
+ static const ggml_backend_device_i ggml_backend_cann_device_interface = {
1995
+ /* .get_name = */ ggml_backend_cann_device_get_name,
1996
+ /* .get_description = */ ggml_backend_cann_device_get_description,
1997
+ /* .get_memory = */ ggml_backend_cann_device_get_memory,
1998
+ /* .get_type = */ ggml_backend_cann_device_get_type,
1999
+ /* .get_props = */ ggml_backend_cann_device_get_props,
2000
+ /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2001
+ /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
2002
+ /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
2003
+ /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2004
+ /* .supports_op = */ ggml_backend_cann_supports_op,
2005
+ /* .supports_buft = */ ggml_backend_cann_supports_buft,
2006
+ /* .offload_op = */ ggml_backend_cann_offload_op,
2007
+ /* .event_new = */ ggml_backend_cann_device_event_new,
2008
+ /* .event_free = */ ggml_backend_cann_device_event_free,
2009
+ /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
2010
+ };
2011
+
2012
+
2013
+ // backend reg
2014
+ struct ggml_backend_cann_reg_context {
2015
+ std::vector<ggml_backend_dev_t> devices;
2016
+ };
2017
+
2018
+ static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2019
+ GGML_UNUSED(reg);
2020
+ return GGML_CANN_NAME;
2021
+ }
2022
+
2023
+ static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2024
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2025
+ return ctx->devices.size();
2026
+ }
2027
+
2028
+ static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2029
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2030
+ GGML_ASSERT(index < ctx->devices.size());
2031
+ return ctx->devices[index];
2032
+ }
2033
+
2034
+ static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2035
+ GGML_UNUSED(reg);
2036
+ GGML_UNUSED(name);
2037
+ // reserved for future use
2038
+ return nullptr;
2039
+ }
2040
+
2041
+ static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2042
+ /* .get_name = */ ggml_backend_cann_reg_get_name,
2043
+ /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
2044
+ /* .get_device_get = */ ggml_backend_cann_reg_get_device,
2045
+ /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
2046
+ };
2047
+
2048
+ // backend registry, called only once for cann backend
2049
+ ggml_backend_reg_t ggml_backend_cann_reg() {
2050
+ static ggml_backend_reg reg;
2051
+ static bool initialized = false;
2052
+
2053
+ {
2054
+ static std::mutex mutex;
2055
+ std::lock_guard<std::mutex> lock(mutex);
2056
+ if (!initialized) {
2057
+ aclInit(nullptr);
2058
+ ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2059
+
2060
+ for (int i = 0; i < ggml_cann_info().device_count; i++) {
2061
+ ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
2062
+ dev_ctx->description = aclrtGetSocName();
2063
+ dev_ctx->device = i;
2064
+ dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2065
+ ggml_cann_set_device(i);
2066
+ ggml_backend_dev_t dev = new ggml_backend_device {
2067
+ /* .interface = */ ggml_backend_cann_device_interface,
2068
+ /* .reg = */ &reg,
2069
+ /* .context = */ dev_ctx
2070
+ };
2071
+ ctx->devices.push_back(dev);
2072
+ }
2073
+
2074
+ reg = ggml_backend_reg {
2075
+ /* .interface = */ ggml_backend_cann_reg_interface,
2076
+ /* .context = */ ctx
2077
+ };
2078
+ }
2079
+
2080
+ initialized = true;
2081
+ }
2082
+
2083
+ return &reg;
2084
+ }
2085
+
2086
+ ggml_backend_t ggml_backend_cann_init(int32_t device) {
2087
+ aclInit(nullptr);
2088
+ if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
2089
+ GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
2090
+ return nullptr;
2091
+ }
2092
+
2093
+ ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
2094
+ if (ctx == nullptr) {
2095
+ GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
2096
+ return nullptr;
2097
+ }
2098
+ ggml_cann_set_device(ctx->device);
2099
+ ggml_backend_t cann_backend =
2100
+ new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
2101
+ /* .interface = */ ggml_backend_cann_interface,
2102
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2103
+ /* .context = */ ctx};
2104
+
2105
+ return cann_backend;
2106
+ }
2107
+
2108
+ bool ggml_backend_is_cann(ggml_backend_t backend) {
2109
+ return backend != NULL &&
2110
+ ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2111
+ }
2112
+
2113
+ int32_t ggml_backend_cann_get_device_count() {
2114
+ return ggml_cann_info().device_count;
2115
+ }
2116
+
2117
+ void ggml_backend_cann_get_device_description(
2118
+ int32_t device, char* description, size_t description_size) {
2119
+ ggml_cann_set_device(device);
2120
+ const char* soc_name = aclrtGetSocName();
2121
+ snprintf(description, description_size, "%s", soc_name);
2122
+ }
2123
+
2124
+ void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
2125
+ size_t* total) {
2126
+ ggml_cann_set_device(device);
2127
+ ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
2128
+ }