w11wo commited on
Commit
b704bc0
1 Parent(s): 82c888a

Added Reduced onnxruntime.xcframework

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.a filter=lfs diff=lfs merge=lfs -text
1.15.1/onnxruntime.xcframework/Headers/coreml_provider_factory.h ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+ #pragma once
4
+
5
+ #include "onnxruntime_c_api.h"
6
+
7
+ // COREMLFlags are bool options we want to set for CoreML EP
8
+ // This enum is defined as bit flags, and cannot have negative value
9
+ // To generate an uint32_t coreml_flags for using with OrtSessionOptionsAppendExecutionProvider_CoreML below,
10
+ // uint32_t coreml_flags = 0;
11
+ // coreml_flags |= COREML_FLAG_USE_CPU_ONLY;
12
+ enum COREMLFlags {
13
+ COREML_FLAG_USE_NONE = 0x000,
14
+
15
+ // Using CPU only in CoreML EP, this may decrease the perf but will provide
16
+ // reference output value without precision loss, which is useful for validation
17
+ COREML_FLAG_USE_CPU_ONLY = 0x001,
18
+
19
+ // Enable CoreML EP on subgraph
20
+ COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
21
+
22
+ // By default CoreML Execution provider will be enabled for all compatible Apple devices
23
+ // Enable this option will only enable CoreML EP for Apple devices with ANE (Apple Neural Engine)
24
+ // Please note, enable this option does not guarantee the entire model to be executed using ANE only
25
+ COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
26
+
27
+ // Keep COREML_FLAG_MAX at the end of the enum definition
28
+ // And assign the last COREMLFlag to it
29
+ COREML_FLAG_LAST = COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE,
30
+ };
31
+
32
+ #ifdef __cplusplus
33
+ extern "C" {
34
+ #endif
35
+
36
+ ORT_EXPORT ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CoreML,
37
+ _In_ OrtSessionOptions* options, uint32_t coreml_flags);
38
+
39
+ #ifdef __cplusplus
40
+ }
41
+ #endif
1.15.1/onnxruntime.xcframework/Headers/cpu_provider_factory.h ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ #include "onnxruntime_c_api.h"
5
+
6
+ #ifdef __cplusplus
7
+ extern "C" {
8
+ #endif
9
+
10
+ /**
11
+ * \param use_arena zero: false. non-zero: true.
12
+ */
13
+ ORT_EXPORT
14
+ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CPU, _In_ OrtSessionOptions* options, int use_arena)
15
+ ORT_ALL_ARGS_NONNULL;
16
+
17
+ #ifdef __cplusplus
18
+ }
19
+ #endif
1.15.1/onnxruntime.xcframework/Headers/onnxruntime_c_api.h ADDED
The diff for this file is too large to render. See raw diff
 
1.15.1/onnxruntime.xcframework/Headers/onnxruntime_cxx_api.h ADDED
@@ -0,0 +1,1878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5
+ //
6
+ // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7
+ // and automatically releasing resources in the destructors. The primary purpose of C++ API is exception safety so
8
+ // all the resources follow RAII and do not leak memory.
9
+ //
10
+ // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
11
+ // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). However, you can't use them
12
+ // until you assign an instance that actually holds an underlying object.
13
+ //
14
+ // For Ort objects only move assignment between objects is allowed, there are no copy constructors.
15
+ // Some objects have explicit 'Clone' methods for this purpose.
16
+ //
17
+ // ConstXXXX types are copyable since they do not own the underlying C object, so you can pass them to functions as arguments
18
+ // by value or by reference. ConstXXXX types are restricted to const only interfaces.
19
+ //
20
+ // UnownedXXXX are similar to ConstXXXX but also allow non-const interfaces.
21
+ //
22
+ // The lifetime of the corresponding owning object must eclipse the lifetimes of the ConstXXXX/UnownedXXXX types. They exists so you do not
23
+ // have to fallback to C types and the API with the usual pitfalls. In general, do not use C API from your C++ code.
24
+
25
+ #pragma once
26
+ #include "onnxruntime_c_api.h"
27
+ #include <cstddef>
28
+ #include <array>
29
+ #include <memory>
30
+ #include <stdexcept>
31
+ #include <string>
32
+ #include <vector>
33
+ #include <unordered_map>
34
+ #include <utility>
35
+ #include <type_traits>
36
+
37
+ #ifdef ORT_NO_EXCEPTIONS
38
+ #include <iostream>
39
+ #endif
40
+
41
+ /** \brief All C++ Onnxruntime APIs are defined inside this namespace
42
+ *
43
+ */
44
+ namespace Ort {
45
+
46
+ /** \brief All C++ methods that can fail will throw an exception of this type
47
+ *
48
+ * If <tt>ORT_NO_EXCEPTIONS</tt> is defined, then any error will result in a call to abort()
49
+ */
50
+ struct Exception : std::exception {
51
+ Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
52
+
53
+ OrtErrorCode GetOrtErrorCode() const { return code_; }
54
+ const char* what() const noexcept override { return message_.c_str(); }
55
+
56
+ private:
57
+ std::string message_;
58
+ OrtErrorCode code_;
59
+ };
60
+
61
+ #ifdef ORT_NO_EXCEPTIONS
62
+ // The #ifndef is for the very special case where the user of this library wants to define their own way of handling errors.
63
+ // NOTE: This header expects control flow to not continue after calling ORT_CXX_API_THROW
64
+ #ifndef ORT_CXX_API_THROW
65
+ #define ORT_CXX_API_THROW(string, code) \
66
+ do { \
67
+ std::cerr << Ort::Exception(string, code) \
68
+ .what() \
69
+ << std::endl; \
70
+ abort(); \
71
+ } while (false)
72
+ #endif
73
+ #else
74
+ #define ORT_CXX_API_THROW(string, code) \
75
+ throw Ort::Exception(string, code)
76
+ #endif
77
+
78
+ // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi,
79
+ // it's in a template so that we can define a global variable in a header and make
80
+ // it transparent to the users of the API.
81
+ template <typename T>
82
+ struct Global {
83
+ static const OrtApi* api_;
84
+ };
85
+
86
+ // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
87
+ template <typename T>
88
+ #ifdef ORT_API_MANUAL_INIT
89
+ const OrtApi* Global<T>::api_{};
90
+ inline void InitApi() noexcept { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
91
+
92
+ // Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is
93
+ // required by C++ APIs.
94
+ //
95
+ // Example mycustomop.cc:
96
+ //
97
+ // #define ORT_API_MANUAL_INIT
98
+ // #include <onnxruntime_cxx_api.h>
99
+ // #undef ORT_API_MANUAL_INIT
100
+ //
101
+ // OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api_base) {
102
+ // Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
103
+ // // ...
104
+ // }
105
+ //
106
+ inline void InitApi(const OrtApi* api) noexcept { Global<void>::api_ = api; }
107
+ #else
108
+ #if defined(_MSC_VER) && !defined(__clang__)
109
+ #pragma warning(push)
110
+ // "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers.
111
+ // Please define ORT_API_MANUAL_INIT if it conerns you.
112
+ #pragma warning(disable : 26426)
113
+ #endif
114
+ const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
115
+ #if defined(_MSC_VER) && !defined(__clang__)
116
+ #pragma warning(pop)
117
+ #endif
118
+ #endif
119
+
120
+ /// This returns a reference to the OrtApi interface in use
121
+ inline const OrtApi& GetApi() noexcept { return *Global<void>::api_; }
122
+
123
+ /// <summary>
124
+ /// This is a C++ wrapper for OrtApi::GetAvailableProviders() and
125
+ /// returns a vector of strings representing the available execution providers.
126
+ /// </summary>
127
+ /// <returns>vector of strings</returns>
128
+ std::vector<std::string> GetAvailableProviders();
129
+
130
+ /** \brief IEEE 754 half-precision floating point data type
131
+ * \details It is necessary for type dispatching to make use of C++ API
132
+ * The type is implicitly convertible to/from uint16_t.
133
+ * The size of the structure should align with uint16_t and one can freely cast
134
+ * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
135
+ *
136
+ * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
137
+ * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
138
+ * And you can also feed a array of uint16_t elements directly. For example,
139
+ *
140
+ * \code{.unparsed}
141
+ * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
142
+ * constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
143
+ * std::vector<int64_t> dims = {values_length}; // one dimensional example
144
+ * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
145
+ * // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
146
+ * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
147
+ * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
148
+ * \endcode
149
+ *
150
+ * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
151
+ * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
152
+ * template specialization.
153
+ *
154
+ * \code{.unparsed}
155
+ * namespace yours { struct half {}; } // assume this is your type, define this:
156
+ * namespace Ort {
157
+ * template<>
158
+ * struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
159
+ * } //namespace Ort
160
+ *
161
+ * std::vector<yours::half> values;
162
+ * std::vector<int64_t> dims = {values.size()}; // one dimensional example
163
+ * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
164
+ * // Here we are passing element count -> values.size()
165
+ * auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
166
+ *
167
+ * \endcode
168
+ */
169
+ struct Float16_t {
170
+ uint16_t value;
171
+ constexpr Float16_t() noexcept : value(0) {}
172
+ constexpr Float16_t(uint16_t v) noexcept : value(v) {}
173
+ constexpr operator uint16_t() const noexcept { return value; }
174
+ constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
175
+ constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
176
+ };
177
+
178
+ static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
179
+
180
+ /** \brief bfloat16 (Brain Floating Point) data type
181
+ * \details It is necessary for type dispatching to make use of C++ API
182
+ * The type is implicitly convertible to/from uint16_t.
183
+ * The size of the structure should align with uint16_t and one can freely cast
184
+ * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
185
+ *
186
+ * See also code examples for Float16_t above.
187
+ */
188
+ struct BFloat16_t {
189
+ uint16_t value;
190
+ constexpr BFloat16_t() noexcept : value(0) {}
191
+ constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
192
+ constexpr operator uint16_t() const noexcept { return value; }
193
+ constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
194
+ constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
195
+ };
196
+
197
+ static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
198
+
199
+ namespace detail {
200
+ // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
201
+ // This can't be done in the C API since C doesn't have function overloading.
202
+ #define ORT_DEFINE_RELEASE(NAME) \
203
+ inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
204
+
205
+ ORT_DEFINE_RELEASE(Allocator);
206
+ ORT_DEFINE_RELEASE(MemoryInfo);
207
+ ORT_DEFINE_RELEASE(CustomOpDomain);
208
+ ORT_DEFINE_RELEASE(ThreadingOptions);
209
+ ORT_DEFINE_RELEASE(Env);
210
+ ORT_DEFINE_RELEASE(RunOptions);
211
+ ORT_DEFINE_RELEASE(Session);
212
+ ORT_DEFINE_RELEASE(SessionOptions);
213
+ ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
214
+ ORT_DEFINE_RELEASE(SequenceTypeInfo);
215
+ ORT_DEFINE_RELEASE(MapTypeInfo);
216
+ ORT_DEFINE_RELEASE(TypeInfo);
217
+ ORT_DEFINE_RELEASE(Value);
218
+ ORT_DEFINE_RELEASE(ModelMetadata);
219
+ ORT_DEFINE_RELEASE(IoBinding);
220
+ ORT_DEFINE_RELEASE(ArenaCfg);
221
+ ORT_DEFINE_RELEASE(Status);
222
+ ORT_DEFINE_RELEASE(OpAttr);
223
+ ORT_DEFINE_RELEASE(Op);
224
+ ORT_DEFINE_RELEASE(KernelInfo);
225
+
226
+ #undef ORT_DEFINE_RELEASE
227
+
228
+ /** \brief This is a tagging template type. Use it with Base<T> to indicate that the C++ interface object
229
+ * has no ownership of the underlying C object.
230
+ */
231
+ template <typename T>
232
+ struct Unowned {
233
+ using Type = T;
234
+ };
235
+
236
+ /** \brief Used internally by the C++ API. C++ wrapper types inherit from this.
237
+ * This is a zero cost abstraction to wrap the C API objects and delete them on destruction.
238
+ *
239
+ * All of the C++ classes
240
+ * a) serve as containers for pointers to objects that are created by the underlying C API.
241
+ * Their size is just a pointer size, no need to dynamically allocate them. Use them by value.
242
+ * b) Each of struct XXXX, XXX instances function as smart pointers to the underlying C API objects.
243
+ * they would release objects owned automatically when going out of scope, they are move-only.
244
+ * c) ConstXXXX and UnownedXXX structs function as non-owning, copyable containers for the above pointers.
245
+ * ConstXXXX allow calling const interfaces only. They give access to objects that are owned by somebody else
246
+ * such as Onnxruntime or instances of XXXX classes.
247
+ * d) serve convenient interfaces that return C++ objects and further enhance exception and type safety so they can be used
248
+ * in C++ code.
249
+ *
250
+ */
251
+
252
+ /// <summary>
253
+ /// This is a non-const pointer holder that is move-only. Disposes of the pointer on destruction.
254
+ /// </summary>
255
+ template <typename T>
256
+ struct Base {
257
+ using contained_type = T;
258
+
259
+ constexpr Base() = default;
260
+ constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
261
+ ~Base() { OrtRelease(p_); }
262
+
263
+ Base(const Base&) = delete;
264
+ Base& operator=(const Base&) = delete;
265
+
266
+ Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
267
+ Base& operator=(Base&& v) noexcept {
268
+ OrtRelease(p_);
269
+ p_ = v.release();
270
+ return *this;
271
+ }
272
+
273
+ constexpr operator contained_type*() const noexcept { return p_; }
274
+
275
+ /// \brief Relinquishes ownership of the contained C object pointer
276
+ /// The underlying object is not destroyed
277
+ contained_type* release() {
278
+ T* p = p_;
279
+ p_ = nullptr;
280
+ return p;
281
+ }
282
+
283
+ protected:
284
+ contained_type* p_{};
285
+ };
286
+
287
+ // Undefined. For const types use Base<Unowned<const T>>
288
+ template <typename T>
289
+ struct Base<const T>;
290
+
291
+ /// <summary>
292
+ /// Covers unowned pointers owned by either the ORT
293
+ /// or some other instance of CPP wrappers.
294
+ /// Used for ConstXXX and UnownedXXXX types that are copyable.
295
+ /// Also convenient to wrap raw OrtXX pointers .
296
+ /// </summary>
297
+ /// <typeparam name="T"></typeparam>
298
+ template <typename T>
299
+ struct Base<Unowned<T>> {
300
+ using contained_type = typename Unowned<T>::Type;
301
+
302
+ constexpr Base() = default;
303
+ constexpr explicit Base(contained_type* p) noexcept : p_{p} {}
304
+
305
+ ~Base() = default;
306
+
307
+ Base(const Base&) = default;
308
+ Base& operator=(const Base&) = default;
309
+
310
+ Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
311
+ Base& operator=(Base&& v) noexcept {
312
+ p_ = nullptr;
313
+ std::swap(p_, v.p_);
314
+ return *this;
315
+ }
316
+
317
+ constexpr operator contained_type*() const noexcept { return p_; }
318
+
319
+ protected:
320
+ contained_type* p_{};
321
+ };
322
+
323
+ // Light functor to release memory with OrtAllocator
324
+ struct AllocatedFree {
325
+ OrtAllocator* allocator_;
326
+ explicit AllocatedFree(OrtAllocator* allocator)
327
+ : allocator_(allocator) {}
328
+ void operator()(void* ptr) const {
329
+ if (ptr) allocator_->Free(allocator_, ptr);
330
+ }
331
+ };
332
+
333
+ } // namespace detail
334
+
335
+ struct AllocatorWithDefaultOptions;
336
+ struct Env;
337
+ struct TypeInfo;
338
+ struct Value;
339
+ struct ModelMetadata;
340
+
341
+ /** \brief unique_ptr typedef used to own strings allocated by OrtAllocators
342
+ * and release them at the end of the scope. The lifespan of the given allocator
343
+ * must eclipse the lifespan of AllocatedStringPtr instance
344
+ */
345
+ using AllocatedStringPtr = std::unique_ptr<char, detail::AllocatedFree>;
346
+
347
+ /** \brief The Status that holds ownership of OrtStatus received from C API
348
+ * Use it to safely destroy OrtStatus* returned from the C API. Use appropriate
349
+ * constructors to construct an instance of a Status object from exceptions.
350
+ */
351
+ struct Status : detail::Base<OrtStatus> {
352
+ explicit Status(std::nullptr_t) {} ///< Create an empty object, must be assigned a valid one to be used
353
+ explicit Status(OrtStatus* status); ///< Takes ownership of OrtStatus instance returned from the C API. Must be non-null
354
+ explicit Status(const Exception&); ///< Creates status instance out of exception
355
+ explicit Status(const std::exception&); ///< Creates status instance out of exception
356
+ std::string GetErrorMessage() const;
357
+ OrtErrorCode GetErrorCode() const;
358
+ };
359
+
360
+ /** \brief The ThreadingOptions
361
+ *
362
+ * The ThreadingOptions used for set global threadpools' options of The Env.
363
+ */
364
+ struct ThreadingOptions : detail::Base<OrtThreadingOptions> {
365
+ /// \brief Wraps OrtApi::CreateThreadingOptions
366
+ ThreadingOptions();
367
+
368
+ /// \brief Wraps OrtApi::SetGlobalIntraOpNumThreads
369
+ ThreadingOptions& SetGlobalIntraOpNumThreads(int intra_op_num_threads);
370
+
371
+ /// \brief Wraps OrtApi::SetGlobalInterOpNumThreads
372
+ ThreadingOptions& SetGlobalInterOpNumThreads(int inter_op_num_threads);
373
+
374
+ /// \brief Wraps OrtApi::SetGlobalSpinControl
375
+ ThreadingOptions& SetGlobalSpinControl(int allow_spinning);
376
+
377
+ /// \brief Wraps OrtApi::SetGlobalDenormalAsZero
378
+ ThreadingOptions& SetGlobalDenormalAsZero();
379
+
380
+ /// \brief Wraps OrtApi::SetGlobalCustomCreateThreadFn
381
+ ThreadingOptions& SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn);
382
+
383
+ /// \brief Wraps OrtApi::SetGlobalCustomThreadCreationOptions
384
+ ThreadingOptions& SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options);
385
+
386
+ /// \brief Wraps OrtApi::SetGlobalCustomJoinThreadFn
387
+ ThreadingOptions& SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn);
388
+ };
389
+
390
+ /** \brief The Env (Environment)
391
+ *
392
+ * The Env holds the logging state used by all other objects.
393
+ * <b>Note:</b> One Env must be created before using any other Onnxruntime functionality
394
+ */
395
+ struct Env : detail::Base<OrtEnv> {
396
+ explicit Env(std::nullptr_t) {} ///< Create an empty Env object, must be assigned a valid one to be used
397
+
398
+ /// \brief Wraps OrtApi::CreateEnv
399
+ Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
400
+
401
+ /// \brief Wraps OrtApi::CreateEnvWithCustomLogger
402
+ Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
403
+
404
+ /// \brief Wraps OrtApi::CreateEnvWithGlobalThreadPools
405
+ Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
406
+
407
+ /// \brief Wraps OrtApi::CreateEnvWithCustomLoggerAndGlobalThreadPools
408
+ Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
409
+ OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
410
+
411
+ /// \brief C Interop Helper
412
+ explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
413
+
414
+ Env& EnableTelemetryEvents(); ///< Wraps OrtApi::EnableTelemetryEvents
415
+ Env& DisableTelemetryEvents(); ///< Wraps OrtApi::DisableTelemetryEvents
416
+
417
+ Env& UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level); ///< Wraps OrtApi::UpdateEnvWithCustomLogLevel
418
+
419
+ Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); ///< Wraps OrtApi::CreateAndRegisterAllocator
420
+ };
421
+
422
+ /** \brief Custom Op Domain
423
+ *
424
+ */
425
+ struct CustomOpDomain : detail::Base<OrtCustomOpDomain> {
426
+ explicit CustomOpDomain(std::nullptr_t) {} ///< Create an empty CustomOpDomain object, must be assigned a valid one to be used
427
+
428
+ /// \brief Wraps OrtApi::CreateCustomOpDomain
429
+ explicit CustomOpDomain(const char* domain);
430
+
431
+ // This does not take ownership of the op, simply registers it.
432
+ void Add(const OrtCustomOp* op); ///< Wraps CustomOpDomain_Add
433
+ };
434
+
435
+ /** \brief RunOptions
436
+ *
437
+ */
438
+ struct RunOptions : detail::Base<OrtRunOptions> {
439
+ explicit RunOptions(std::nullptr_t) {} ///< Create an empty RunOptions object, must be assigned a valid one to be used
440
+ RunOptions(); ///< Wraps OrtApi::CreateRunOptions
441
+
442
+ RunOptions& SetRunLogVerbosityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogVerbosityLevel
443
+ int GetRunLogVerbosityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogVerbosityLevel
444
+
445
+ RunOptions& SetRunLogSeverityLevel(int); ///< Wraps OrtApi::RunOptionsSetRunLogSeverityLevel
446
+ int GetRunLogSeverityLevel() const; ///< Wraps OrtApi::RunOptionsGetRunLogSeverityLevel
447
+
448
+ RunOptions& SetRunTag(const char* run_tag); ///< wraps OrtApi::RunOptionsSetRunTag
449
+ const char* GetRunTag() const; ///< Wraps OrtApi::RunOptionsGetRunTag
450
+
451
+ RunOptions& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddRunConfigEntry
452
+
453
+ /** \brief Terminates all currently executing Session::Run calls that were made using this RunOptions instance
454
+ *
455
+ * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error
456
+ * Wraps OrtApi::RunOptionsSetTerminate
457
+ */
458
+ RunOptions& SetTerminate();
459
+
460
+ /** \brief Clears the terminate flag so this RunOptions instance can be used in a new Session::Run call without it instantly terminating
461
+ *
462
+ * Wraps OrtApi::RunOptionsUnsetTerminate
463
+ */
464
+ RunOptions& UnsetTerminate();
465
+ };
466
+
467
+
468
+ namespace detail {
469
+ // Utility function that returns a SessionOption config entry key for a specific custom operator.
470
+ // Ex: custom_op.[custom_op_name].[config]
471
+ std::string MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config);
472
+ } // namespace detail
473
+
474
+ /// <summary>
475
+ /// Class that represents session configuration entries for one or more custom operators.
476
+ ///
477
+ /// Example:
478
+ /// Ort::CustomOpConfigs op_configs;
479
+ /// op_configs.AddConfig("my_custom_op", "device_type", "CPU");
480
+ ///
481
+ /// Passed to Ort::SessionOptions::RegisterCustomOpsLibrary.
482
+ /// </summary>
483
+ struct CustomOpConfigs {
484
+ CustomOpConfigs() = default;
485
+ ~CustomOpConfigs() = default;
486
+ CustomOpConfigs(const CustomOpConfigs&) = default;
487
+ CustomOpConfigs& operator=(const CustomOpConfigs&) = default;
488
+ CustomOpConfigs(CustomOpConfigs&& o) = default;
489
+ CustomOpConfigs& operator=(CustomOpConfigs&& o) = default;
490
+
491
+ /** \brief Adds a session configuration entry/value for a specific custom operator.
492
+ *
493
+ * \param custom_op_name The name of the custom operator for which to add a configuration entry.
494
+ * Must match the name returned by the CustomOp's GetName() method.
495
+ * \param config_key The name of the configuration entry.
496
+ * \param config_value The value of the configuration entry.
497
+ * \return A reference to this object to enable call chaining.
498
+ */
499
+ CustomOpConfigs& AddConfig(const char* custom_op_name, const char* config_key, const char* config_value);
500
+
501
+ /** \brief Returns a flattened map of custom operator configuration entries and their values.
502
+ *
503
+ * The keys has been flattened to include both the custom operator name and the configuration entry key name.
504
+ * For example, a prior call to AddConfig("my_op", "key", "value") corresponds to the flattened key/value pair
505
+ * {"my_op.key", "value"}.
506
+ *
507
+ * \return An unordered map of flattened configurations.
508
+ */
509
+ const std::unordered_map<std::string, std::string>& GetFlattenedConfigs() const;
510
+
511
+ private:
512
+ std::unordered_map<std::string, std::string> flat_configs_;
513
+ };
514
+
515
+ /** \brief Options object used when creating a new Session object
516
+ *
517
+ * Wraps ::OrtSessionOptions object and methods
518
+ */
519
+
520
+ struct SessionOptions;
521
+
522
+ namespace detail {
523
+ // we separate const-only methods because passing const ptr to non-const methods
524
+ // is only discovered when inline methods are compiled which is counter-intuitive
525
+ template <typename T>
526
+ struct ConstSessionOptionsImpl : Base<T> {
527
+ using B = Base<T>;
528
+ using B::B;
529
+
530
+ SessionOptions Clone() const; ///< Creates and returns a copy of this SessionOptions object. Wraps OrtApi::CloneSessionOptions
531
+
532
+ std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
533
+ bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
534
+ std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def);
535
+ };
536
+
537
+ template <typename T>
538
+ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
539
+ using B = ConstSessionOptionsImpl<T>;
540
+ using B::B;
541
+
542
+ SessionOptionsImpl& SetIntraOpNumThreads(int intra_op_num_threads); ///< Wraps OrtApi::SetIntraOpNumThreads
543
+ SessionOptionsImpl& SetInterOpNumThreads(int inter_op_num_threads); ///< Wraps OrtApi::SetInterOpNumThreads
544
+ SessionOptionsImpl& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::SetSessionGraphOptimizationLevel
545
+
546
+ SessionOptionsImpl& EnableCpuMemArena(); ///< Wraps OrtApi::EnableCpuMemArena
547
+ SessionOptionsImpl& DisableCpuMemArena(); ///< Wraps OrtApi::DisableCpuMemArena
548
+
549
+ SessionOptionsImpl& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); ///< Wraps OrtApi::SetOptimizedModelFilePath
550
+
551
+ SessionOptionsImpl& EnableProfiling(const ORTCHAR_T* profile_file_prefix); ///< Wraps OrtApi::EnableProfiling
552
+ SessionOptionsImpl& DisableProfiling(); ///< Wraps OrtApi::DisableProfiling
553
+
554
+ SessionOptionsImpl& EnableOrtCustomOps(); ///< Wraps OrtApi::EnableOrtCustomOps
555
+
556
+ SessionOptionsImpl& EnableMemPattern(); ///< Wraps OrtApi::EnableMemPattern
557
+ SessionOptionsImpl& DisableMemPattern(); ///< Wraps OrtApi::DisableMemPattern
558
+
559
+ SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
560
+
561
+ SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
562
+ SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
563
+
564
+ SessionOptionsImpl& Add(OrtCustomOpDomain* custom_op_domain); ///< Wraps OrtApi::AddCustomOpDomain
565
+
566
+ SessionOptionsImpl& DisablePerSessionThreads(); ///< Wraps OrtApi::DisablePerSessionThreads
567
+
568
+ SessionOptionsImpl& AddConfigEntry(const char* config_key, const char* config_value); ///< Wraps OrtApi::AddSessionConfigEntry
569
+
570
+ SessionOptionsImpl& AddInitializer(const char* name, const OrtValue* ort_val); ///< Wraps OrtApi::AddInitializer
571
+ SessionOptionsImpl& AddExternalInitializers(const std::vector<std::string>& names, const std::vector<Value>& ort_values); ///< Wraps OrtApi::AddExternalInitializers
572
+
573
+ SessionOptionsImpl& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA
574
+ SessionOptionsImpl& AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CUDA_V2
575
+ SessionOptionsImpl& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM
576
+ SessionOptionsImpl& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
577
+ SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
578
+ SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
579
+ SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
580
+ ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
581
+ SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
582
+ ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
583
+ SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
584
+ /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports SNPE and XNNPACK.
585
+ SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
586
+ const std::unordered_map<std::string, std::string>& provider_options = {});
587
+
588
+ SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
589
+ SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
590
+ SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
591
+
592
+ ///< Registers the custom operator from the specified shared library via OrtApi::RegisterCustomOpsLibrary_V2.
593
+ ///< The custom operator configurations are optional. If provided, custom operator configs are set via
594
+ ///< OrtApi::AddSessionConfigEntry.
595
+ SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {});
596
+
597
+ SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction
598
+ };
599
+ } // namespace detail
600
+
601
+ using UnownedSessionOptions = detail::SessionOptionsImpl<detail::Unowned<OrtSessionOptions>>;
602
+ using ConstSessionOptions = detail::ConstSessionOptionsImpl<detail::Unowned<const OrtSessionOptions>>;
603
+
604
+ /** \brief Wrapper around ::OrtSessionOptions
605
+ *
606
+ */
607
+ struct SessionOptions : detail::SessionOptionsImpl<OrtSessionOptions> {
608
+ explicit SessionOptions(std::nullptr_t) {} ///< Create an empty SessionOptions object, must be assigned a valid one to be used
609
+ SessionOptions(); ///< Wraps OrtApi::CreateSessionOptions
610
+ explicit SessionOptions(OrtSessionOptions* p) : SessionOptionsImpl<OrtSessionOptions>{p} {} ///< Used for interop with the C API
611
+ UnownedSessionOptions GetUnowned() const { return UnownedSessionOptions{this->p_}; }
612
+ ConstSessionOptions GetConst() const { return ConstSessionOptions{this->p_}; }
613
+ };
614
+
615
+ /** \brief Wrapper around ::OrtModelMetadata
616
+ *
617
+ */
618
+ struct ModelMetadata : detail::Base<OrtModelMetadata> {
619
+ explicit ModelMetadata(std::nullptr_t) {} ///< Create an empty ModelMetadata object, must be assigned a valid one to be used
620
+ explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} ///< Used for interop with the C API
621
+
622
+ /** \brief Returns a copy of the producer name.
623
+ *
624
+ * \param allocator to allocate memory for the copy of the name returned
625
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
626
+ * The OrtAllocator instances must be valid at the point of memory release.
627
+ */
628
+ AllocatedStringPtr GetProducerNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetProducerName
629
+
630
+ /** \brief Returns a copy of the graph name.
631
+ *
632
+ * \param allocator to allocate memory for the copy of the name returned
633
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
634
+ * The OrtAllocator instances must be valid at the point of memory release.
635
+ */
636
+ AllocatedStringPtr GetGraphNameAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphName
637
+
638
+ /** \brief Returns a copy of the domain name.
639
+ *
640
+ * \param allocator to allocate memory for the copy of the name returned
641
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
642
+ * The OrtAllocator instances must be valid at the point of memory release.
643
+ */
644
+ AllocatedStringPtr GetDomainAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDomain
645
+
646
+ /** \brief Returns a copy of the description.
647
+ *
648
+ * \param allocator to allocate memory for the copy of the string returned
649
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
650
+ * The OrtAllocator instances must be valid at the point of memory release.
651
+ */
652
+ AllocatedStringPtr GetDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetDescription
653
+
654
+ /** \brief Returns a copy of the graph description.
655
+ *
656
+ * \param allocator to allocate memory for the copy of the string returned
657
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
658
+ * The OrtAllocator instances must be valid at the point of memory release.
659
+ */
660
+ AllocatedStringPtr GetGraphDescriptionAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetGraphDescription
661
+
662
+ /** \brief Returns a vector of copies of the custom metadata keys.
663
+ *
664
+ * \param allocator to allocate memory for the copy of the string returned
665
+ * \return a instance std::vector of smart pointers that would deallocate the buffers when out of scope.
666
+ * The OrtAllocator instance must be valid at the point of memory release.
667
+ */
668
+ std::vector<AllocatedStringPtr> GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataGetCustomMetadataMapKeys
669
+
670
+ /** \brief Looks up a value by a key in the Custom Metadata map
671
+ *
672
+ * \param key zero terminated string key to lookup
673
+ * \param allocator to allocate memory for the copy of the string returned
674
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
675
+ * maybe nullptr if key is not found.
676
+ *
677
+ * The OrtAllocator instances must be valid at the point of memory release.
678
+ */
679
+ AllocatedStringPtr LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const; ///< Wraps OrtApi::ModelMetadataLookupCustomMetadataMap
680
+
681
+ int64_t GetVersion() const; ///< Wraps OrtApi::ModelMetadataGetVersion
682
+ };
683
+
684
+ struct IoBinding;
685
+
686
+ namespace detail {
687
+
688
+ // we separate const-only methods because passing const ptr to non-const methods
689
+ // is only discovered when inline methods are compiled which is counter-intuitive
690
+ template <typename T>
691
+ struct ConstSessionImpl : Base<T> {
692
+ using B = Base<T>;
693
+ using B::B;
694
+
695
+ size_t GetInputCount() const; ///< Returns the number of model inputs
696
+ size_t GetOutputCount() const; ///< Returns the number of model outputs
697
+ size_t GetOverridableInitializerCount() const; ///< Returns the number of inputs that have defaults that can be overridden
698
+
699
+ /** \brief Returns a copy of input name at the specified index.
700
+ *
701
+ * \param index must less than the value returned by GetInputCount()
702
+ * \param allocator to allocate memory for the copy of the name returned
703
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
704
+ * The OrtAllocator instances must be valid at the point of memory release.
705
+ */
706
+ AllocatedStringPtr GetInputNameAllocated(size_t index, OrtAllocator* allocator) const;
707
+
708
+ /** \brief Returns a copy of output name at then specified index.
709
+ *
710
+ * \param index must less than the value returned by GetOutputCount()
711
+ * \param allocator to allocate memory for the copy of the name returned
712
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
713
+ * The OrtAllocator instances must be valid at the point of memory release.
714
+ */
715
+ AllocatedStringPtr GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const;
716
+
717
+ /** \brief Returns a copy of the overridable initializer name at then specified index.
718
+ *
719
+ * \param index must less than the value returned by GetOverridableInitializerCount()
720
+ * \param allocator to allocate memory for the copy of the name returned
721
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
722
+ * The OrtAllocator instances must be valid at the point of memory release.
723
+ */
724
+ AllocatedStringPtr GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const; ///< Wraps OrtApi::SessionGetOverridableInitializerName
725
+
726
+ uint64_t GetProfilingStartTimeNs() const; ///< Wraps OrtApi::SessionGetProfilingStartTimeNs
727
+ ModelMetadata GetModelMetadata() const; ///< Wraps OrtApi::SessionGetModelMetadata
728
+
729
+ TypeInfo GetInputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetInputTypeInfo
730
+ TypeInfo GetOutputTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOutputTypeInfo
731
+ TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; ///< Wraps OrtApi::SessionGetOverridableInitializerTypeInfo
732
+ };
733
+
734
+ template <typename T>
735
+ struct SessionImpl : ConstSessionImpl<T> {
736
+ using B = ConstSessionImpl<T>;
737
+ using B::B;
738
+
739
+ /** \brief Run the model returning results in an Ort allocated vector.
740
+ *
741
+ * Wraps OrtApi::Run
742
+ *
743
+ * The caller provides a list of inputs and a list of the desired outputs to return.
744
+ *
745
+ * See the output logs for more information on warnings/errors that occur while processing the model.
746
+ * Common errors are.. (TODO)
747
+ *
748
+ * \param[in] run_options
749
+ * \param[in] input_names Array of null terminated strings of length input_count that is the list of input names
750
+ * \param[in] input_values Array of Value objects of length input_count that is the list of input values
751
+ * \param[in] input_count Number of inputs (the size of the input_names & input_values arrays)
752
+ * \param[in] output_names Array of C style strings of length output_count that is the list of output names
753
+ * \param[in] output_count Number of outputs (the size of the output_names array)
754
+ * \return A std::vector of Value objects that directly maps to the output_names array (eg. output_name[0] is the first entry of the returned vector)
755
+ */
756
+ std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
757
+ const char* const* output_names, size_t output_count);
758
+
759
+ /** \brief Run the model returning results in user provided outputs
760
+ * Same as Run(const RunOptions&, const char* const*, const Value*, size_t,const char* const*, size_t)
761
+ */
762
+ void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
763
+ const char* const* output_names, Value* output_values, size_t output_count);
764
+
765
+ void Run(const RunOptions& run_options, const IoBinding&); ///< Wraps OrtApi::RunWithBinding
766
+
767
+ /** \brief End profiling and return a copy of the profiling file name.
768
+ *
769
+ * \param allocator to allocate memory for the copy of the string returned
770
+ * \return a instance of smart pointer that would deallocate the buffer when out of scope.
771
+ * The OrtAllocator instances must be valid at the point of memory release.
772
+ */
773
+ AllocatedStringPtr EndProfilingAllocated(OrtAllocator* allocator); ///< Wraps OrtApi::SessionEndProfiling
774
+ };
775
+
776
+ } // namespace detail
777
+
778
+ using ConstSession = detail::ConstSessionImpl<detail::Unowned<const OrtSession>>;
779
+ using UnownedSession = detail::SessionImpl<detail::Unowned<OrtSession>>;
780
+
781
+ /** \brief Wrapper around ::OrtSession
782
+ *
783
+ */
784
+ struct Session : detail::SessionImpl<OrtSession> {
785
+ explicit Session(std::nullptr_t) {} ///< Create an empty Session object, must be assigned a valid one to be used
786
+ Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); ///< Wraps OrtApi::CreateSession
787
+ Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
788
+ OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionWithPrepackedWeightsContainer
789
+ Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); ///< Wraps OrtApi::CreateSessionFromArray
790
+ Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options,
791
+ OrtPrepackedWeightsContainer* prepacked_weights_container); ///< Wraps OrtApi::CreateSessionFromArrayWithPrepackedWeightsContainer
792
+
793
+ ConstSession GetConst() const { return ConstSession{this->p_}; }
794
+ UnownedSession GetUnowned() const { return UnownedSession{this->p_}; }
795
+ };
796
+
797
+ namespace detail {
798
+ template <typename T>
799
+ struct MemoryInfoImpl : Base<T> {
800
+ using B = Base<T>;
801
+ using B::B;
802
+
803
+ std::string GetAllocatorName() const;
804
+ OrtAllocatorType GetAllocatorType() const;
805
+ int GetDeviceId() const;
806
+ OrtMemoryInfoDeviceType GetDeviceType() const;
807
+ OrtMemType GetMemoryType() const;
808
+
809
+ template <typename U>
810
+ bool operator==(const MemoryInfoImpl<U>& o) const;
811
+ };
812
+ } // namespace detail
813
+
814
+ // Const object holder that does not own the underlying object
815
+ using ConstMemoryInfo = detail::MemoryInfoImpl<detail::Unowned<const OrtMemoryInfo>>;
816
+
817
+ /** \brief Wrapper around ::OrtMemoryInfo
818
+ *
819
+ */
820
+ struct MemoryInfo : detail::MemoryInfoImpl<OrtMemoryInfo> {
821
+ static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
822
+ explicit MemoryInfo(std::nullptr_t) {} ///< No instance is created
823
+ explicit MemoryInfo(OrtMemoryInfo* p) : MemoryInfoImpl<OrtMemoryInfo>{p} {} ///< Take ownership of a pointer created by C Api
824
+ MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
825
+ ConstMemoryInfo GetConst() const { return ConstMemoryInfo{this->p_}; }
826
+ };
827
+
828
+ namespace detail {
829
+ template <typename T>
830
+ struct TensorTypeAndShapeInfoImpl : Base<T> {
831
+ using B = Base<T>;
832
+ using B::B;
833
+
834
+ ONNXTensorElementDataType GetElementType() const; ///< Wraps OrtApi::GetTensorElementType
835
+ size_t GetElementCount() const; ///< Wraps OrtApi::GetTensorShapeElementCount
836
+
837
+ size_t GetDimensionsCount() const; ///< Wraps OrtApi::GetDimensionsCount
838
+
839
+ /** \deprecated use GetShape() returning std::vector
840
+ * [[deprecated]]
841
+ * This interface is unsafe to use
842
+ */
843
+ [[deprecated("use GetShape()")]] void GetDimensions(int64_t* values, size_t values_count) const; ///< Wraps OrtApi::GetDimensions
844
+
845
+ void GetSymbolicDimensions(const char** values, size_t values_count) const; ///< Wraps OrtApi::GetSymbolicDimensions
846
+
847
+ std::vector<int64_t> GetShape() const; ///< Uses GetDimensionsCount & GetDimensions to return a std::vector of the shape
848
+ };
849
+
850
+ } // namespace detail
851
+
852
+ using ConstTensorTypeAndShapeInfo = detail::TensorTypeAndShapeInfoImpl<detail::Unowned<const OrtTensorTypeAndShapeInfo>>;
853
+
854
+ /** \brief Wrapper around ::OrtTensorTypeAndShapeInfo
855
+ *
856
+ */
857
+ struct TensorTypeAndShapeInfo : detail::TensorTypeAndShapeInfoImpl<OrtTensorTypeAndShapeInfo> {
858
+ explicit TensorTypeAndShapeInfo(std::nullptr_t) {} ///< Create an empty TensorTypeAndShapeInfo object, must be assigned a valid one to be used
859
+ explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : TensorTypeAndShapeInfoImpl{p} {} ///< Used for interop with the C API
860
+ ConstTensorTypeAndShapeInfo GetConst() const { return ConstTensorTypeAndShapeInfo{this->p_}; }
861
+ };
862
+
863
+ namespace detail {
864
+ template <typename T>
865
+ struct SequenceTypeInfoImpl : Base<T> {
866
+ using B = Base<T>;
867
+ using B::B;
868
+ TypeInfo GetSequenceElementType() const; ///< Wraps OrtApi::GetSequenceElementType
869
+ };
870
+
871
+ } // namespace detail
872
+
873
+ using ConstSequenceTypeInfo = detail::SequenceTypeInfoImpl<detail::Unowned<const OrtSequenceTypeInfo>>;
874
+
875
+ /** \brief Wrapper around ::OrtSequenceTypeInfo
876
+ *
877
+ */
878
+ struct SequenceTypeInfo : detail::SequenceTypeInfoImpl<OrtSequenceTypeInfo> {
879
+ explicit SequenceTypeInfo(std::nullptr_t) {} ///< Create an empty SequenceTypeInfo object, must be assigned a valid one to be used
880
+ explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : SequenceTypeInfoImpl<OrtSequenceTypeInfo>{p} {} ///< Used for interop with the C API
881
+ ConstSequenceTypeInfo GetConst() const { return ConstSequenceTypeInfo{this->p_}; }
882
+ };
883
+
884
+ namespace detail {
885
+ template <typename T>
886
+ struct MapTypeInfoImpl : detail::Base<T> {
887
+ using B = Base<T>;
888
+ using B::B;
889
+ ONNXTensorElementDataType GetMapKeyType() const; ///< Wraps OrtApi::GetMapKeyType
890
+ TypeInfo GetMapValueType() const; ///< Wraps OrtApi::GetMapValueType
891
+ };
892
+
893
+ } // namespace detail
894
+
895
+ using ConstMapTypeInfo = detail::MapTypeInfoImpl<detail::Unowned<const OrtMapTypeInfo>>;
896
+
897
+ /** \brief Wrapper around ::OrtMapTypeInfo
898
+ *
899
+ */
900
+ struct MapTypeInfo : detail::MapTypeInfoImpl<OrtMapTypeInfo> {
901
+ explicit MapTypeInfo(std::nullptr_t) {} ///< Create an empty MapTypeInfo object, must be assigned a valid one to be used
902
+ explicit MapTypeInfo(OrtMapTypeInfo* p) : MapTypeInfoImpl<OrtMapTypeInfo>{p} {} ///< Used for interop with the C API
903
+ ConstMapTypeInfo GetConst() const { return ConstMapTypeInfo{this->p_}; }
904
+ };
905
+
906
+ namespace detail {
907
+ template <typename T>
908
+ struct TypeInfoImpl : detail::Base<T> {
909
+ using B = Base<T>;
910
+ using B::B;
911
+
912
+ ConstTensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; ///< Wraps OrtApi::CastTypeInfoToTensorInfo
913
+ ConstSequenceTypeInfo GetSequenceTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToSequenceTypeInfo
914
+ ConstMapTypeInfo GetMapTypeInfo() const; ///< Wraps OrtApi::CastTypeInfoToMapTypeInfo
915
+
916
+ ONNXType GetONNXType() const;
917
+ };
918
+ } // namespace detail
919
+
920
+ /// <summary>
921
+ /// Contains a constant, unowned OrtTypeInfo that can be copied and passed around by value.
922
+ /// Provides access to const OrtTypeInfo APIs.
923
+ /// </summary>
924
+ using ConstTypeInfo = detail::TypeInfoImpl<detail::Unowned<const OrtTypeInfo>>;
925
+
926
+ /// <summary>
927
+ /// Type information that may contain either TensorTypeAndShapeInfo or
928
+ /// the information about contained sequence or map depending on the ONNXType.
929
+ /// </summary>
930
+ struct TypeInfo : detail::TypeInfoImpl<OrtTypeInfo> {
931
+ explicit TypeInfo(std::nullptr_t) {} ///< Create an empty TypeInfo object, must be assigned a valid one to be used
932
+ explicit TypeInfo(OrtTypeInfo* p) : TypeInfoImpl<OrtTypeInfo>{p} {} ///< C API Interop
933
+
934
+ ConstTypeInfo GetConst() const { return ConstTypeInfo{this->p_}; }
935
+ };
936
+
937
+ namespace detail {
938
+ // This structure is used to feed sparse tensor values
939
+ // information for use with FillSparseTensor<Format>() API
940
+ // if the data type for the sparse tensor values is numeric
941
+ // use data.p_data, otherwise, use data.str pointer to feed
942
+ // values. data.str is an array of const char* that are zero terminated.
943
+ // number of strings in the array must match shape size.
944
+ // For fully sparse tensors use shape {0} and set p_data/str
945
+ // to nullptr.
946
+ struct OrtSparseValuesParam {
947
+ const int64_t* values_shape;
948
+ size_t values_shape_len;
949
+ union {
950
+ const void* p_data;
951
+ const char** str;
952
+ } data;
953
+ };
954
+
955
+ // Provides a way to pass shape in a single
956
+ // argument
957
+ struct Shape {
958
+ const int64_t* shape;
959
+ size_t shape_len;
960
+ };
961
+
962
+ template <typename T>
963
+ struct ConstValueImpl : Base<T> {
964
+ using B = Base<T>;
965
+ using B::B;
966
+
967
+ /// <summary>
968
+ /// Obtains a pointer to a user defined data for experimental purposes
969
+ /// </summary>
970
+ template <typename R>
971
+ void GetOpaqueData(const char* domain, const char* type_name, R&) const; ///< Wraps OrtApi::GetOpaqueValue
972
+
973
+ bool IsTensor() const; ///< Returns true if Value is a tensor, false for other types like map/sequence/etc
974
+ bool HasValue() const; /// < Return true if OrtValue contains data and returns false if the OrtValue is a None
975
+
976
+ size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
977
+ Value GetValue(int index, OrtAllocator* allocator) const;
978
+
979
+ /// <summary>
980
+ /// This API returns a full length of string data contained within either a tensor or a sparse Tensor.
981
+ /// For sparse tensor it returns a full length of stored non-empty strings (values). The API is useful
982
+ /// for allocating necessary memory and calling GetStringTensorContent().
983
+ /// </summary>
984
+ /// <returns>total length of UTF-8 encoded bytes contained. No zero terminators counted.</returns>
985
+ size_t GetStringTensorDataLength() const;
986
+
987
+ /// <summary>
988
+ /// The API copies all of the UTF-8 encoded string data contained within a tensor or a sparse tensor
989
+ /// into a supplied buffer. Use GetStringTensorDataLength() to find out the length of the buffer to allocate.
990
+ /// The user must also allocate offsets buffer with the number of entries equal to that of the contained
991
+ /// strings.
992
+ ///
993
+ /// Strings are always assumed to be on CPU, no X-device copy.
994
+ /// </summary>
995
+ /// <param name="buffer">user allocated buffer</param>
996
+ /// <param name="buffer_length">length in bytes of the allocated buffer</param>
997
+ /// <param name="offsets">a pointer to the offsets user allocated buffer</param>
998
+ /// <param name="offsets_count">count of offsets, must be equal to the number of strings contained.
999
+ /// that can be obtained from the shape of the tensor or from GetSparseTensorValuesTypeAndShapeInfo()
1000
+ /// for sparse tensors</param>
1001
+ void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
1002
+
1003
+ /// <summary>
1004
+ /// Returns a const typed pointer to the tensor contained data.
1005
+ /// No type checking is performed, the caller must ensure the type matches the tensor type.
1006
+ /// </summary>
1007
+ /// <typeparam name="T"></typeparam>
1008
+ /// <returns>const pointer to data, no copies made</returns>
1009
+ template <typename R>
1010
+ const R* GetTensorData() const; ///< Wraps OrtApi::GetTensorMutableData /// <summary>
1011
+
1012
+ /// <summary>
1013
+ /// Returns a non-typed pointer to a tensor contained data.
1014
+ /// </summary>
1015
+ /// <returns>const pointer to data, no copies made</returns>
1016
+ const void* GetTensorRawData() const;
1017
+
1018
+ /// <summary>
1019
+ /// The API returns type information for data contained in a tensor. For sparse
1020
+ /// tensors it returns type information for contained non-zero values.
1021
+ /// It returns dense shape for sparse tensors.
1022
+ /// </summary>
1023
+ /// <returns>TypeInfo</returns>
1024
+ TypeInfo GetTypeInfo() const;
1025
+
1026
+ /// <summary>
1027
+ /// The API returns type information for data contained in a tensor. For sparse
1028
+ /// tensors it returns type information for contained non-zero values.
1029
+ /// It returns dense shape for sparse tensors.
1030
+ /// </summary>
1031
+ /// <returns>TensorTypeAndShapeInfo</returns>
1032
+ TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
1033
+
1034
+ /// <summary>
1035
+ /// This API returns information about the memory allocation used to hold data.
1036
+ /// </summary>
1037
+ /// <returns>Non owning instance of MemoryInfo</returns>
1038
+ ConstMemoryInfo GetTensorMemoryInfo() const;
1039
+
1040
+ /// <summary>
1041
+ /// The API copies UTF-8 encoded bytes for the requested string element
1042
+ /// contained within a tensor or a sparse tensor into a provided buffer.
1043
+ /// Use GetStringTensorElementLength() to obtain the length of the buffer to allocate.
1044
+ /// </summary>
1045
+ /// <param name="buffer_length"></param>
1046
+ /// <param name="element_index"></param>
1047
+ /// <param name="buffer"></param>
1048
+ void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
1049
+
1050
+ /// <summary>
1051
+ /// The API returns a byte length of UTF-8 encoded string element
1052
+ /// contained in either a tensor or a spare tensor values.
1053
+ /// </summary>
1054
+ /// <param name="element_index"></param>
1055
+ /// <returns>byte length for the specified string element</returns>
1056
+ size_t GetStringTensorElementLength(size_t element_index) const;
1057
+
1058
+ #if !defined(DISABLE_SPARSE_TENSORS)
1059
+ /// <summary>
1060
+ /// The API returns the sparse data format this OrtValue holds in a sparse tensor.
1061
+ /// If the sparse tensor was not fully constructed, i.e. Use*() or Fill*() API were not used
1062
+ /// the value returned is ORT_SPARSE_UNDEFINED.
1063
+ /// </summary>
1064
+ /// <returns>Format enum</returns>
1065
+ OrtSparseFormat GetSparseFormat() const;
1066
+
1067
+ /// <summary>
1068
+ /// The API returns type and shape information for stored non-zero values of the
1069
+ /// sparse tensor. Use GetSparseTensorValues() to obtain values buffer pointer.
1070
+ /// </summary>
1071
+ /// <returns>TensorTypeAndShapeInfo values information</returns>
1072
+ TensorTypeAndShapeInfo GetSparseTensorValuesTypeAndShapeInfo() const;
1073
+
1074
+ /// <summary>
1075
+ /// The API returns type and shape information for the specified indices. Each supported
1076
+ /// indices have their own enum values even if a give format has more than one kind of indices.
1077
+ /// Use GetSparseTensorIndicesData() to obtain pointer to indices buffer.
1078
+ /// </summary>
1079
+ /// <param name="format">enum requested</param>
1080
+ /// <returns>type and shape information</returns>
1081
+ TensorTypeAndShapeInfo GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat format) const;
1082
+
1083
+ /// <summary>
1084
+ /// The API retrieves a pointer to the internal indices buffer. The API merely performs
1085
+ /// a convenience data type casting on the return type pointer. Make sure you are requesting
1086
+ /// the right type, use GetSparseTensorIndicesTypeShapeInfo();
1087
+ /// </summary>
1088
+ /// <typeparam name="T">type to cast to</typeparam>
1089
+ /// <param name="indices_format">requested indices kind</param>
1090
+ /// <param name="num_indices">number of indices entries</param>
1091
+ /// <returns>Pinter to the internal sparse tensor buffer containing indices. Do not free this pointer.</returns>
1092
+ template <typename R>
1093
+ const R* GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const;
1094
+
1095
+ /// <summary>
1096
+ /// Returns true if the OrtValue contains a sparse tensor
1097
+ /// </summary>
1098
+ /// <returns></returns>
1099
+ bool IsSparseTensor() const;
1100
+
1101
+ /// <summary>
1102
+ /// The API returns a pointer to an internal buffer of the sparse tensor
1103
+ /// containing non-zero values. The API merely does casting. Make sure you
1104
+ /// are requesting the right data type by calling GetSparseTensorValuesTypeAndShapeInfo()
1105
+ /// first.
1106
+ /// </summary>
1107
+ /// <typeparam name="T">numeric data types only. Use GetStringTensor*() to retrieve strings.</typeparam>
1108
+ /// <returns>a pointer to the internal values buffer. Do not free this pointer.</returns>
1109
+ template <typename R>
1110
+ const R* GetSparseTensorValues() const;
1111
+
1112
+ #endif
1113
+ };
1114
+
1115
+ template <typename T>
1116
+ struct ValueImpl : ConstValueImpl<T> {
1117
+ using B = ConstValueImpl<T>;
1118
+ using B::B;
1119
+
1120
+ /// <summary>
1121
+ /// Returns a non-const typed pointer to an OrtValue/Tensor contained buffer
1122
+ /// No type checking is performed, the caller must ensure the type matches the tensor type.
1123
+ /// </summary>
1124
+ /// <returns>non-const pointer to data, no copies made</returns>
1125
+ template <typename R>
1126
+ R* GetTensorMutableData();
1127
+
1128
+ /// <summary>
1129
+ /// Returns a non-typed non-const pointer to a tensor contained data.
1130
+ /// </summary>
1131
+ /// <returns>pointer to data, no copies made</returns>
1132
+ void* GetTensorMutableRawData();
1133
+
1134
+ /// <summary>
1135
+ // Obtain a reference to an element of data at the location specified
1136
+ /// by the vector of dims.
1137
+ /// </summary>
1138
+ /// <typeparam name="R"></typeparam>
1139
+ /// <param name="location">[in] expressed by a vecotr of dimensions offsets</param>
1140
+ /// <returns></returns>
1141
+ template <typename R>
1142
+ R& At(const std::vector<int64_t>& location);
1143
+
1144
+ /// <summary>
1145
+ /// Set all strings at once in a string tensor
1146
+ /// </summary>
1147
+ /// <param name="s">[in] An array of strings. Each string in this array must be null terminated.</param>
1148
+ /// <param name="s_len">[in] Count of strings in s (Must match the size of \p value's tensor shape)</param>
1149
+ void FillStringTensor(const char* const* s, size_t s_len);
1150
+
1151
+ /// <summary>
1152
+ /// Set a single string in a string tensor
1153
+ /// </summary>
1154
+ /// <param name="s">[in] A null terminated UTF-8 encoded string</param>
1155
+ /// <param name="index">[in] Index of the string in the tensor to set</param>
1156
+ void FillStringTensorElement(const char* s, size_t index);
1157
+
1158
+ #if !defined(DISABLE_SPARSE_TENSORS)
1159
+ /// <summary>
1160
+ /// Supplies COO format specific indices and marks the contained sparse tensor as being a COO format tensor.
1161
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1162
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1163
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1164
+ /// </summary>
1165
+ /// <param name="indices_data">pointer to the user allocated buffer with indices. Use nullptr for fully sparse tensors.</param>
1166
+ /// <param name="indices_num">number of indices entries. Use 0 for fully sparse tensors</param>
1167
+ void UseCooIndices(int64_t* indices_data, size_t indices_num);
1168
+
1169
+ /// <summary>
1170
+ /// Supplies CSR format specific indices and marks the contained sparse tensor as being a CSR format tensor.
1171
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1172
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1173
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1174
+ /// </summary>
1175
+ /// <param name="inner_data">pointer to the user allocated buffer with inner indices or nullptr for fully sparse tensors</param>
1176
+ /// <param name="inner_num">number of csr inner indices or 0 for fully sparse tensors</param>
1177
+ /// <param name="outer_data">pointer to the user allocated buffer with outer indices or nullptr for fully sparse tensors</param>
1178
+ /// <param name="outer_num">number of csr outer indices or 0 for fully sparse tensors</param>
1179
+ void UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num);
1180
+
1181
+ /// <summary>
1182
+ /// Supplies BlockSparse format specific indices and marks the contained sparse tensor as being a BlockSparse format tensor.
1183
+ /// Values are supplied with a CreateSparseTensor() API. The supplied indices are not copied and the user
1184
+ /// allocated buffers lifespan must eclipse that of the OrtValue.
1185
+ /// The location of the indices is assumed to be the same as specified by OrtMemoryInfo argument at the creation time.
1186
+ /// </summary>
1187
+ /// <param name="indices_shape">indices shape or a {0} for fully sparse</param>
1188
+ /// <param name="indices_data">user allocated buffer with indices or nullptr for fully spare tensors</param>
1189
+ void UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data);
1190
+
1191
+ /// <summary>
1192
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1193
+ /// and copy the values and COO indices into it. If data_mem_info specifies that the data is located
1194
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1195
+ /// </summary>
1196
+ /// <param name="data_mem_info">specified buffer memory description</param>
1197
+ /// <param name="values_param">values buffer information.</param>
1198
+ /// <param name="indices_data">coo indices buffer or nullptr for fully sparse data</param>
1199
+ /// <param name="indices_num">number of COO indices or 0 for fully sparse data</param>
1200
+ void FillSparseTensorCoo(const OrtMemoryInfo* data_mem_info, const OrtSparseValuesParam& values_param,
1201
+ const int64_t* indices_data, size_t indices_num);
1202
+
1203
+ /// <summary>
1204
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1205
+ /// and copy the values and CSR indices into it. If data_mem_info specifies that the data is located
1206
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1207
+ /// </summary>
1208
+ /// <param name="data_mem_info">specified buffer memory description</param>
1209
+ /// <param name="values">values buffer information</param>
1210
+ /// <param name="inner_indices_data">csr inner indices pointer or nullptr for fully sparse tensors</param>
1211
+ /// <param name="inner_indices_num">number of csr inner indices or 0 for fully sparse tensors</param>
1212
+ /// <param name="outer_indices_data">pointer to csr indices data or nullptr for fully sparse tensors</param>
1213
+ /// <param name="outer_indices_num">number of csr outer indices or 0</param>
1214
+ void FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1215
+ const OrtSparseValuesParam& values,
1216
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1217
+ const int64_t* outer_indices_data, size_t outer_indices_num);
1218
+
1219
+ /// <summary>
1220
+ /// The API will allocate memory using the allocator instance supplied to the CreateSparseTensor() API
1221
+ /// and copy the values and BlockSparse indices into it. If data_mem_info specifies that the data is located
1222
+ /// at difference device than the allocator, a X-device copy will be performed if possible.
1223
+ /// </summary>
1224
+ /// <param name="data_mem_info">specified buffer memory description</param>
1225
+ /// <param name="values">values buffer information</param>
1226
+ /// <param name="indices_shape">indices shape. use {0} for fully sparse tensors</param>
1227
+ /// <param name="indices_data">pointer to indices data or nullptr for fully sparse tensors</param>
1228
+ void FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1229
+ const OrtSparseValuesParam& values,
1230
+ const Shape& indices_shape,
1231
+ const int32_t* indices_data);
1232
+
1233
+ #endif
1234
+ };
1235
+
1236
+ } // namespace detail
1237
+
1238
+ using ConstValue = detail::ConstValueImpl<detail::Unowned<const OrtValue>>;
1239
+ using UnownedValue = detail::ValueImpl<detail::Unowned<OrtValue>>;
1240
+
1241
+ /** \brief Wrapper around ::OrtValue
1242
+ *
1243
+ */
1244
+ struct Value : detail::ValueImpl<OrtValue> {
1245
+ using Base = detail::ValueImpl<OrtValue>;
1246
+ using OrtSparseValuesParam = detail::OrtSparseValuesParam;
1247
+ using Shape = detail::Shape;
1248
+
1249
+ explicit Value(std::nullptr_t) {} ///< Create an empty Value object, must be assigned a valid one to be used
1250
+ explicit Value(OrtValue* p) : Base{p} {} ///< Used for interop with the C API
1251
+ Value(Value&&) = default;
1252
+ Value& operator=(Value&&) = default;
1253
+
1254
+ ConstValue GetConst() const { return ConstValue{this->p_}; }
1255
+ UnownedValue GetUnowned() const { return UnownedValue{this->p_}; }
1256
+
1257
+ /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1258
+ * \tparam T The numeric datatype. This API is not suitable for strings.
1259
+ * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1260
+ * \param p_data Pointer to the data buffer.
1261
+ * \param p_data_element_count The number of elements in the data buffer.
1262
+ * \param shape Pointer to the tensor shape dimensions.
1263
+ * \param shape_len The number of tensor shape dimensions.
1264
+ */
1265
+ template <typename T>
1266
+ static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
1267
+
1268
+ /** \brief Creates a tensor with a user supplied buffer. Wraps OrtApi::CreateTensorWithDataAsOrtValue.
1269
+ * \param info Memory description of where the p_data buffer resides (CPU vs GPU etc).
1270
+ * \param p_data Pointer to the data buffer.
1271
+ * \param p_data_byte_count The number of bytes in the data buffer.
1272
+ * \param shape Pointer to the tensor shape dimensions.
1273
+ * \param shape_len The number of tensor shape dimensions.
1274
+ * \param type The data type.
1275
+ */
1276
+ static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1277
+ ONNXTensorElementDataType type);
1278
+
1279
+ /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1280
+ * \tparam T The numeric datatype. This API is not suitable for strings.
1281
+ * \param allocator The allocator to use.
1282
+ * \param shape Pointer to the tensor shape dimensions.
1283
+ * \param shape_len The number of tensor shape dimensions.
1284
+ */
1285
+ template <typename T>
1286
+ static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
1287
+
1288
+ /** \brief Creates a tensor using a supplied OrtAllocator. Wraps OrtApi::CreateTensorAsOrtValue.
1289
+ * \param allocator The allocator to use.
1290
+ * \param shape Pointer to the tensor shape dimensions.
1291
+ * \param shape_len The number of tensor shape dimensions.
1292
+ * \param type The data type.
1293
+ */
1294
+ static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
1295
+
1296
+ static Value CreateMap(Value& keys, Value& values); ///< Wraps OrtApi::CreateValue
1297
+ static Value CreateSequence(std::vector<Value>& values); ///< Wraps OrtApi::CreateValue
1298
+
1299
+ template <typename T>
1300
+ static Value CreateOpaque(const char* domain, const char* type_name, const T&); ///< Wraps OrtApi::CreateOpaqueValue
1301
+
1302
+ #if !defined(DISABLE_SPARSE_TENSORS)
1303
+ /// <summary>
1304
+ /// This is a simple forwarding method to the other overload that helps deducing
1305
+ /// data type enum value from the type of the buffer.
1306
+ /// </summary>
1307
+ /// <typeparam name="T">numeric datatype. This API is not suitable for strings.</typeparam>
1308
+ /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1309
+ /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1310
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1311
+ /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1312
+ /// <returns></returns>
1313
+ template <typename T>
1314
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1315
+ const Shape& values_shape);
1316
+
1317
+ /// <summary>
1318
+ /// Creates an OrtValue instance containing SparseTensor. This constructs
1319
+ /// a sparse tensor that makes use of user allocated buffers. It does not make copies
1320
+ /// of the user provided data and does not modify it. The lifespan of user provided buffers should
1321
+ /// eclipse the life span of the resulting OrtValue. This call constructs an instance that only contain
1322
+ /// a pointer to non-zero values. To fully populate the sparse tensor call Use<Format>Indices() API below
1323
+ /// to supply a sparse format specific indices.
1324
+ /// This API is not suitable for string data. Use CreateSparseTensor() with allocator specified so strings
1325
+ /// can be properly copied into the allocated buffer.
1326
+ /// </summary>
1327
+ /// <param name="info">Memory description where the user buffers reside (CPU vs GPU etc)</param>
1328
+ /// <param name="p_data">pointer to the user supplied buffer, use nullptr for fully sparse tensors</param>
1329
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1330
+ /// <param name="values_shape">non zero values shape. Use a single 0 shape for fully sparse tensors.</param>
1331
+ /// <param name="type">data type</param>
1332
+ /// <returns>Ort::Value instance containing SparseTensor</returns>
1333
+ static Value CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1334
+ const Shape& values_shape, ONNXTensorElementDataType type);
1335
+
1336
+ /// <summary>
1337
+ /// This is a simple forwarding method to the below CreateSparseTensor.
1338
+ /// This helps to specify data type enum in terms of C++ data type.
1339
+ /// Use CreateSparseTensor<T>
1340
+ /// </summary>
1341
+ /// <typeparam name="T">numeric data type only. String data enum must be specified explicitly.</typeparam>
1342
+ /// <param name="allocator">allocator to use</param>
1343
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1344
+ /// <returns>Ort::Value</returns>
1345
+ template <typename T>
1346
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape);
1347
+
1348
+ /// <summary>
1349
+ /// Creates an instance of OrtValue containing sparse tensor. The created instance has no data.
1350
+ /// The data must be supplied by on of the FillSparseTensor<Format>() methods that take both non-zero values
1351
+ /// and indices. The data will be copied into a buffer that would be allocated using the supplied allocator.
1352
+ /// Use this API to create OrtValues that contain sparse tensors with all supported data types including
1353
+ /// strings.
1354
+ /// </summary>
1355
+ /// <param name="allocator">allocator to use. The allocator lifespan must eclipse that of the resulting OrtValue</param>
1356
+ /// <param name="dense_shape">a would be dense shape of the tensor</param>
1357
+ /// <param name="type">data type</param>
1358
+ /// <returns>an instance of Ort::Value</returns>
1359
+ static Value CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape, ONNXTensorElementDataType type);
1360
+
1361
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1362
+ };
1363
+
1364
+ /// <summary>
1365
+ /// Represents native memory allocation coming from one of the
1366
+ /// OrtAllocators registered with OnnxRuntime.
1367
+ /// Use it to wrap an allocation made by an allocator
1368
+ /// so it can be automatically released when no longer needed.
1369
+ /// </summary>
1370
+ struct MemoryAllocation {
1371
+ MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
1372
+ ~MemoryAllocation();
1373
+ MemoryAllocation(const MemoryAllocation&) = delete;
1374
+ MemoryAllocation& operator=(const MemoryAllocation&) = delete;
1375
+ MemoryAllocation(MemoryAllocation&&) noexcept;
1376
+ MemoryAllocation& operator=(MemoryAllocation&&) noexcept;
1377
+
1378
+ void* get() { return p_; }
1379
+ size_t size() const { return size_; }
1380
+
1381
+ private:
1382
+ OrtAllocator* allocator_;
1383
+ void* p_;
1384
+ size_t size_;
1385
+ };
1386
+
1387
+ namespace detail {
1388
+ template <typename T>
1389
+ struct AllocatorImpl : Base<T> {
1390
+ using B = Base<T>;
1391
+ using B::B;
1392
+
1393
+ void* Alloc(size_t size);
1394
+ MemoryAllocation GetAllocation(size_t size);
1395
+ void Free(void* p);
1396
+ ConstMemoryInfo GetInfo() const;
1397
+ };
1398
+
1399
+ } // namespace detail
1400
+
1401
+ /** \brief Wrapper around ::OrtAllocator default instance that is owned by Onnxruntime
1402
+ *
1403
+ */
1404
+ struct AllocatorWithDefaultOptions : detail::AllocatorImpl<detail::Unowned<OrtAllocator>> {
1405
+ explicit AllocatorWithDefaultOptions(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1406
+ AllocatorWithDefaultOptions();
1407
+ };
1408
+
1409
+ /** \brief Wrapper around ::OrtAllocator
1410
+ *
1411
+ */
1412
+ struct Allocator : detail::AllocatorImpl<OrtAllocator> {
1413
+ explicit Allocator(std::nullptr_t) {} ///< Convenience to create a class member and then replace with an instance
1414
+ Allocator(const Session& session, const OrtMemoryInfo*);
1415
+ };
1416
+
1417
+ using UnownedAllocator = detail::AllocatorImpl<detail::Unowned<OrtAllocator>>;
1418
+
1419
+ namespace detail {
1420
+ namespace binding_utils {
1421
+ // Bring these out of template
1422
+ std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator*);
1423
+ std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator*);
1424
+ } // namespace binding_utils
1425
+
1426
+ template <typename T>
1427
+ struct ConstIoBindingImpl : Base<T> {
1428
+ using B = Base<T>;
1429
+ using B::B;
1430
+
1431
+ std::vector<std::string> GetOutputNames() const;
1432
+ std::vector<std::string> GetOutputNames(OrtAllocator*) const;
1433
+ std::vector<Value> GetOutputValues() const;
1434
+ std::vector<Value> GetOutputValues(OrtAllocator*) const;
1435
+ };
1436
+
1437
+ template <typename T>
1438
+ struct IoBindingImpl : ConstIoBindingImpl<T> {
1439
+ using B = ConstIoBindingImpl<T>;
1440
+ using B::B;
1441
+
1442
+ void BindInput(const char* name, const Value&);
1443
+ void BindOutput(const char* name, const Value&);
1444
+ void BindOutput(const char* name, const OrtMemoryInfo*);
1445
+ void ClearBoundInputs();
1446
+ void ClearBoundOutputs();
1447
+ void SynchronizeInputs();
1448
+ void SynchronizeOutputs();
1449
+ };
1450
+
1451
+ } // namespace detail
1452
+
1453
+ using ConstIoBinding = detail::ConstIoBindingImpl<detail::Unowned<const OrtIoBinding>>;
1454
+ using UnownedIoBinding = detail::IoBindingImpl<detail::Unowned<OrtIoBinding>>;
1455
+
1456
+ /** \brief Wrapper around ::OrtIoBinding
1457
+ *
1458
+ */
1459
+ struct IoBinding : detail::IoBindingImpl<OrtIoBinding> {
1460
+ explicit IoBinding(std::nullptr_t) {} ///< Create an empty object for convenience. Sometimes, we want to initialize members later.
1461
+ explicit IoBinding(Session& session);
1462
+ ConstIoBinding GetConst() const { return ConstIoBinding{this->p_}; }
1463
+ UnownedIoBinding GetUnowned() const { return UnownedIoBinding{this->p_}; }
1464
+ };
1465
+
1466
+ /*! \struct Ort::ArenaCfg
1467
+ * \brief it is a structure that represents the configuration of an arena based allocator
1468
+ * \details Please see docs/C_API.md for details
1469
+ */
1470
+ struct ArenaCfg : detail::Base<OrtArenaCfg> {
1471
+ explicit ArenaCfg(std::nullptr_t) {} ///< Create an empty ArenaCfg object, must be assigned a valid one to be used
1472
+ /**
1473
+ * Wraps OrtApi::CreateArenaCfg
1474
+ * \param max_mem - use 0 to allow ORT to choose the default
1475
+ * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
1476
+ * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
1477
+ * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
1478
+ * See docs/C_API.md for details on what the following parameters mean and how to choose these values
1479
+ */
1480
+ ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
1481
+ };
1482
+
1483
+ //
1484
+ // Custom OPs (only needed to implement custom OPs)
1485
+ //
1486
+
1487
+ /// <summary>
1488
+ /// This struct provides life time management for custom op attribute
1489
+ /// </summary>
1490
+ struct OpAttr : detail::Base<OrtOpAttr> {
1491
+ OpAttr(const char* name, const void* data, int len, OrtOpAttrType type);
1492
+ };
1493
+
1494
+ /// <summary>
1495
+ /// This class wraps a raw pointer OrtKernelContext* that is being passed
1496
+ /// to the custom kernel Compute() method. Use it to safely access context
1497
+ /// attributes, input and output parameters with exception safety guarantees.
1498
+ /// See usage example in onnxruntime/test/testdata/custom_op_library/custom_op_library.cc
1499
+ /// </summary>
1500
+ struct KernelContext {
1501
+ explicit KernelContext(OrtKernelContext* context);
1502
+ size_t GetInputCount() const;
1503
+ size_t GetOutputCount() const;
1504
+ ConstValue GetInput(size_t index) const;
1505
+ UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
1506
+ UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
1507
+ void* GetGPUComputeStream() const;
1508
+
1509
+ private:
1510
+ OrtKernelContext* ctx_;
1511
+ };
1512
+
1513
+ struct KernelInfo;
1514
+
1515
+ namespace detail {
1516
+ namespace attr_utils {
1517
+ void GetAttr(const OrtKernelInfo* p, const char* name, float&);
1518
+ void GetAttr(const OrtKernelInfo* p, const char* name, int64_t&);
1519
+ void GetAttr(const OrtKernelInfo* p, const char* name, std::string&);
1520
+ void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>&);
1521
+ void GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>&);
1522
+ } // namespace attr_utils
1523
+
1524
+ template <typename T>
1525
+ struct KernelInfoImpl : Base<T> {
1526
+ using B = Base<T>;
1527
+ using B::B;
1528
+
1529
+ KernelInfo Copy() const;
1530
+
1531
+ template <typename R> // R is only implemented for float, int64_t, and string
1532
+ R GetAttribute(const char* name) const {
1533
+ R val;
1534
+ attr_utils::GetAttr(this->p_, name, val);
1535
+ return val;
1536
+ }
1537
+
1538
+ template <typename R> // R is only implemented for std::vector<float>, std::vector<int64_t>
1539
+ std::vector<R> GetAttributes(const char* name) const {
1540
+ std::vector<R> result;
1541
+ attr_utils::GetAttrs(this->p_, name, result);
1542
+ return result;
1543
+ }
1544
+
1545
+ Value GetTensorAttribute(const char* name, OrtAllocator* allocator) const;
1546
+
1547
+ size_t GetInputCount() const;
1548
+ size_t GetOutputCount() const;
1549
+
1550
+ std::string GetInputName(size_t index) const;
1551
+ std::string GetOutputName(size_t index) const;
1552
+
1553
+ TypeInfo GetInputTypeInfo(size_t index) const;
1554
+ TypeInfo GetOutputTypeInfo(size_t index) const;
1555
+ };
1556
+
1557
+ } // namespace detail
1558
+
1559
+ using ConstKernelInfo = detail::KernelInfoImpl<detail::Unowned<const OrtKernelInfo>>;
1560
+
1561
+ /// <summary>
1562
+ /// This struct owns the OrtKernInfo* pointer when a copy is made.
1563
+ /// For convenient wrapping of OrtKernelInfo* passed to kernel constructor
1564
+ /// and query attributes, warp the pointer with Ort::Unowned<KernelInfo> instance
1565
+ /// so it does not destroy the pointer the kernel does not own.
1566
+ /// </summary>
1567
+ struct KernelInfo : detail::KernelInfoImpl<OrtKernelInfo> {
1568
+ explicit KernelInfo(std::nullptr_t) {} ///< Create an empty instance to initialize later
1569
+ explicit KernelInfo(OrtKernelInfo* info); ///< Take ownership of the instance
1570
+ ConstKernelInfo GetConst() const { return ConstKernelInfo{this->p_}; }
1571
+ };
1572
+
1573
+ /// <summary>
1574
+ /// Create and own custom defined operation.
1575
+ /// </summary>
1576
+ struct Op : detail::Base<OrtOp> {
1577
+ explicit Op(std::nullptr_t) {} ///< Create an empty Operator object, must be assigned a valid one to be used
1578
+
1579
+ explicit Op(OrtOp*); ///< Take ownership of the OrtOp
1580
+
1581
+ static Op Create(const OrtKernelInfo* info, const char* op_name, const char* domain,
1582
+ int version, const char** type_constraint_names,
1583
+ const ONNXTensorElementDataType* type_constraint_values,
1584
+ size_t type_constraint_count,
1585
+ const OpAttr* attr_values,
1586
+ size_t attr_count,
1587
+ size_t input_count, size_t output_count);
1588
+
1589
+ void Invoke(const OrtKernelContext* context,
1590
+ const Value* input_values,
1591
+ size_t input_count,
1592
+ Value* output_values,
1593
+ size_t output_count);
1594
+
1595
+ // For easier refactoring
1596
+ void Invoke(const OrtKernelContext* context,
1597
+ const OrtValue* const* input_values,
1598
+ size_t input_count,
1599
+ OrtValue* const* output_values,
1600
+ size_t output_count);
1601
+ };
1602
+
1603
+ /// <summary>
1604
+ /// This entire structure is deprecated, but we not marking
1605
+ /// it as a whole yet since we want to preserve for the next release.
1606
+ /// </summary>
1607
+ struct CustomOpApi {
1608
+ CustomOpApi(const OrtApi& api) : api_(api) {}
1609
+
1610
+ /** \deprecated use Ort::Value::GetTensorTypeAndShape()
1611
+ * [[deprecated]]
1612
+ * This interface produces a pointer that must be released. Not exception safe.
1613
+ */
1614
+ [[deprecated("use Ort::Value::GetTensorTypeAndShape()")]] OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
1615
+
1616
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementCount()
1617
+ * [[deprecated]]
1618
+ * This interface is redundant.
1619
+ */
1620
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementCount()")]] size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1621
+
1622
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetElementType()
1623
+ * [[deprecated]]
1624
+ * This interface is redundant.
1625
+ */
1626
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetElementType()")]] ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
1627
+
1628
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()
1629
+ * [[deprecated]]
1630
+ * This interface is redundant.
1631
+ */
1632
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetDimensionsCount()")]] size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
1633
+
1634
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1635
+ * [[deprecated]]
1636
+ * This interface is redundant.
1637
+ */
1638
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
1639
+
1640
+ /** \deprecated
1641
+ * [[deprecated]]
1642
+ * This interface sets dimensions to TensorTypeAndShapeInfo, but has no effect on the OrtValue.
1643
+ */
1644
+ [[deprecated("Do not use")]] void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
1645
+
1646
+ /** \deprecated use Ort::Value::GetTensorMutableData()
1647
+ * [[deprecated]]
1648
+ * This interface is redundant.
1649
+ */
1650
+ template <typename T>
1651
+ [[deprecated("use Ort::Value::GetTensorMutableData()")]] T* GetTensorMutableData(_Inout_ OrtValue* value);
1652
+
1653
+ /** \deprecated use Ort::Value::GetTensorData()
1654
+ * [[deprecated]]
1655
+ * This interface is redundant.
1656
+ */
1657
+ template <typename T>
1658
+ [[deprecated("use Ort::Value::GetTensorData()")]] const T* GetTensorData(_Inout_ const OrtValue* value);
1659
+
1660
+ /** \deprecated use Ort::Value::GetTensorMemoryInfo()
1661
+ * [[deprecated]]
1662
+ * This interface is redundant.
1663
+ */
1664
+ [[deprecated("use Ort::Value::GetTensorMemoryInfo()")]] const OrtMemoryInfo* GetTensorMemoryInfo(_In_ const OrtValue* value);
1665
+
1666
+ /** \deprecated use Ort::TensorTypeAndShapeInfo::GetShape()
1667
+ * [[deprecated]]
1668
+ * This interface is redundant.
1669
+ */
1670
+ [[deprecated("use Ort::TensorTypeAndShapeInfo::GetShape()")]] std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
1671
+
1672
+ /** \deprecated use TensorTypeAndShapeInfo instances for automatic ownership.
1673
+ * [[deprecated]]
1674
+ * This interface is not exception safe.
1675
+ */
1676
+ [[deprecated("use TensorTypeAndShapeInfo")]] void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
1677
+
1678
+ /** \deprecated use Ort::KernelContext::GetInputCount
1679
+ * [[deprecated]]
1680
+ * This interface is redundant.
1681
+ */
1682
+ [[deprecated("use Ort::KernelContext::GetInputCount")]] size_t KernelContext_GetInputCount(const OrtKernelContext* context);
1683
+
1684
+ /** \deprecated use Ort::KernelContext::GetInput
1685
+ * [[deprecated]]
1686
+ * This interface is redundant.
1687
+ */
1688
+ [[deprecated("use Ort::KernelContext::GetInput")]] const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
1689
+
1690
+ /** \deprecated use Ort::KernelContext::GetOutputCount
1691
+ * [[deprecated]]
1692
+ * This interface is redundant.
1693
+ */
1694
+ [[deprecated("use Ort::KernelContext::GetOutputCount")]] size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
1695
+
1696
+ /** \deprecated use Ort::KernelContext::GetOutput
1697
+ * [[deprecated]]
1698
+ * This interface is redundant.
1699
+ */
1700
+ [[deprecated("use Ort::KernelContext::GetOutput")]] OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
1701
+
1702
+ /** \deprecated use Ort::KernelContext::GetGPUComputeStream
1703
+ * [[deprecated]]
1704
+ * This interface is redundant.
1705
+ */
1706
+ [[deprecated("use Ort::KernelContext::GetGPUComputeStream")]] void* KernelContext_GetGPUComputeStream(const OrtKernelContext* context);
1707
+
1708
+ /** \deprecated use Ort::ThrowOnError()
1709
+ * [[deprecated]]
1710
+ * This interface is redundant.
1711
+ */
1712
+ [[deprecated("use Ort::ThrowOnError()")]] void ThrowOnError(OrtStatus* result);
1713
+
1714
+ /** \deprecated use Ort::OpAttr
1715
+ * [[deprecated]]
1716
+ * This interface is not exception safe.
1717
+ */
1718
+ [[deprecated("use Ort::OpAttr")]] OrtOpAttr* CreateOpAttr(_In_ const char* name,
1719
+ _In_ const void* data,
1720
+ _In_ int len,
1721
+ _In_ OrtOpAttrType type);
1722
+
1723
+ /** \deprecated use Ort::OpAttr
1724
+ * [[deprecated]]
1725
+ * This interface is not exception safe.
1726
+ */
1727
+ [[deprecated("use Ort::OpAttr")]] void ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr);
1728
+
1729
+ /** \deprecated use Ort::Op
1730
+ * [[deprecated]]
1731
+ * This interface is not exception safe.
1732
+ */
1733
+ [[deprecated("use Ort::Op")]] OrtOp* CreateOp(_In_ const OrtKernelInfo* info,
1734
+ _In_ const char* op_name,
1735
+ _In_ const char* domain,
1736
+ _In_ int version,
1737
+ _In_opt_ const char** type_constraint_names,
1738
+ _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1739
+ _In_opt_ int type_constraint_count,
1740
+ _In_opt_ const OrtOpAttr* const* attr_values,
1741
+ _In_opt_ int attr_count,
1742
+ _In_ int input_count,
1743
+ _In_ int output_count);
1744
+
1745
+ /** \deprecated use Ort::Op::Invoke
1746
+ * [[deprecated]]
1747
+ * This interface is redundant
1748
+ */
1749
+ [[deprecated("use Ort::Op::Invoke")]] void InvokeOp(_In_ const OrtKernelContext* context,
1750
+ _In_ const OrtOp* ort_op,
1751
+ _In_ const OrtValue* const* input_values,
1752
+ _In_ int input_count,
1753
+ _Inout_ OrtValue* const* output_values,
1754
+ _In_ int output_count);
1755
+
1756
+ /** \deprecated use Ort::Op for automatic lifespan management.
1757
+ * [[deprecated]]
1758
+ * This interface is not exception safe.
1759
+ */
1760
+ [[deprecated("use Ort::Op")]] void ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op);
1761
+
1762
+ /** \deprecated use Ort::KernelInfo for automatic lifespan management or for
1763
+ * querying attributes
1764
+ * [[deprecated]]
1765
+ * This interface is redundant
1766
+ */
1767
+ template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
1768
+ [[deprecated("use Ort::KernelInfo::GetAttribute")]] T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
1769
+
1770
+ /** \deprecated use Ort::KernelInfo::Copy
1771
+ * querying attributes
1772
+ * [[deprecated]]
1773
+ * This interface is not exception safe
1774
+ */
1775
+ [[deprecated("use Ort::KernelInfo::Copy")]] OrtKernelInfo* CopyKernelInfo(_In_ const OrtKernelInfo* info);
1776
+
1777
+ /** \deprecated use Ort::KernelInfo for lifespan management
1778
+ * querying attributes
1779
+ * [[deprecated]]
1780
+ * This interface is not exception safe
1781
+ */
1782
+ [[deprecated("use Ort::KernelInfo")]] void ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy);
1783
+
1784
+ private:
1785
+ const OrtApi& api_;
1786
+ };
1787
+
1788
+ template <typename TOp, typename TKernel>
1789
+ struct CustomOpBase : OrtCustomOp {
1790
+ CustomOpBase() {
1791
+ OrtCustomOp::version = ORT_API_VERSION;
1792
+ OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
1793
+ OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
1794
+
1795
+ OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
1796
+
1797
+ OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
1798
+ OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
1799
+ OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputMemoryType(index); };
1800
+
1801
+ OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
1802
+ OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
1803
+
1804
+ OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
1805
+ #if defined(_MSC_VER) && !defined(__clang__)
1806
+ #pragma warning(push)
1807
+ #pragma warning(disable : 26409)
1808
+ #endif
1809
+ OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
1810
+ #if defined(_MSC_VER) && !defined(__clang__)
1811
+ #pragma warning(pop)
1812
+ #endif
1813
+ OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputCharacteristic(index); };
1814
+ OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index); };
1815
+
1816
+ OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicInputMinArity(); };
1817
+ OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicInputHomogeneity()); };
1818
+ OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetVariadicOutputMinArity(); };
1819
+ OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp* this_) { return static_cast<int>(static_cast<const TOp*>(this_)->GetVariadicOutputHomogeneity()); };
1820
+ }
1821
+
1822
+ // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
1823
+ const char* GetExecutionProviderType() const { return nullptr; }
1824
+
1825
+ // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
1826
+ // (inputs and outputs are required by default)
1827
+ OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
1828
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
1829
+ }
1830
+
1831
+ OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
1832
+ return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
1833
+ }
1834
+
1835
+ // Default implemention of GetInputMemoryType() that returns OrtMemTypeDefault
1836
+ OrtMemType GetInputMemoryType(size_t /*index*/) const {
1837
+ return OrtMemTypeDefault;
1838
+ }
1839
+
1840
+ // Default implementation of GetVariadicInputMinArity() returns 1 to specify that a variadic input
1841
+ // should expect at least 1 argument.
1842
+ int GetVariadicInputMinArity() const {
1843
+ return 1;
1844
+ }
1845
+
1846
+ // Default implementation of GetVariadicInputHomegeneity() returns true to specify that all arguments
1847
+ // to a variadic input should be of the same type.
1848
+ bool GetVariadicInputHomogeneity() const {
1849
+ return true;
1850
+ }
1851
+
1852
+ // Default implementation of GetVariadicOutputMinArity() returns 1 to specify that a variadic output
1853
+ // should produce at least 1 output value.
1854
+ int GetVariadicOutputMinArity() const {
1855
+ return 1;
1856
+ }
1857
+
1858
+ // Default implementation of GetVariadicOutputHomegeneity() returns true to specify that all output values
1859
+ // produced by a variadic output should be of the same type.
1860
+ bool GetVariadicOutputHomogeneity() const {
1861
+ return true;
1862
+ }
1863
+
1864
+ // Declare list of session config entries used by this Custom Op.
1865
+ // Implement this function in order to get configs from CustomOpBase::GetSessionConfigs().
1866
+ // This default implementation returns an empty vector of config entries.
1867
+ std::vector<std::string> GetSessionConfigKeys() const {
1868
+ return std::vector<std::string>{};
1869
+ }
1870
+
1871
+ protected:
1872
+ // Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
1873
+ void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;
1874
+ };
1875
+
1876
+ } // namespace Ort
1877
+
1878
+ #include "onnxruntime_cxx_inline.h"
1.15.1/onnxruntime.xcframework/Headers/onnxruntime_cxx_inline.h ADDED
@@ -0,0 +1,1888 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Microsoft Corporation. All rights reserved.
2
+ // Licensed under the MIT License.
3
+
4
+ // Do not include this file directly. Please include "onnxruntime_cxx_api.h" instead.
5
+ // If interested in trying out features of the new experimental C++ API, include "experimental_onnxruntime_cxx_api.h" instead.
6
+ //
7
+ // These are the inline implementations of the C++ header APIs. They're in this separate file as to not clutter
8
+ // the main C++ file with implementation details.
9
+
10
+ namespace Ort {
11
+
12
+ namespace detail {
13
+ inline void ThrowStatus(const Status& st) {
14
+ std::string error_message = st.GetErrorMessage();
15
+ OrtErrorCode error_code = st.GetErrorCode();
16
+ ORT_CXX_API_THROW(std::move(error_message), error_code);
17
+ }
18
+ } // namespace detail
19
+
20
+ inline void ThrowOnError(OrtStatus* ort_status) {
21
+ if (ort_status) {
22
+ Ort::Status st(ort_status);
23
+ detail::ThrowStatus(st);
24
+ }
25
+ }
26
+
27
+ inline void ThrowOnError(const Status& st) {
28
+ if (st) {
29
+ detail::ThrowStatus(st);
30
+ }
31
+ }
32
+
33
+ inline Status::Status(OrtStatus* status) : Base<OrtStatus>{status} {
34
+ }
35
+
36
+ inline Status::Status(const std::exception& e) {
37
+ p_ = GetApi().CreateStatus(ORT_FAIL, e.what());
38
+ }
39
+
40
+ inline Status::Status(const Exception& e) {
41
+ p_ = GetApi().CreateStatus(e.GetOrtErrorCode(), e.what());
42
+ }
43
+
44
+ inline std::string Status::GetErrorMessage() const {
45
+ std::string message(GetApi().GetErrorMessage(p_));
46
+ return message;
47
+ }
48
+
49
+ inline OrtErrorCode Status::GetErrorCode() const {
50
+ return GetApi().GetErrorCode(p_);
51
+ }
52
+
53
+ // This template converts a C++ type into it's ONNXTensorElementDataType
54
+ template <typename T>
55
+ struct TypeToTensorType;
56
+ template <>
57
+ struct TypeToTensorType<float> {
58
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
59
+ };
60
+ template <>
61
+ struct TypeToTensorType<Float16_t> {
62
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
63
+ };
64
+ template <>
65
+ struct TypeToTensorType<BFloat16_t> {
66
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16;
67
+ };
68
+ template <>
69
+ struct TypeToTensorType<double> {
70
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
71
+ };
72
+ template <>
73
+ struct TypeToTensorType<int8_t> {
74
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
75
+ };
76
+ template <>
77
+ struct TypeToTensorType<int16_t> {
78
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
79
+ };
80
+ template <>
81
+ struct TypeToTensorType<int32_t> {
82
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
83
+ };
84
+ template <>
85
+ struct TypeToTensorType<int64_t> {
86
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
87
+ };
88
+ template <>
89
+ struct TypeToTensorType<uint8_t> {
90
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
91
+ };
92
+ template <>
93
+ struct TypeToTensorType<uint16_t> {
94
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
95
+ };
96
+ template <>
97
+ struct TypeToTensorType<uint32_t> {
98
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
99
+ };
100
+ template <>
101
+ struct TypeToTensorType<uint64_t> {
102
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
103
+ };
104
+ template <>
105
+ struct TypeToTensorType<bool> {
106
+ static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
107
+ };
108
+
109
+ inline MemoryAllocation::MemoryAllocation(OrtAllocator* allocator, void* p, size_t size)
110
+ : allocator_(allocator), p_(p), size_(size) {
111
+ }
112
+
113
+ inline MemoryAllocation::~MemoryAllocation() {
114
+ if (p_ != nullptr) {
115
+ // We do not throw out of destructor
116
+ auto ret = GetApi().AllocatorFree(allocator_, p_);
117
+ static_cast<void>(ret);
118
+ }
119
+ }
120
+
121
+ inline MemoryAllocation::MemoryAllocation(MemoryAllocation&& o) noexcept : allocator_(nullptr), p_(nullptr), size_(0) {
122
+ *this = std::move(o);
123
+ }
124
+
125
+ inline MemoryAllocation& MemoryAllocation::operator=(MemoryAllocation&& o) noexcept {
126
+ OrtAllocator* alloc = nullptr;
127
+ void* p = nullptr;
128
+ size_t sz = 0;
129
+
130
+ // Swap out this
131
+ std::swap(alloc, allocator_);
132
+ std::swap(p, p_);
133
+ std::swap(sz, size_);
134
+
135
+ // Swap with incoming
136
+ std::swap(allocator_, o.allocator_);
137
+ std::swap(p_, o.p_);
138
+ std::swap(size_, o.size_);
139
+
140
+ // Destroy this instance if needed
141
+ MemoryAllocation this_alloc(alloc, p, sz);
142
+ return *this;
143
+ }
144
+
145
+ namespace detail {
146
+
147
+ template <typename T>
148
+ inline void* AllocatorImpl<T>::Alloc(size_t size) {
149
+ void* out;
150
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
151
+ return out;
152
+ }
153
+
154
+ template <typename T>
155
+ inline MemoryAllocation AllocatorImpl<T>::GetAllocation(size_t size) {
156
+ void* out;
157
+ ThrowOnError(GetApi().AllocatorAlloc(this->p_, size, &out));
158
+ MemoryAllocation result(this->p_, out, size);
159
+ return result;
160
+ }
161
+
162
+ template <typename T>
163
+ inline void AllocatorImpl<T>::Free(void* p) {
164
+ ThrowOnError(GetApi().AllocatorFree(this->p_, p));
165
+ }
166
+
167
+ template <typename T>
168
+ inline ConstMemoryInfo AllocatorImpl<T>::GetInfo() const {
169
+ const OrtMemoryInfo* out;
170
+ ThrowOnError(GetApi().AllocatorGetInfo(this->p_, &out));
171
+ return ConstMemoryInfo{out};
172
+ }
173
+
174
+ } // namespace detail
175
+
176
+ inline AllocatorWithDefaultOptions::AllocatorWithDefaultOptions() {
177
+ ThrowOnError(GetApi().GetAllocatorWithDefaultOptions(&this->p_));
178
+ }
179
+
180
+ inline Allocator::Allocator(const Session& sess, const OrtMemoryInfo* mem_info) {
181
+ ThrowOnError(GetApi().CreateAllocator(sess, mem_info, &this->p_));
182
+ }
183
+
184
+ namespace detail {
185
+
186
+ template <typename T>
187
+ inline std::string MemoryInfoImpl<T>::GetAllocatorName() const {
188
+ const char* name = nullptr;
189
+ ThrowOnError(GetApi().MemoryInfoGetName(this->p_, &name));
190
+ return std::string(name);
191
+ }
192
+
193
+ template <typename T>
194
+ inline OrtAllocatorType MemoryInfoImpl<T>::GetAllocatorType() const {
195
+ OrtAllocatorType type;
196
+ ThrowOnError(GetApi().MemoryInfoGetType(this->p_, &type));
197
+ return type;
198
+ }
199
+
200
+ template <typename T>
201
+ inline int MemoryInfoImpl<T>::GetDeviceId() const {
202
+ int id = 0;
203
+ ThrowOnError(GetApi().MemoryInfoGetId(this->p_, &id));
204
+ return id;
205
+ }
206
+
207
+ template <typename T>
208
+ inline OrtMemoryInfoDeviceType MemoryInfoImpl<T>::GetDeviceType() const {
209
+ OrtMemoryInfoDeviceType type;
210
+ GetApi().MemoryInfoGetDeviceType(this->p_, &type);
211
+ return type;
212
+ }
213
+
214
+ template <typename T>
215
+ inline OrtMemType MemoryInfoImpl<T>::GetMemoryType() const {
216
+ OrtMemType type;
217
+ ThrowOnError(GetApi().MemoryInfoGetMemType(this->p_, &type));
218
+ return type;
219
+ }
220
+
221
+ template <typename T>
222
+ template <typename U>
223
+ inline bool MemoryInfoImpl<T>::operator==(const MemoryInfoImpl<U>& o) const {
224
+ int comp_result = 0;
225
+ ThrowOnError(Ort::GetApi().CompareMemoryInfo(this->p_, o, &comp_result));
226
+ return comp_result == 0;
227
+ }
228
+
229
+ } // namespace detail
230
+
231
+ inline MemoryInfo MemoryInfo::CreateCpu(OrtAllocatorType type, OrtMemType mem_type) {
232
+ OrtMemoryInfo* p;
233
+ ThrowOnError(GetApi().CreateCpuMemoryInfo(type, mem_type, &p));
234
+ return MemoryInfo(p);
235
+ }
236
+
237
+ inline MemoryInfo::MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type) {
238
+ ThrowOnError(GetApi().CreateMemoryInfo(name, type, id, mem_type, &this->p_));
239
+ }
240
+
241
+ namespace detail {
242
+ template <typename T>
243
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames() const {
244
+ AllocatorWithDefaultOptions allocator;
245
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
246
+ }
247
+
248
+ template <typename T>
249
+ inline std::vector<std::string> ConstIoBindingImpl<T>::GetOutputNames(OrtAllocator* allocator) const {
250
+ return binding_utils::GetOutputNamesHelper(this->p_, allocator);
251
+ }
252
+
253
+ template <typename T>
254
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues() const {
255
+ AllocatorWithDefaultOptions allocator;
256
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
257
+ }
258
+
259
+ template <typename T>
260
+ inline std::vector<Value> ConstIoBindingImpl<T>::GetOutputValues(OrtAllocator* allocator) const {
261
+ return binding_utils::GetOutputValuesHelper(this->p_, allocator);
262
+ }
263
+
264
+ template <typename T>
265
+ inline void IoBindingImpl<T>::BindInput(const char* name, const Value& value) {
266
+ ThrowOnError(GetApi().BindInput(this->p_, name, value));
267
+ }
268
+
269
+ template <typename T>
270
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const Value& value) {
271
+ ThrowOnError(GetApi().BindOutput(this->p_, name, value));
272
+ }
273
+
274
+ template <typename T>
275
+ inline void IoBindingImpl<T>::BindOutput(const char* name, const OrtMemoryInfo* mem_info) {
276
+ ThrowOnError(GetApi().BindOutputToDevice(this->p_, name, mem_info));
277
+ }
278
+
279
+ template <typename T>
280
+ inline void IoBindingImpl<T>::ClearBoundInputs() {
281
+ GetApi().ClearBoundInputs(this->p_);
282
+ }
283
+
284
+ template <typename T>
285
+ inline void IoBindingImpl<T>::ClearBoundOutputs() {
286
+ GetApi().ClearBoundOutputs(this->p_);
287
+ }
288
+
289
+ template <typename T>
290
+ inline void IoBindingImpl<T>::SynchronizeInputs() {
291
+ ThrowOnError(GetApi().SynchronizeBoundInputs(this->p_));
292
+ }
293
+
294
+ template <typename T>
295
+ inline void IoBindingImpl<T>::SynchronizeOutputs() {
296
+ ThrowOnError(GetApi().SynchronizeBoundOutputs(this->p_));
297
+ }
298
+
299
+ namespace binding_utils {
300
+ inline std::vector<std::string> GetOutputNamesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
301
+ std::vector<std::string> result;
302
+ auto free_fn = detail::AllocatedFree(allocator);
303
+ using Ptr = std::unique_ptr<void, decltype(free_fn)>;
304
+
305
+ char* buffer = nullptr;
306
+ size_t* lengths = nullptr;
307
+ size_t count = 0;
308
+ ThrowOnError(GetApi().GetBoundOutputNames(binding, allocator, &buffer, &lengths, &count));
309
+
310
+ if (count == 0) {
311
+ return result;
312
+ }
313
+
314
+ Ptr buffer_g(buffer, free_fn);
315
+ Ptr lengths_g(lengths, free_fn);
316
+
317
+ result.reserve(count);
318
+ for (size_t i = 0; i < count; ++i) {
319
+ auto sz = *lengths;
320
+ result.emplace_back(buffer, sz);
321
+ buffer += sz;
322
+ ++lengths;
323
+ }
324
+ return result;
325
+ }
326
+
327
+ inline std::vector<Value> GetOutputValuesHelper(const OrtIoBinding* binding, OrtAllocator* allocator) {
328
+ std::vector<Value> result;
329
+ size_t owned = 0;
330
+ size_t output_count = 0;
331
+ // Lambda to release the buffer when no longer needed and
332
+ // make sure that we destroy all instances on exception
333
+ auto free_fn = [&owned, &output_count, allocator](OrtValue** buffer) {
334
+ if (buffer) {
335
+ while (owned < output_count) {
336
+ auto* p = buffer + owned++;
337
+ GetApi().ReleaseValue(*p);
338
+ }
339
+ allocator->Free(allocator, buffer);
340
+ }
341
+ };
342
+ using Ptr = std::unique_ptr<OrtValue*, decltype(free_fn)>;
343
+
344
+ OrtValue** output_buffer = nullptr;
345
+ ThrowOnError(GetApi().GetBoundOutputValues(binding, allocator, &output_buffer, &output_count));
346
+ if (output_count == 0) {
347
+ return result;
348
+ }
349
+
350
+ Ptr buffer_g(output_buffer, free_fn);
351
+
352
+ result.reserve(output_count);
353
+ for (size_t i = 0; i < output_count; ++i) {
354
+ result.emplace_back(output_buffer[i]);
355
+ ++owned;
356
+ }
357
+ return result;
358
+ }
359
+
360
+ } // namespace binding_utils
361
+ } // namespace detail
362
+
363
+ inline IoBinding::IoBinding(Session& session) {
364
+ ThrowOnError(GetApi().CreateIoBinding(session, &this->p_));
365
+ }
366
+
367
+ inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
368
+ ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
369
+ }
370
+
371
+ inline ThreadingOptions::ThreadingOptions() {
372
+ ThrowOnError(GetApi().CreateThreadingOptions(&p_));
373
+ }
374
+
375
+ inline ThreadingOptions& ThreadingOptions::SetGlobalIntraOpNumThreads(int intra_op_num_threads) {
376
+ ThrowOnError(GetApi().SetGlobalIntraOpNumThreads(p_, intra_op_num_threads));
377
+ return *this;
378
+ }
379
+
380
+ inline ThreadingOptions& ThreadingOptions::SetGlobalInterOpNumThreads(int inter_op_num_threads) {
381
+ ThrowOnError(GetApi().SetGlobalInterOpNumThreads(p_, inter_op_num_threads));
382
+ return *this;
383
+ }
384
+
385
+ inline ThreadingOptions& ThreadingOptions::SetGlobalSpinControl(int allow_spinning) {
386
+ ThrowOnError(GetApi().SetGlobalSpinControl(p_, allow_spinning));
387
+ return *this;
388
+ }
389
+
390
+ inline ThreadingOptions& ThreadingOptions::SetGlobalDenormalAsZero() {
391
+ ThrowOnError(GetApi().SetGlobalDenormalAsZero(p_));
392
+ return *this;
393
+ }
394
+
395
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
396
+ ThrowOnError(GetApi().SetGlobalCustomCreateThreadFn(p_, ort_custom_create_thread_fn));
397
+ return *this;
398
+ }
399
+
400
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
401
+ ThrowOnError(GetApi().SetGlobalCustomThreadCreationOptions(p_, ort_custom_thread_creation_options));
402
+ return *this;
403
+ }
404
+
405
+ inline ThreadingOptions& ThreadingOptions::SetGlobalCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
406
+ ThrowOnError(GetApi().SetGlobalCustomJoinThreadFn(p_, ort_custom_join_thread_fn));
407
+ return *this;
408
+ }
409
+
410
+ inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
411
+ ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
412
+ if (strcmp(logid, "onnxruntime-node") == 0) {
413
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
414
+ } else {
415
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
416
+ }
417
+ }
418
+
419
+ inline Env::Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
420
+ ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, logging_level, logid, &p_));
421
+ if (strcmp(logid, "onnxruntime-node") == 0) {
422
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
423
+ } else {
424
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
425
+ }
426
+ }
427
+
428
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level, _In_ const char* logid) {
429
+ ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(logging_level, logid, tp_options, &p_));
430
+ if (strcmp(logid, "onnxruntime-node") == 0) {
431
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
432
+ } else {
433
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
434
+ }
435
+ }
436
+
437
+ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
438
+ OrtLoggingLevel logging_level, _In_ const char* logid) {
439
+ ThrowOnError(GetApi().CreateEnvWithCustomLoggerAndGlobalThreadPools(logging_function, logger_param, logging_level, logid, tp_options, &p_));
440
+ if (strcmp(logid, "onnxruntime-node") == 0) {
441
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
442
+ } else {
443
+ ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
444
+ }
445
+ }
446
+
447
+ inline Env& Env::EnableTelemetryEvents() {
448
+ ThrowOnError(GetApi().EnableTelemetryEvents(p_));
449
+ return *this;
450
+ }
451
+
452
+ inline Env& Env::DisableTelemetryEvents() {
453
+ ThrowOnError(GetApi().DisableTelemetryEvents(p_));
454
+ return *this;
455
+ }
456
+
457
+ inline Env& Env::UpdateEnvWithCustomLogLevel(OrtLoggingLevel log_severity_level) {
458
+ ThrowOnError(GetApi().UpdateEnvWithCustomLogLevel(p_, log_severity_level));
459
+ return *this;
460
+ }
461
+
462
+ inline Env& Env::CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg) {
463
+ ThrowOnError(GetApi().CreateAndRegisterAllocator(p_, mem_info, arena_cfg));
464
+ return *this;
465
+ }
466
+
467
+ inline CustomOpDomain::CustomOpDomain(const char* domain) {
468
+ ThrowOnError(GetApi().CreateCustomOpDomain(domain, &p_));
469
+ }
470
+
471
+ inline void CustomOpDomain::Add(const OrtCustomOp* op) {
472
+ ThrowOnError(GetApi().CustomOpDomain_Add(p_, op));
473
+ }
474
+
475
+ inline RunOptions::RunOptions() {
476
+ ThrowOnError(GetApi().CreateRunOptions(&p_));
477
+ }
478
+
479
+ inline RunOptions& RunOptions::SetRunLogVerbosityLevel(int level) {
480
+ ThrowOnError(GetApi().RunOptionsSetRunLogVerbosityLevel(p_, level));
481
+ return *this;
482
+ }
483
+
484
+ inline RunOptions& RunOptions::SetRunLogSeverityLevel(int level) {
485
+ ThrowOnError(GetApi().RunOptionsSetRunLogSeverityLevel(p_, level));
486
+ return *this;
487
+ }
488
+
489
+ inline int RunOptions::GetRunLogVerbosityLevel() const {
490
+ int out;
491
+ ThrowOnError(GetApi().RunOptionsGetRunLogVerbosityLevel(p_, &out));
492
+ return out;
493
+ }
494
+
495
+ inline int RunOptions::GetRunLogSeverityLevel() const {
496
+ int out;
497
+ ThrowOnError(GetApi().RunOptionsGetRunLogSeverityLevel(p_, &out));
498
+ return out;
499
+ }
500
+
501
+ inline RunOptions& RunOptions::SetRunTag(const char* run_tag) {
502
+ ThrowOnError(GetApi().RunOptionsSetRunTag(p_, run_tag));
503
+ return *this;
504
+ }
505
+
506
+ inline const char* RunOptions::GetRunTag() const {
507
+ const char* out;
508
+ ThrowOnError(GetApi().RunOptionsGetRunTag(p_, &out));
509
+ return out;
510
+ }
511
+
512
+ inline RunOptions& RunOptions::AddConfigEntry(const char* config_key, const char* config_value) {
513
+ ThrowOnError(GetApi().AddRunConfigEntry(p_, config_key, config_value));
514
+ return *this;
515
+ }
516
+
517
+ inline RunOptions& RunOptions::SetTerminate() {
518
+ ThrowOnError(GetApi().RunOptionsSetTerminate(p_));
519
+ return *this;
520
+ }
521
+
522
+ inline RunOptions& RunOptions::UnsetTerminate() {
523
+ ThrowOnError(GetApi().RunOptionsUnsetTerminate(p_));
524
+ return *this;
525
+ }
526
+
527
+ namespace detail {
528
+
529
+ template <typename T>
530
+ inline Ort::SessionOptions ConstSessionOptionsImpl<T>::Clone() const {
531
+ OrtSessionOptions* out;
532
+ ThrowOnError(GetApi().CloneSessionOptions(this->p_, &out));
533
+ return SessionOptions{out};
534
+ }
535
+
536
+ template <typename T>
537
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntry(const char* config_key) const {
538
+ size_t size = 0;
539
+ // Feed nullptr for the data buffer to query the true size of the string value
540
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, nullptr, &size));
541
+
542
+ std::string out;
543
+ out.resize(size);
544
+ Ort::ThrowOnError(GetApi().GetSessionConfigEntry(this->p_, config_key, &out[0], &size));
545
+ out.resize(size - 1); // remove the terminating character '\0'
546
+
547
+ return out;
548
+ }
549
+
550
+ template <typename T>
551
+ inline bool ConstSessionOptionsImpl<T>::HasConfigEntry(const char* config_key) const {
552
+ int out = 0;
553
+ Ort::ThrowOnError(GetApi().HasSessionConfigEntry(this->p_, config_key, &out));
554
+ return static_cast<bool>(out);
555
+ }
556
+
557
+ template <typename T>
558
+ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const char* config_key, const std::string& def) {
559
+ if (!this->HasConfigEntry(config_key)) {
560
+ return def;
561
+ }
562
+
563
+ return this->GetConfigEntry(config_key);
564
+ }
565
+
566
+ template <typename T>
567
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
568
+ ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
569
+ return *this;
570
+ }
571
+
572
+ template <typename T>
573
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetInterOpNumThreads(int inter_op_num_threads) {
574
+ ThrowOnError(GetApi().SetInterOpNumThreads(this->p_, inter_op_num_threads));
575
+ return *this;
576
+ }
577
+
578
+ template <typename T>
579
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
580
+ ThrowOnError(GetApi().SetSessionGraphOptimizationLevel(this->p_, graph_optimization_level));
581
+ return *this;
582
+ }
583
+
584
+ template <typename T>
585
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_filepath) {
586
+ ThrowOnError(GetApi().SetOptimizedModelFilePath(this->p_, optimized_model_filepath));
587
+ return *this;
588
+ }
589
+
590
+ template <typename T>
591
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableProfiling(const ORTCHAR_T* profile_file_prefix) {
592
+ ThrowOnError(GetApi().EnableProfiling(this->p_, profile_file_prefix));
593
+ return *this;
594
+ }
595
+
596
+ template <typename T>
597
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableProfiling() {
598
+ ThrowOnError(GetApi().DisableProfiling(this->p_));
599
+ return *this;
600
+ }
601
+
602
+ template <typename T>
603
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableOrtCustomOps() {
604
+ ThrowOnError(GetApi().EnableOrtCustomOps(this->p_));
605
+ return *this;
606
+ }
607
+
608
+ template <typename T>
609
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableMemPattern() {
610
+ ThrowOnError(GetApi().EnableMemPattern(this->p_));
611
+ return *this;
612
+ }
613
+
614
+ template <typename T>
615
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableMemPattern() {
616
+ ThrowOnError(GetApi().DisableMemPattern(this->p_));
617
+ return *this;
618
+ }
619
+
620
+ template <typename T>
621
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::EnableCpuMemArena() {
622
+ ThrowOnError(GetApi().EnableCpuMemArena(this->p_));
623
+ return *this;
624
+ }
625
+
626
+ template <typename T>
627
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisableCpuMemArena() {
628
+ ThrowOnError(GetApi().DisableCpuMemArena(this->p_));
629
+ return *this;
630
+ }
631
+
632
+ template <typename T>
633
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionMode execution_mode) {
634
+ ThrowOnError(GetApi().SetSessionExecutionMode(this->p_, execution_mode));
635
+ return *this;
636
+ }
637
+
638
+ template <typename T>
639
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
640
+ ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
641
+ return *this;
642
+ }
643
+
644
+ template <typename T>
645
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogSeverityLevel(int level) {
646
+ ThrowOnError(GetApi().SetSessionLogSeverityLevel(this->p_, level));
647
+ return *this;
648
+ }
649
+
650
+ template <typename T>
651
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::Add(OrtCustomOpDomain* custom_op_domain) {
652
+ ThrowOnError(GetApi().AddCustomOpDomain(this->p_, custom_op_domain));
653
+ return *this;
654
+ }
655
+
656
+ template <typename T>
657
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddConfigEntry(const char* config_key, const char* config_value) {
658
+ ThrowOnError(GetApi().AddSessionConfigEntry(this->p_, config_key, config_value));
659
+ return *this;
660
+ }
661
+
662
+ template <typename T>
663
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddInitializer(const char* name, const OrtValue* ort_val) {
664
+ ThrowOnError(GetApi().AddInitializer(this->p_, name, ort_val));
665
+ return *this;
666
+ }
667
+
668
+ template <typename T>
669
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::DisablePerSessionThreads() {
670
+ ThrowOnError(GetApi().DisablePerSessionThreads(this->p_));
671
+ return *this;
672
+ }
673
+
674
+ template <typename T>
675
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AddExternalInitializers(const std::vector<std::string>& names,
676
+ const std::vector<Value>& ort_values) {
677
+ const size_t inputs_num = names.size();
678
+ if (inputs_num != ort_values.size()) {
679
+ ORT_CXX_API_THROW("Expecting names and ort_values to have the same length", ORT_INVALID_ARGUMENT);
680
+ }
681
+ std::vector<const char*> names_ptr;
682
+ std::vector<const OrtValue*> ort_values_ptrs;
683
+ names_ptr.reserve(inputs_num);
684
+ ort_values_ptrs.reserve(inputs_num);
685
+ for (size_t i = 0; i < inputs_num; ++i) {
686
+ names_ptr.push_back(names[i].c_str());
687
+ ort_values_ptrs.push_back(ort_values[i]);
688
+ }
689
+ ThrowOnError(GetApi().AddExternalInitializers(this->p_, names_ptr.data(), ort_values_ptrs.data(), inputs_num));
690
+ return *this;
691
+ }
692
+
693
+ template <typename T>
694
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options) {
695
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA(this->p_, &provider_options));
696
+ return *this;
697
+ }
698
+
699
+ template <typename T>
700
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CUDA_V2(const OrtCUDAProviderOptionsV2& provider_options) {
701
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CUDA_V2(this->p_, &provider_options));
702
+ return *this;
703
+ }
704
+
705
+ template <typename T>
706
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options) {
707
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_ROCM(this->p_, &provider_options));
708
+ return *this;
709
+ }
710
+
711
+ template <typename T>
712
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options) {
713
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT(this->p_, &provider_options));
714
+ return *this;
715
+ }
716
+
717
+ template <typename T>
718
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options) {
719
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_TensorRT_V2(this->p_, &provider_options));
720
+ return *this;
721
+ }
722
+
723
+ template <typename T>
724
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) {
725
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(this->p_, &provider_options));
726
+ return *this;
727
+ }
728
+
729
+ template <typename T>
730
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
731
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
732
+ return *this;
733
+ }
734
+
735
+ template <typename T>
736
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options) {
737
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_Dnnl(this->p_, &provider_options));
738
+ return *this;
739
+ }
740
+
741
+ template <typename T>
742
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider(
743
+ const std::string& provider_name,
744
+ const std::unordered_map<std::string, std::string>& provider_options) {
745
+ auto num_entries = provider_options.size();
746
+ std::vector<const char*> keys, values;
747
+ if (num_entries > 0) {
748
+ keys.reserve(num_entries);
749
+ values.reserve(num_entries);
750
+
751
+ for (const auto& entry : provider_options) {
752
+ keys.push_back(entry.first.c_str());
753
+ values.push_back(entry.second.c_str());
754
+ }
755
+ }
756
+
757
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider(this->p_, provider_name.c_str(),
758
+ keys.data(), values.data(), num_entries));
759
+
760
+ return *this;
761
+ }
762
+
763
+ template <typename T>
764
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
765
+ ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
766
+ return *this;
767
+ }
768
+
769
+ template <typename T>
770
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options) {
771
+ ThrowOnError(GetApi().SessionOptionsSetCustomThreadCreationOptions(this->p_, ort_custom_thread_creation_options));
772
+ return *this;
773
+ }
774
+
775
+ template <typename T>
776
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn) {
777
+ ThrowOnError(GetApi().SessionOptionsSetCustomJoinThreadFn(this->p_, ort_custom_join_thread_fn));
778
+ return *this;
779
+ }
780
+
781
+ template <typename T>
782
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options) {
783
+ ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_OpenVINO(this->p_, &provider_options));
784
+ return *this;
785
+ }
786
+
787
+ template <typename T>
788
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name,
789
+ const CustomOpConfigs& custom_op_configs) {
790
+ // Add custom op config entries before registering the custom op library. Otherwise, the config entries _may_ be ignored by
791
+ // the custom op library.
792
+ for (const auto& config_iter : custom_op_configs.GetFlattenedConfigs()) {
793
+ AddConfigEntry(config_iter.first.c_str(), config_iter.second.c_str());
794
+ }
795
+
796
+ ThrowOnError(GetApi().RegisterCustomOpsLibrary_V2(this->p_, library_name));
797
+ return *this;
798
+ }
799
+
800
+ template <typename T>
801
+ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::RegisterCustomOpsUsingFunction(const char* registration_function_name) {
802
+ ThrowOnError(GetApi().RegisterCustomOpsUsingFunction(this->p_, registration_function_name));
803
+ return *this;
804
+ }
805
+
806
+ /// Session
807
+ template <typename T>
808
+ inline size_t ConstSessionImpl<T>::GetInputCount() const {
809
+ size_t out;
810
+ ThrowOnError(GetApi().SessionGetInputCount(this->p_, &out));
811
+ return out;
812
+ }
813
+
814
+ template <typename T>
815
+ inline size_t ConstSessionImpl<T>::GetOutputCount() const {
816
+ size_t out;
817
+ ThrowOnError(GetApi().SessionGetOutputCount(this->p_, &out));
818
+ return out;
819
+ }
820
+
821
+ template <typename T>
822
+ inline size_t ConstSessionImpl<T>::GetOverridableInitializerCount() const {
823
+ size_t out;
824
+ ThrowOnError(GetApi().SessionGetOverridableInitializerCount(this->p_, &out));
825
+ return out;
826
+ }
827
+
828
+ template <typename T>
829
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetInputNameAllocated(size_t index, OrtAllocator* allocator) const {
830
+ char* out;
831
+ ThrowOnError(GetApi().SessionGetInputName(this->p_, index, allocator, &out));
832
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
833
+ }
834
+
835
+ template <typename T>
836
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOutputNameAllocated(size_t index, OrtAllocator* allocator) const {
837
+ char* out;
838
+ ThrowOnError(GetApi().SessionGetOutputName(this->p_, index, allocator, &out));
839
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
840
+ }
841
+
842
+ template <typename T>
843
+ inline AllocatedStringPtr ConstSessionImpl<T>::GetOverridableInitializerNameAllocated(size_t index, OrtAllocator* allocator) const {
844
+ char* out;
845
+ ThrowOnError(GetApi().SessionGetOverridableInitializerName(this->p_, index, allocator, &out));
846
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
847
+ }
848
+
849
+ template <typename T>
850
+ inline uint64_t ConstSessionImpl<T>::GetProfilingStartTimeNs() const {
851
+ uint64_t out;
852
+ ThrowOnError(GetApi().SessionGetProfilingStartTimeNs(this->p_, &out));
853
+ return out;
854
+ }
855
+
856
+ template <typename T>
857
+ inline ModelMetadata ConstSessionImpl<T>::GetModelMetadata() const {
858
+ OrtModelMetadata* out;
859
+ ThrowOnError(GetApi().SessionGetModelMetadata(this->p_, &out));
860
+ return ModelMetadata{out};
861
+ }
862
+
863
+ template <typename T>
864
+ inline TypeInfo ConstSessionImpl<T>::GetInputTypeInfo(size_t index) const {
865
+ OrtTypeInfo* out;
866
+ ThrowOnError(GetApi().SessionGetInputTypeInfo(this->p_, index, &out));
867
+ return TypeInfo{out};
868
+ }
869
+
870
+ template <typename T>
871
+ inline TypeInfo ConstSessionImpl<T>::GetOutputTypeInfo(size_t index) const {
872
+ OrtTypeInfo* out;
873
+ ThrowOnError(GetApi().SessionGetOutputTypeInfo(this->p_, index, &out));
874
+ return TypeInfo{out};
875
+ }
876
+
877
+ template <typename T>
878
+ inline TypeInfo ConstSessionImpl<T>::GetOverridableInitializerTypeInfo(size_t index) const {
879
+ OrtTypeInfo* out;
880
+ ThrowOnError(GetApi().SessionGetOverridableInitializerTypeInfo(this->p_, index, &out));
881
+ return TypeInfo{out};
882
+ }
883
+
884
+ template <typename T>
885
+ inline std::vector<Value> SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
886
+ const char* const* output_names, size_t output_count) {
887
+ std::vector<Value> output_values;
888
+ output_values.reserve(output_count);
889
+ for (size_t i = 0; i < output_count; i++)
890
+ output_values.emplace_back(nullptr);
891
+ Run(run_options, input_names, input_values, input_count, output_names, output_values.data(), output_count);
892
+ return output_values;
893
+ }
894
+
895
+ template <typename T>
896
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
897
+ const char* const* output_names, Value* output_values, size_t output_count) {
898
+ static_assert(sizeof(Value) == sizeof(OrtValue*), "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
899
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
900
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
901
+ ThrowOnError(GetApi().Run(this->p_, run_options, input_names, ort_input_values, input_count, output_names, output_count, ort_output_values));
902
+ }
903
+
904
+ template <typename T>
905
+ inline void SessionImpl<T>::Run(const RunOptions& run_options, const IoBinding& io_binding) {
906
+ ThrowOnError(GetApi().RunWithBinding(this->p_, run_options, io_binding));
907
+ }
908
+
909
+ template <typename T>
910
+ inline AllocatedStringPtr SessionImpl<T>::EndProfilingAllocated(OrtAllocator* allocator) {
911
+ char* out = nullptr;
912
+ ThrowOnError(GetApi().SessionEndProfiling(this->p_, allocator, &out));
913
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
914
+ }
915
+
916
+ } // namespace detail
917
+
918
+ inline SessionOptions::SessionOptions() {
919
+ ThrowOnError(GetApi().CreateSessionOptions(&this->p_));
920
+ }
921
+
922
+ /// CustomOpConfigs
923
+ inline std::string detail::MakeCustomOpConfigEntryKey(const char* custom_op_name, const char* config) {
924
+ std::string config_key = "custom_op.";
925
+
926
+ config_key += custom_op_name;
927
+ config_key += ".";
928
+ config_key += config;
929
+
930
+ return config_key;
931
+ }
932
+
933
+ inline CustomOpConfigs& CustomOpConfigs::AddConfig(const char* custom_op_name, const char* config_key, const char* config_value) {
934
+ const std::string full_flat_key = detail::MakeCustomOpConfigEntryKey(custom_op_name, config_key);
935
+ flat_configs_[full_flat_key] = config_value;
936
+ return *this;
937
+ }
938
+
939
+ inline const std::unordered_map<std::string, std::string>& CustomOpConfigs::GetFlattenedConfigs() const {
940
+ return flat_configs_;
941
+ }
942
+
943
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options) {
944
+ ThrowOnError(GetApi().CreateSession(env, model_path, options, &this->p_));
945
+ }
946
+
947
+ inline Session::Session(const Env& env, const ORTCHAR_T* model_path, const SessionOptions& options,
948
+ OrtPrepackedWeightsContainer* prepacked_weights_container) {
949
+ ThrowOnError(GetApi().CreateSessionWithPrepackedWeightsContainer(env, model_path, options, prepacked_weights_container, &this->p_));
950
+ }
951
+
952
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options) {
953
+ ThrowOnError(GetApi().CreateSessionFromArray(env, model_data, model_data_length, options, &this->p_));
954
+ }
955
+
956
+ inline Session::Session(const Env& env, const void* model_data, size_t model_data_length,
957
+ const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container) {
958
+ ThrowOnError(GetApi().CreateSessionFromArrayWithPrepackedWeightsContainer(env, model_data, model_data_length, options,
959
+ prepacked_weights_container, &this->p_));
960
+ }
961
+
962
+ inline AllocatedStringPtr ModelMetadata::GetProducerNameAllocated(OrtAllocator* allocator) const {
963
+ char* out;
964
+ ThrowOnError(GetApi().ModelMetadataGetProducerName(p_, allocator, &out));
965
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
966
+ }
967
+
968
+ inline AllocatedStringPtr ModelMetadata::GetGraphNameAllocated(OrtAllocator* allocator) const {
969
+ char* out;
970
+ ThrowOnError(GetApi().ModelMetadataGetGraphName(p_, allocator, &out));
971
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
972
+ }
973
+
974
+ inline AllocatedStringPtr ModelMetadata::GetDomainAllocated(OrtAllocator* allocator) const {
975
+ char* out;
976
+ ThrowOnError(GetApi().ModelMetadataGetDomain(p_, allocator, &out));
977
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
978
+ }
979
+
980
+ inline AllocatedStringPtr Ort::ModelMetadata::GetDescriptionAllocated(OrtAllocator* allocator) const {
981
+ char* out;
982
+ ThrowOnError(GetApi().ModelMetadataGetDescription(p_, allocator, &out));
983
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
984
+ }
985
+
986
+ inline AllocatedStringPtr ModelMetadata::GetGraphDescriptionAllocated(OrtAllocator* allocator) const {
987
+ char* out;
988
+ ThrowOnError(GetApi().ModelMetadataGetGraphDescription(p_, allocator, &out));
989
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
990
+ }
991
+
992
+ inline AllocatedStringPtr ModelMetadata::LookupCustomMetadataMapAllocated(const char* key, OrtAllocator* allocator) const {
993
+ char* out;
994
+ ThrowOnError(GetApi().ModelMetadataLookupCustomMetadataMap(p_, allocator, key, &out));
995
+ return AllocatedStringPtr(out, detail::AllocatedFree(allocator));
996
+ }
997
+
998
+ inline std::vector<AllocatedStringPtr> ModelMetadata::GetCustomMetadataMapKeysAllocated(OrtAllocator* allocator) const {
999
+ auto deletor = detail::AllocatedFree(allocator);
1000
+ std::vector<AllocatedStringPtr> result;
1001
+
1002
+ char** out = nullptr;
1003
+ int64_t num_keys = 0;
1004
+ ThrowOnError(GetApi().ModelMetadataGetCustomMetadataMapKeys(p_, allocator, &out, &num_keys));
1005
+ if (num_keys <= 0) {
1006
+ return result;
1007
+ }
1008
+
1009
+ // array of pointers will be freed
1010
+ std::unique_ptr<void, decltype(deletor)> array_guard(out, deletor);
1011
+ // reserve may throw
1012
+ auto strings_deletor = [&deletor, num_keys](char** out) { for(int64_t i = 0; i < num_keys; ++i) deletor(out[i]); };
1013
+ std::unique_ptr<char*, decltype(strings_deletor)> strings_guard(out, strings_deletor);
1014
+ result.reserve(static_cast<size_t>(num_keys));
1015
+ strings_guard.release();
1016
+ for (int64_t i = 0; i < num_keys; ++i) {
1017
+ result.push_back(AllocatedStringPtr(out[i], deletor));
1018
+ }
1019
+
1020
+ return result;
1021
+ }
1022
+
1023
+ inline int64_t ModelMetadata::GetVersion() const {
1024
+ int64_t out;
1025
+ ThrowOnError(GetApi().ModelMetadataGetVersion(p_, &out));
1026
+ return out;
1027
+ }
1028
+
1029
+ namespace detail {
1030
+
1031
+ template <typename T>
1032
+ inline ONNXTensorElementDataType TensorTypeAndShapeInfoImpl<T>::GetElementType() const {
1033
+ ONNXTensorElementDataType out;
1034
+ ThrowOnError(GetApi().GetTensorElementType(this->p_, &out));
1035
+ return out;
1036
+ }
1037
+
1038
+ template <typename T>
1039
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetElementCount() const {
1040
+ size_t out;
1041
+ ThrowOnError(GetApi().GetTensorShapeElementCount(this->p_, &out));
1042
+ return static_cast<size_t>(out);
1043
+ }
1044
+
1045
+ template <typename T>
1046
+ inline size_t TensorTypeAndShapeInfoImpl<T>::GetDimensionsCount() const {
1047
+ size_t out;
1048
+ ThrowOnError(GetApi().GetDimensionsCount(this->p_, &out));
1049
+ return out;
1050
+ }
1051
+
1052
+ template <typename T>
1053
+ inline void TensorTypeAndShapeInfoImpl<T>::GetDimensions(int64_t* values, size_t values_count) const {
1054
+ ThrowOnError(GetApi().GetDimensions(this->p_, values, values_count));
1055
+ }
1056
+
1057
+ template <typename T>
1058
+ inline void TensorTypeAndShapeInfoImpl<T>::GetSymbolicDimensions(const char** values, size_t values_count) const {
1059
+ ThrowOnError(GetApi().GetSymbolicDimensions(this->p_, values, values_count));
1060
+ }
1061
+
1062
+ template <typename T>
1063
+ inline std::vector<int64_t> TensorTypeAndShapeInfoImpl<T>::GetShape() const {
1064
+ std::vector<int64_t> out(GetDimensionsCount(), 0);
1065
+ ThrowOnError(GetApi().GetDimensions(this->p_, out.data(), out.size()));
1066
+ return out;
1067
+ }
1068
+
1069
+ } // namespace detail
1070
+
1071
+ namespace detail {
1072
+ template <typename T>
1073
+ inline ConstTensorTypeAndShapeInfo TypeInfoImpl<T>::GetTensorTypeAndShapeInfo() const {
1074
+ const OrtTensorTypeAndShapeInfo* out;
1075
+ ThrowOnError(GetApi().CastTypeInfoToTensorInfo(this->p_, &out));
1076
+ return ConstTensorTypeAndShapeInfo{out};
1077
+ }
1078
+
1079
+ template <typename T>
1080
+ inline ConstSequenceTypeInfo TypeInfoImpl<T>::GetSequenceTypeInfo() const {
1081
+ const OrtSequenceTypeInfo* out;
1082
+ ThrowOnError(GetApi().CastTypeInfoToSequenceTypeInfo(this->p_, &out));
1083
+ return ConstSequenceTypeInfo{out};
1084
+ }
1085
+
1086
+ template <typename T>
1087
+ inline ConstMapTypeInfo TypeInfoImpl<T>::GetMapTypeInfo() const {
1088
+ const OrtMapTypeInfo* out;
1089
+ ThrowOnError(GetApi().CastTypeInfoToMapTypeInfo(this->p_, &out));
1090
+ return ConstMapTypeInfo{out};
1091
+ }
1092
+
1093
+ template <typename T>
1094
+ inline ONNXType TypeInfoImpl<T>::GetONNXType() const {
1095
+ ONNXType out;
1096
+ ThrowOnError(GetApi().GetOnnxTypeFromTypeInfo(this->p_, &out));
1097
+ return out;
1098
+ }
1099
+
1100
+ } // namespace detail
1101
+
1102
+ namespace detail {
1103
+ template <typename T>
1104
+ inline TypeInfo SequenceTypeInfoImpl<T>::GetSequenceElementType() const {
1105
+ OrtTypeInfo* output;
1106
+ ThrowOnError(GetApi().GetSequenceElementType(this->p_, &output));
1107
+ return TypeInfo{output};
1108
+ }
1109
+
1110
+ } // namespace detail
1111
+
1112
+ namespace detail {
1113
+ template <typename T>
1114
+ inline ONNXTensorElementDataType MapTypeInfoImpl<T>::GetMapKeyType() const {
1115
+ ONNXTensorElementDataType out;
1116
+ ThrowOnError(GetApi().GetMapKeyType(this->p_, &out));
1117
+ return out;
1118
+ }
1119
+
1120
+ template <typename T>
1121
+ inline TypeInfo MapTypeInfoImpl<T>::GetMapValueType() const {
1122
+ OrtTypeInfo* output;
1123
+ ThrowOnError(GetApi().GetMapValueType(this->p_, &output));
1124
+ return TypeInfo{output};
1125
+ }
1126
+ } // namespace detail
1127
+
1128
+ namespace detail {
1129
+
1130
+ template <typename T>
1131
+ template <typename R>
1132
+ inline void ConstValueImpl<T>::GetOpaqueData(const char* domain, const char* type_name, R& out) const {
1133
+ ThrowOnError(GetApi().GetOpaqueValue(domain, type_name, this->p_, &out, sizeof(R)));
1134
+ }
1135
+
1136
+ template <typename T>
1137
+ inline bool ConstValueImpl<T>::IsTensor() const {
1138
+ int out;
1139
+ ThrowOnError(GetApi().IsTensor(this->p_, &out));
1140
+ return out != 0;
1141
+ }
1142
+
1143
+ template <typename T>
1144
+ inline bool ConstValueImpl<T>::HasValue() const {
1145
+ int out;
1146
+ ThrowOnError(GetApi().HasValue(this->p_, &out));
1147
+ return out != 0;
1148
+ }
1149
+
1150
+ template <typename T>
1151
+ inline size_t ConstValueImpl<T>::GetCount() const {
1152
+ size_t out;
1153
+ ThrowOnError(GetApi().GetValueCount(this->p_, &out));
1154
+ return out;
1155
+ }
1156
+
1157
+ template <typename T>
1158
+ inline Value ConstValueImpl<T>::GetValue(int index, OrtAllocator* allocator) const {
1159
+ OrtValue* out;
1160
+ ThrowOnError(GetApi().GetValue(this->p_, index, allocator, &out));
1161
+ return Value{out};
1162
+ }
1163
+
1164
+ template <typename T>
1165
+ inline size_t ConstValueImpl<T>::GetStringTensorDataLength() const {
1166
+ size_t out;
1167
+ ThrowOnError(GetApi().GetStringTensorDataLength(this->p_, &out));
1168
+ return out;
1169
+ }
1170
+
1171
+ template <typename T>
1172
+ inline size_t ConstValueImpl<T>::GetStringTensorElementLength(size_t element_index) const {
1173
+ size_t out;
1174
+ ThrowOnError(GetApi().GetStringTensorElementLength(this->p_, element_index, &out));
1175
+ return out;
1176
+ }
1177
+
1178
+ template <typename T>
1179
+ template <typename R>
1180
+ inline const R* ConstValueImpl<T>::GetTensorData() const {
1181
+ R* out;
1182
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), (void**)&out));
1183
+ return out;
1184
+ }
1185
+
1186
+ template <typename T>
1187
+ inline const void* ConstValueImpl<T>::GetTensorRawData() const {
1188
+ void* out;
1189
+ ThrowOnError(GetApi().GetTensorMutableData(const_cast<OrtValue*>(this->p_), &out));
1190
+ return out;
1191
+ }
1192
+
1193
+ template <typename T>
1194
+ inline TypeInfo ConstValueImpl<T>::GetTypeInfo() const {
1195
+ OrtTypeInfo* output;
1196
+ ThrowOnError(GetApi().GetTypeInfo(this->p_, &output));
1197
+ return TypeInfo{output};
1198
+ }
1199
+
1200
+ template <typename T>
1201
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetTensorTypeAndShapeInfo() const {
1202
+ OrtTensorTypeAndShapeInfo* output;
1203
+ ThrowOnError(GetApi().GetTensorTypeAndShape(this->p_, &output));
1204
+ return TensorTypeAndShapeInfo{output};
1205
+ }
1206
+
1207
+ template <typename T>
1208
+ inline ConstMemoryInfo ConstValueImpl<T>::GetTensorMemoryInfo() const {
1209
+ const OrtMemoryInfo* mem_info;
1210
+ ThrowOnError(GetApi().GetTensorMemoryInfo(this->p_, &mem_info));
1211
+ return ConstMemoryInfo(mem_info);
1212
+ }
1213
+
1214
+ template <typename T>
1215
+ inline void ConstValueImpl<T>::GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const {
1216
+ ThrowOnError(GetApi().GetStringTensorElement(this->p_, buffer_length, element_index, buffer));
1217
+ }
1218
+
1219
+ template <typename T>
1220
+ inline void ConstValueImpl<T>::GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const {
1221
+ ThrowOnError(GetApi().GetStringTensorContent(this->p_, buffer, buffer_length, offsets, offsets_count));
1222
+ }
1223
+
1224
+ #if !defined(DISABLE_SPARSE_TENSORS)
1225
+ template <typename T>
1226
+ inline OrtSparseFormat ConstValueImpl<T>::GetSparseFormat() const {
1227
+ OrtSparseFormat format;
1228
+ ThrowOnError(GetApi().GetSparseTensorFormat(this->p_, &format));
1229
+ return format;
1230
+ }
1231
+
1232
+ template <typename T>
1233
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorValuesTypeAndShapeInfo() const {
1234
+ OrtTensorTypeAndShapeInfo* output;
1235
+ ThrowOnError(GetApi().GetSparseTensorValuesTypeAndShape(this->p_, &output));
1236
+ return TensorTypeAndShapeInfo{output};
1237
+ }
1238
+
1239
+ template <typename T>
1240
+ inline TensorTypeAndShapeInfo ConstValueImpl<T>::GetSparseTensorIndicesTypeShapeInfo(OrtSparseIndicesFormat indices_format) const {
1241
+ OrtTensorTypeAndShapeInfo* output;
1242
+ ThrowOnError(GetApi().GetSparseTensorIndicesTypeShape(this->p_, indices_format, &output));
1243
+ return TensorTypeAndShapeInfo{output};
1244
+ }
1245
+
1246
+ template <typename T>
1247
+ template <typename R>
1248
+ inline const R* ConstValueImpl<T>::GetSparseTensorIndicesData(OrtSparseIndicesFormat indices_format, size_t& num_indices) const {
1249
+ const void* out;
1250
+ ThrowOnError(GetApi().GetSparseTensorIndices(this->p_, indices_format, &num_indices, &out));
1251
+ return reinterpret_cast<const R*>(out);
1252
+ }
1253
+
1254
+ template <typename T>
1255
+ inline bool ConstValueImpl<T>::IsSparseTensor() const {
1256
+ int out;
1257
+ ThrowOnError(GetApi().IsSparseTensor(this->p_, &out));
1258
+ return out != 0;
1259
+ }
1260
+
1261
+ template <typename T>
1262
+ template <typename R>
1263
+ inline const R* ConstValueImpl<T>::GetSparseTensorValues() const {
1264
+ const void* out;
1265
+ ThrowOnError(GetApi().GetSparseTensorValues(this->p_, &out));
1266
+ return reinterpret_cast<const R*>(out);
1267
+ }
1268
+
1269
+ #endif
1270
+
1271
+ template <typename T>
1272
+ void ValueImpl<T>::FillStringTensor(const char* const* s, size_t s_len) {
1273
+ ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len));
1274
+ }
1275
+
1276
+ template <typename T>
1277
+ void ValueImpl<T>::FillStringTensorElement(const char* s, size_t index) {
1278
+ ThrowOnError(GetApi().FillStringTensorElement(this->p_, s, index));
1279
+ }
1280
+
1281
+ template <typename T>
1282
+ void* ValueImpl<T>::GetTensorMutableRawData() {
1283
+ void* out;
1284
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, &out));
1285
+ return out;
1286
+ }
1287
+
1288
+ template <typename T>
1289
+ template <typename R>
1290
+ R* ValueImpl<T>::GetTensorMutableData() {
1291
+ R* out;
1292
+ ThrowOnError(GetApi().GetTensorMutableData(this->p_, (void**)&out));
1293
+ return out;
1294
+ }
1295
+
1296
+ template <typename T>
1297
+ template <typename R>
1298
+ R& ValueImpl<T>::At(const std::vector<int64_t>& location) {
1299
+ static_assert(!std::is_same<T, std::string>::value, "this api does not support std::string");
1300
+ R* out;
1301
+ ThrowOnError(GetApi().TensorAt(this->p_, location.data(), location.size(), (void**)&out));
1302
+ return *out;
1303
+ }
1304
+
1305
+ #if !defined(DISABLE_SPARSE_TENSORS)
1306
+ template <typename T>
1307
+ void ValueImpl<T>::UseCooIndices(int64_t* indices_data, size_t indices_num) {
1308
+ ThrowOnError(GetApi().UseCooIndices(this->p_, indices_data, indices_num));
1309
+ }
1310
+
1311
+ template <typename T>
1312
+ void ValueImpl<T>::UseCsrIndices(int64_t* inner_data, size_t inner_num, int64_t* outer_data, size_t outer_num) {
1313
+ ThrowOnError(GetApi().UseCsrIndices(this->p_, inner_data, inner_num, outer_data, outer_num));
1314
+ }
1315
+
1316
+ template <typename T>
1317
+ void ValueImpl<T>::UseBlockSparseIndices(const Shape& indices_shape, int32_t* indices_data) {
1318
+ ThrowOnError(GetApi().UseBlockSparseIndices(this->p_, indices_shape.shape, indices_shape.shape_len, indices_data));
1319
+ }
1320
+
1321
+ template <typename T>
1322
+ void ValueImpl<T>::FillSparseTensorCoo(const OrtMemoryInfo* mem_info, const OrtSparseValuesParam& values_param,
1323
+ const int64_t* indices_data, size_t indices_num) {
1324
+ ThrowOnError(GetApi().FillSparseTensorCoo(this->p_, mem_info, values_param.values_shape,
1325
+ values_param.values_shape_len, values_param.data.p_data,
1326
+ indices_data, indices_num));
1327
+ }
1328
+
1329
+ template <typename T>
1330
+ void ValueImpl<T>::FillSparseTensorCsr(const OrtMemoryInfo* data_mem_info,
1331
+ const OrtSparseValuesParam& values,
1332
+ const int64_t* inner_indices_data, size_t inner_indices_num,
1333
+ const int64_t* outer_indices_data, size_t outer_indices_num) {
1334
+ ThrowOnError(GetApi().FillSparseTensorCsr(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1335
+ inner_indices_data, inner_indices_num,
1336
+ outer_indices_data, outer_indices_num));
1337
+ }
1338
+
1339
+ template <typename T>
1340
+ void ValueImpl<T>::FillSparseTensorBlockSparse(const OrtMemoryInfo* data_mem_info,
1341
+ const OrtSparseValuesParam& values,
1342
+ const Shape& indices_shape,
1343
+ const int32_t* indices_data) {
1344
+ ThrowOnError(GetApi().FillSparseTensorBlockSparse(this->p_, data_mem_info, values.values_shape, values.values_shape_len, values.data.p_data,
1345
+ indices_shape.shape, indices_shape.shape_len,
1346
+ indices_data));
1347
+ }
1348
+
1349
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1350
+
1351
+ } // namespace detail
1352
+
1353
+ template <typename T>
1354
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len) {
1355
+ return CreateTensor(info, p_data, p_data_element_count * sizeof(T), shape, shape_len, TypeToTensorType<T>::type);
1356
+ }
1357
+
1358
+ inline Value Value::CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
1359
+ ONNXTensorElementDataType type) {
1360
+ OrtValue* out;
1361
+ ThrowOnError(GetApi().CreateTensorWithDataAsOrtValue(info, p_data, p_data_byte_count, shape, shape_len, type, &out));
1362
+ return Value{out};
1363
+ }
1364
+
1365
+ template <typename T>
1366
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len) {
1367
+ return CreateTensor(allocator, shape, shape_len, TypeToTensorType<T>::type);
1368
+ }
1369
+
1370
+ inline Value Value::CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type) {
1371
+ OrtValue* out;
1372
+ ThrowOnError(GetApi().CreateTensorAsOrtValue(allocator, shape, shape_len, type, &out));
1373
+ return Value{out};
1374
+ }
1375
+
1376
+ #if !defined(DISABLE_SPARSE_TENSORS)
1377
+
1378
+ template <typename T>
1379
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, T* p_data, const Shape& dense_shape,
1380
+ const Shape& values_shape) {
1381
+ return CreateSparseTensor(info, p_data, dense_shape, values_shape, TypeToTensorType<T>::type);
1382
+ }
1383
+
1384
+ inline Value Value::CreateSparseTensor(const OrtMemoryInfo* info, void* p_data, const Shape& dense_shape,
1385
+ const Shape& values_shape, ONNXTensorElementDataType type) {
1386
+ OrtValue* out;
1387
+ ThrowOnError(GetApi().CreateSparseTensorWithValuesAsOrtValue(info, p_data, dense_shape.shape, dense_shape.shape_len,
1388
+ values_shape.shape, values_shape.shape_len, type, &out));
1389
+ return Value{out};
1390
+ }
1391
+
1392
+ template <typename T>
1393
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape) {
1394
+ return CreateSparseTensor(allocator, dense_shape, TypeToTensorType<T>::type);
1395
+ }
1396
+
1397
+ inline Value Value::CreateSparseTensor(OrtAllocator* allocator, const Shape& dense_shape,
1398
+ ONNXTensorElementDataType type) {
1399
+ OrtValue* out;
1400
+ ThrowOnError(GetApi().CreateSparseTensorAsOrtValue(allocator, dense_shape.shape, dense_shape.shape_len, type, &out));
1401
+ return Value{out};
1402
+ }
1403
+ #endif // !defined(DISABLE_SPARSE_TENSORS)
1404
+
1405
+ inline Value Value::CreateMap(Value& keys, Value& values) {
1406
+ OrtValue* out;
1407
+ OrtValue* inputs[2] = {keys, values};
1408
+ ThrowOnError(GetApi().CreateValue(inputs, 2, ONNX_TYPE_MAP, &out));
1409
+ return Value{out};
1410
+ }
1411
+
1412
+ inline Value Value::CreateSequence(std::vector<Value>& values) {
1413
+ OrtValue* out;
1414
+ std::vector<OrtValue*> values_ort{values.data(), values.data() + values.size()};
1415
+ ThrowOnError(GetApi().CreateValue(values_ort.data(), values_ort.size(), ONNX_TYPE_SEQUENCE, &out));
1416
+ return Value{out};
1417
+ }
1418
+
1419
+ template <typename T>
1420
+ inline Value Value::CreateOpaque(const char* domain, const char* type_name, const T& data_container) {
1421
+ OrtValue* out;
1422
+ ThrowOnError(GetApi().CreateOpaqueValue(domain, type_name, &data_container, sizeof(T), &out));
1423
+ return Value{out};
1424
+ }
1425
+
1426
+ //
1427
+ // Custom OP Inlines
1428
+ //
1429
+ inline KernelContext::KernelContext(OrtKernelContext* context) : ctx_(context) {
1430
+ }
1431
+
1432
+ inline size_t KernelContext::GetInputCount() const {
1433
+ size_t out = 0;
1434
+ Ort::ThrowOnError(GetApi().KernelContext_GetInputCount(ctx_, &out));
1435
+ return out;
1436
+ }
1437
+
1438
+ inline size_t KernelContext::GetOutputCount() const {
1439
+ size_t out = 0;
1440
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutputCount(ctx_, &out));
1441
+ return out;
1442
+ }
1443
+
1444
+ inline ConstValue KernelContext::GetInput(size_t index) const {
1445
+ const OrtValue* out = nullptr;
1446
+ Ort::ThrowOnError(GetApi().KernelContext_GetInput(ctx_, index, &out));
1447
+ return ConstValue{out};
1448
+ }
1449
+
1450
+ inline UnownedValue KernelContext::GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const {
1451
+ OrtValue* out = nullptr;
1452
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dim_values, dim_count, &out));
1453
+ return UnownedValue(out);
1454
+ }
1455
+
1456
+ inline UnownedValue KernelContext::GetOutput(size_t index, const std::vector<int64_t>& dims) const {
1457
+ OrtValue* out = nullptr;
1458
+ Ort::ThrowOnError(GetApi().KernelContext_GetOutput(ctx_, index, dims.data(), dims.size(), &out));
1459
+ return UnownedValue(out);
1460
+ }
1461
+
1462
+ inline void* KernelContext::GetGPUComputeStream() const {
1463
+ void* out = nullptr;
1464
+ Ort::ThrowOnError(GetApi().KernelContext_GetGPUComputeStream(ctx_, &out));
1465
+ return out;
1466
+ }
1467
+
1468
+ inline OpAttr::OpAttr(const char* name, const void* data, int len, OrtOpAttrType type) {
1469
+ Ort::ThrowOnError(GetApi().CreateOpAttr(name, data, len, type, &p_));
1470
+ }
1471
+
1472
+ namespace detail {
1473
+ template <typename T>
1474
+ inline KernelInfo KernelInfoImpl<T>::Copy() const {
1475
+ OrtKernelInfo* info_copy = nullptr;
1476
+ Ort::ThrowOnError(GetApi().CopyKernelInfo(this->p_, &info_copy));
1477
+ return KernelInfo{info_copy};
1478
+ }
1479
+
1480
+ template <typename T>
1481
+ inline size_t KernelInfoImpl<T>::GetInputCount() const {
1482
+ size_t out = 0;
1483
+ ThrowOnError(GetApi().KernelInfo_GetInputCount(this->p_, &out));
1484
+ return out;
1485
+ }
1486
+
1487
+ template <typename T>
1488
+ inline size_t KernelInfoImpl<T>::GetOutputCount() const {
1489
+ size_t out = 0;
1490
+ ThrowOnError(GetApi().KernelInfo_GetOutputCount(this->p_, &out));
1491
+ return out;
1492
+ }
1493
+
1494
+ template <typename T>
1495
+ inline std::string KernelInfoImpl<T>::GetInputName(size_t index) const {
1496
+ size_t size = 0;
1497
+
1498
+ // Feed nullptr for the data buffer to query the true size of the string value
1499
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, nullptr, &size));
1500
+
1501
+ std::string out;
1502
+ out.resize(size);
1503
+ Ort::ThrowOnError(GetApi().KernelInfo_GetInputName(this->p_, index, &out[0], &size));
1504
+ out.resize(size - 1); // remove the terminating character '\0'
1505
+
1506
+ return out;
1507
+ }
1508
+
1509
+ template <typename T>
1510
+ inline std::string KernelInfoImpl<T>::GetOutputName(size_t index) const {
1511
+ size_t size = 0;
1512
+
1513
+ // Feed nullptr for the data buffer to query the true size of the string value
1514
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, nullptr, &size));
1515
+
1516
+ std::string out;
1517
+ out.resize(size);
1518
+ Ort::ThrowOnError(GetApi().KernelInfo_GetOutputName(this->p_, index, &out[0], &size));
1519
+ out.resize(size - 1); // remove the terminating character '\0'
1520
+
1521
+ return out;
1522
+ }
1523
+
1524
+ template <typename T>
1525
+ inline TypeInfo KernelInfoImpl<T>::GetInputTypeInfo(size_t index) const {
1526
+ OrtTypeInfo* out = nullptr;
1527
+ ThrowOnError(GetApi().KernelInfo_GetInputTypeInfo(this->p_, index, &out));
1528
+ return TypeInfo{out};
1529
+ }
1530
+
1531
+ template <typename T>
1532
+ inline TypeInfo KernelInfoImpl<T>::GetOutputTypeInfo(size_t index) const {
1533
+ OrtTypeInfo* out = nullptr;
1534
+ ThrowOnError(GetApi().KernelInfo_GetOutputTypeInfo(this->p_, index, &out));
1535
+ return TypeInfo{out};
1536
+ }
1537
+
1538
+ template <typename T>
1539
+ inline Value KernelInfoImpl<T>::GetTensorAttribute(const char* name, OrtAllocator* allocator) const {
1540
+ OrtValue* out = nullptr;
1541
+ ThrowOnError(GetApi().KernelInfoGetAttribute_tensor(this->p_, name, allocator, &out));
1542
+ return Value{out};
1543
+ }
1544
+
1545
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, float& out) {
1546
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_float(p, name, &out));
1547
+ }
1548
+
1549
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, int64_t& out) {
1550
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_int64(p, name, &out));
1551
+ }
1552
+
1553
+ inline void attr_utils::GetAttr(const OrtKernelInfo* p, const char* name, std::string& result) {
1554
+ size_t size = 0;
1555
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1556
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, nullptr, &size));
1557
+
1558
+ std::string out;
1559
+ out.resize(size);
1560
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttribute_string(p, name, &out[0], &size));
1561
+ out.resize(size - 1); // remove the terminating character '\0'
1562
+ out.swap(result);
1563
+ }
1564
+
1565
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<float>& result) {
1566
+ size_t size = 0;
1567
+ // Feed nullptr for the data buffer to query the true size of the attribute
1568
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, nullptr, &size));
1569
+
1570
+ std::vector<float> out;
1571
+ out.resize(size);
1572
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_float(p, name, out.data(), &size));
1573
+ out.swap(result);
1574
+ }
1575
+
1576
+ inline void attr_utils::GetAttrs(const OrtKernelInfo* p, const char* name, std::vector<int64_t>& result) {
1577
+ size_t size = 0;
1578
+
1579
+ // Feed nullptr for the data buffer to query the true size of the attribute
1580
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, nullptr, &size));
1581
+
1582
+ std::vector<int64_t> out;
1583
+ out.resize(size);
1584
+ Ort::ThrowOnError(GetApi().KernelInfoGetAttributeArray_int64(p, name, out.data(), &size));
1585
+ out.swap(result);
1586
+ }
1587
+ } // namespace detail
1588
+
1589
+ inline KernelInfo::KernelInfo(OrtKernelInfo* info) : detail::KernelInfoImpl<OrtKernelInfo>{info} {}
1590
+
1591
+ inline Op::Op(OrtOp* p) : Base<OrtOp>(p) {}
1592
+
1593
+ inline Op Op::Create(const OrtKernelInfo* info, const char* op_name, const char* domain, int version,
1594
+ const char** type_constraint_names,
1595
+ const ONNXTensorElementDataType* type_constraint_values,
1596
+ size_t type_constraint_count,
1597
+ const OpAttr* attr_values, size_t attr_count,
1598
+ size_t input_count, size_t output_count) {
1599
+ static_assert(sizeof(OpAttr) == sizeof(OrtOpAttr*),
1600
+ "OpAttr's is expected to be just an array of OrtOpAttr in memory so we can reinterpret safely");
1601
+ auto attr_input_values = reinterpret_cast<const OrtOpAttr* const*>(attr_values);
1602
+ OrtOp* op;
1603
+ Ort::ThrowOnError(GetApi().CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1604
+ static_cast<int>(type_constraint_count),
1605
+ attr_input_values,
1606
+ static_cast<int>(attr_count),
1607
+ static_cast<int>(input_count),
1608
+ static_cast<int>(output_count), &op));
1609
+ return Op{op};
1610
+ }
1611
+
1612
+ inline void Op::Invoke(const OrtKernelContext* context,
1613
+ const Value* input_values,
1614
+ size_t input_count,
1615
+ Value* output_values,
1616
+ size_t output_count) {
1617
+ static_assert(sizeof(Value) == sizeof(OrtValue*),
1618
+ "Value is really just an array of OrtValue* in memory, so we can reinterpret_cast safely");
1619
+ auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
1620
+ auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
1621
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, ort_input_values, static_cast<int>(input_count),
1622
+ ort_output_values, static_cast<int>(output_count)));
1623
+ }
1624
+
1625
+ inline void Op::Invoke(const OrtKernelContext* context,
1626
+ const OrtValue* const* input_values,
1627
+ size_t input_count,
1628
+ OrtValue* const* output_values,
1629
+ size_t output_count) {
1630
+ Ort::ThrowOnError(GetApi().InvokeOp(context, p_, input_values, static_cast<int>(input_count),
1631
+ output_values, static_cast<int>(output_count)));
1632
+ }
1633
+
1634
+ inline void CustomOpApi::ThrowOnError(OrtStatus* status) {
1635
+ Ort::ThrowOnError(status);
1636
+ }
1637
+
1638
+ template <>
1639
+ inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1640
+ float out;
1641
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
1642
+ return out;
1643
+ }
1644
+
1645
+ template <>
1646
+ inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1647
+ int64_t out;
1648
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
1649
+ return out;
1650
+ }
1651
+
1652
+ template <>
1653
+ inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1654
+ size_t size = 0;
1655
+ std::string out;
1656
+
1657
+ // Feed nullptr for the data buffer to query the true size of the string attribute
1658
+ OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
1659
+
1660
+ if (status == nullptr) {
1661
+ out.resize(size);
1662
+ Ort::ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
1663
+ out.resize(size - 1); // remove the terminating character '\0'
1664
+ } else {
1665
+ Ort::ThrowOnError(status);
1666
+ }
1667
+ return out;
1668
+ }
1669
+
1670
+ template <>
1671
+ inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1672
+ size_t size = 0;
1673
+ std::vector<float> out;
1674
+
1675
+ // Feed nullptr for the data buffer to query the true size of the attribute
1676
+ OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
1677
+
1678
+ if (status == nullptr) {
1679
+ out.resize(size);
1680
+ Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
1681
+ } else {
1682
+ Ort::ThrowOnError(status);
1683
+ }
1684
+ return out;
1685
+ }
1686
+
1687
+ template <>
1688
+ inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) {
1689
+ size_t size = 0;
1690
+ std::vector<int64_t> out;
1691
+
1692
+ // Feed nullptr for the data buffer to query the true size of the attribute
1693
+ OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
1694
+
1695
+ if (status == nullptr) {
1696
+ out.resize(size);
1697
+ Ort::ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
1698
+ } else {
1699
+ Ort::ThrowOnError(status);
1700
+ }
1701
+ return out;
1702
+ }
1703
+ inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) {
1704
+ OrtTensorTypeAndShapeInfo* out;
1705
+ Ort::ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
1706
+ return out;
1707
+ }
1708
+
1709
+ inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1710
+ size_t out;
1711
+ Ort::ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
1712
+ return out;
1713
+ }
1714
+
1715
+ inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) {
1716
+ ONNXTensorElementDataType out;
1717
+ Ort::ThrowOnError(api_.GetTensorElementType(info, &out));
1718
+ return out;
1719
+ }
1720
+
1721
+ inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) {
1722
+ size_t out;
1723
+ Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1724
+ return out;
1725
+ }
1726
+
1727
+ inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) {
1728
+ Ort::ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
1729
+ }
1730
+
1731
+ inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) {
1732
+ Ort::ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
1733
+ }
1734
+
1735
+ template <typename T>
1736
+ inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) {
1737
+ T* data;
1738
+ Ort::ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
1739
+ return data;
1740
+ }
1741
+
1742
+ inline const OrtMemoryInfo* CustomOpApi::GetTensorMemoryInfo(_In_ const OrtValue* value) {
1743
+ const OrtMemoryInfo* mem_info;
1744
+ Ort::ThrowOnError(api_.GetTensorMemoryInfo(value, &mem_info));
1745
+ return mem_info;
1746
+ }
1747
+
1748
+ template <typename T>
1749
+ inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) {
1750
+ T* data = nullptr;
1751
+ Ort::ThrowOnError(api_.GetTensorMutableData(const_cast<OrtValue*>(value), reinterpret_cast<void**>(&data)));
1752
+ return data;
1753
+ }
1754
+
1755
+ inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) {
1756
+ size_t out;
1757
+ Ort::ThrowOnError(api_.GetDimensionsCount(info, &out));
1758
+ std::vector<int64_t> output(out);
1759
+ Ort::ThrowOnError(api_.GetDimensions(info, output.data(), out));
1760
+ return output;
1761
+ }
1762
+
1763
+ inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) {
1764
+ api_.ReleaseTensorTypeAndShapeInfo(input);
1765
+ }
1766
+
1767
+ inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) {
1768
+ size_t out;
1769
+ Ort::ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
1770
+ return out;
1771
+ }
1772
+
1773
+ inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) {
1774
+ const OrtValue* out;
1775
+ Ort::ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
1776
+ return out;
1777
+ }
1778
+
1779
+ inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) {
1780
+ size_t out;
1781
+ Ort::ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
1782
+ return out;
1783
+ }
1784
+
1785
+ inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
1786
+ _In_ const int64_t* dim_values, size_t dim_count) {
1787
+ OrtValue* out;
1788
+ Ort::ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
1789
+ return out;
1790
+ }
1791
+
1792
+ inline void* CustomOpApi::KernelContext_GetGPUComputeStream(const OrtKernelContext* context) {
1793
+ void* out;
1794
+ Ort::ThrowOnError(api_.KernelContext_GetGPUComputeStream(context, &out));
1795
+ return out;
1796
+ }
1797
+
1798
+ inline OrtOpAttr* CustomOpApi::CreateOpAttr(_In_ const char* name,
1799
+ _In_ const void* data,
1800
+ _In_ int len,
1801
+ _In_ OrtOpAttrType type) {
1802
+ OrtOpAttr* op_attr{};
1803
+ Ort::ThrowOnError(api_.CreateOpAttr(name, data, len, type, &op_attr));
1804
+ return op_attr;
1805
+ }
1806
+
1807
+ inline void CustomOpApi::ReleaseOpAttr(_Frees_ptr_opt_ OrtOpAttr* op_attr) {
1808
+ api_.ReleaseOpAttr(op_attr);
1809
+ }
1810
+
1811
+ inline OrtOp* CustomOpApi::CreateOp(_In_ const OrtKernelInfo* info,
1812
+ _In_ const char* op_name,
1813
+ _In_ const char* domain,
1814
+ _In_ int version,
1815
+ _In_opt_ const char** type_constraint_names,
1816
+ _In_opt_ const ONNXTensorElementDataType* type_constraint_values,
1817
+ _In_opt_ int type_constraint_count,
1818
+ _In_opt_ const OrtOpAttr* const* attr_values,
1819
+ _In_opt_ int attr_count,
1820
+ _In_ int input_count,
1821
+ _In_ int output_count) {
1822
+ OrtOp* ort_op{};
1823
+ Ort::ThrowOnError(api_.CreateOp(info, op_name, domain, version, type_constraint_names, type_constraint_values,
1824
+ type_constraint_count, attr_values, attr_count, input_count, output_count, &ort_op));
1825
+ return ort_op;
1826
+ }
1827
+
1828
+ inline void CustomOpApi::InvokeOp(_In_ const OrtKernelContext* context,
1829
+ _In_ const OrtOp* ort_op,
1830
+ _In_ const OrtValue* const* input_values,
1831
+ _In_ int input_count,
1832
+ _Inout_ OrtValue* const* output_values,
1833
+ _In_ int output_count) {
1834
+ Ort::ThrowOnError(api_.InvokeOp(context, ort_op, input_values, input_count, output_values, output_count));
1835
+ }
1836
+
1837
+ inline void CustomOpApi::ReleaseOp(_Frees_ptr_opt_ OrtOp* ort_op) {
1838
+ api_.ReleaseOp(ort_op);
1839
+ }
1840
+
1841
+ inline OrtKernelInfo* CustomOpApi::CopyKernelInfo(_In_ const OrtKernelInfo* info) {
1842
+ OrtKernelInfo* info_copy{};
1843
+ Ort::ThrowOnError(api_.CopyKernelInfo(info, &info_copy));
1844
+ return info_copy;
1845
+ }
1846
+
1847
+ inline void CustomOpApi::ReleaseKernelInfo(_Frees_ptr_opt_ OrtKernelInfo* info_copy) {
1848
+ api_.ReleaseKernelInfo(info_copy);
1849
+ }
1850
+
1851
+ inline std::vector<std::string> GetAvailableProviders() {
1852
+ char** providers;
1853
+ int len;
1854
+
1855
+ auto release_fn = [&len](char** providers) {
1856
+ // This should always return nullptr.
1857
+ ThrowOnError(GetApi().ReleaseAvailableProviders(providers, len));
1858
+ };
1859
+
1860
+ ThrowOnError(GetApi().GetAvailableProviders(&providers, &len));
1861
+ std::unique_ptr<char*, decltype(release_fn)> guard(providers, release_fn);
1862
+ std::vector<std::string> available_providers;
1863
+ available_providers.reserve(static_cast<size_t>(len));
1864
+ for (int i = 0; i < len; ++i) {
1865
+ available_providers.emplace_back(providers[i]);
1866
+ }
1867
+ return available_providers;
1868
+ }
1869
+
1870
+ template <typename TOp, typename TKernel>
1871
+ void CustomOpBase<TOp, TKernel>::GetSessionConfigs(std::unordered_map<std::string, std::string>& out,
1872
+ ConstSessionOptions options) const {
1873
+ const TOp* derived = static_cast<const TOp*>(this);
1874
+ std::vector<std::string> keys = derived->GetSessionConfigKeys();
1875
+
1876
+ out.reserve(keys.size());
1877
+
1878
+ std::string config_entry_key = detail::MakeCustomOpConfigEntryKey(derived->GetName(), "");
1879
+ const size_t prefix_size = config_entry_key.length();
1880
+
1881
+ for (const auto& key : keys) {
1882
+ config_entry_key.resize(prefix_size);
1883
+ config_entry_key.append(key);
1884
+ out[key] = options.GetConfigEntryOrDefault(config_entry_key.c_str(), "");
1885
+ }
1886
+ }
1887
+
1888
+ } // namespace Ort
1.15.1/onnxruntime.xcframework/Info.plist ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
3
+ <plist version="1.0">
4
+ <dict>
5
+ <key>AvailableLibraries</key>
6
+ <array>
7
+ <dict>
8
+ <key>LibraryIdentifier</key>
9
+ <string>ios-arm64_x86_64-simulator</string>
10
+ <key>LibraryPath</key>
11
+ <string>onnxruntime.a</string>
12
+ <key>SupportedArchitectures</key>
13
+ <array>
14
+ <string>arm64</string>
15
+ <string>x86_64</string>
16
+ </array>
17
+ <key>SupportedPlatform</key>
18
+ <string>ios</string>
19
+ <key>SupportedPlatformVariant</key>
20
+ <string>simulator</string>
21
+ </dict>
22
+ <dict>
23
+ <key>LibraryIdentifier</key>
24
+ <string>ios-arm64</string>
25
+ <key>LibraryPath</key>
26
+ <string>onnxruntime.a</string>
27
+ <key>SupportedArchitectures</key>
28
+ <array>
29
+ <string>arm64</string>
30
+ </array>
31
+ <key>SupportedPlatform</key>
32
+ <string>ios</string>
33
+ </dict>
34
+ </array>
35
+ <key>CFBundlePackageType</key>
36
+ <string>XFWK</string>
37
+ <key>XCFrameworkFormatVersion</key>
38
+ <string>1.0</string>
39
+ </dict>
40
+ </plist>
1.15.1/onnxruntime.xcframework/ios-arm64/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd716418bdd0b9b7df2d65701d075c7698aa54bcdfe811c360c79fa61e7f3e3b
3
+ size 57978208
1.15.1/onnxruntime.xcframework/ios-arm64_x86_64-simulator/onnxruntime.a ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae51a98ef737bcfade20f599fb589141ed1e98239f41cd70c59658ef828dfd14
3
+ size 118264080