Skip to content

Commit 6724b3b

Browse files
committed
changes from PR comments
1 parent 645b8b6 commit 6724b3b

File tree

5 files changed

+98
-303
lines changed

5 files changed

+98
-303
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/FixedBufferOnnxValue.shared.cs

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,12 @@
44
using Microsoft.ML.OnnxRuntime.Tensors;
55
using System;
66

7-
#if NET8_0_OR_GREATER
8-
using DotnetTensors = System.Numerics.Tensors;
9-
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
10-
#endif
11-
127
namespace Microsoft.ML.OnnxRuntime
138
{
149
/// <summary>
1510
/// This is a legacy class that is kept for backward compatibility.
1611
/// Use OrtValue based API.
17-
///
12+
///
1813
/// Represents an OrtValue with its underlying buffer pinned
1914
/// </summary>
2015
public class FixedBufferOnnxValue : IDisposable
@@ -44,22 +39,6 @@ public static FixedBufferOnnxValue CreateFromTensor<T>(Tensor<T> value)
4439
return new FixedBufferOnnxValue(ref ortValue, OnnxValueType.ONNX_TYPE_TENSOR, elementType);
4540
}
4641

47-
#if NET8_0_OR_GREATER
48-
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
49-
/// <summary>
50-
/// Creates a <see cref="FixedBufferOnnxValue"/> object from the tensor and pins its underlying buffer.
51-
/// </summary>
52-
/// <typeparam name="T"></typeparam>
53-
/// <param name="value"></param>
54-
/// <returns>a disposable instance of FixedBufferOnnxValue</returns>
55-
public static FixedBufferOnnxValue CreateFromDotnetTensor<T>(DotnetTensors.Tensor<T> value) where T : unmanaged
56-
{
57-
var ortValue = OrtValue.CreateTensorValueFromDotnetTensorObject<T>(value);
58-
return new FixedBufferOnnxValue(ref ortValue, OnnxValueType.ONNX_TYPE_TENSOR, TensorBase.GetTypeInfo(typeof(T)).ElementType);
59-
}
60-
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
61-
#endif
62-
6342
/// <summary>
6443
/// This is a factory method that creates a disposable instance of FixedBufferOnnxValue
6544
/// on top of a buffer. Internally, it will pin managed buffer and will create
@@ -83,7 +62,7 @@ public static FixedBufferOnnxValue CreateFromDotnetTensor<T>(DotnetTensors.Tenso
8362
/// Here is an example of using a 3rd party library class for processing float16/bfloat16.
8463
/// Currently, to pass tensor data and create a tensor one must copy data to Float16/BFloat16 structures
8564
/// so DenseTensor can recognize it.
86-
///
65+
///
8766
/// If you are using a library that has a class Half and it is blittable, that is its managed in memory representation
8867
/// matches native one and its size is 16-bits, you can use the following conceptual example
8968
/// to feed/fetch data for inference using Half array. This allows you to avoid copying data from your Half[] to Float16[]
@@ -94,7 +73,7 @@ public static FixedBufferOnnxValue CreateFromDotnetTensor<T>(DotnetTensors.Tenso
9473
/// var input_shape = new long[] {input.Length};
9574
/// Half[] output = new Half[40]; // Whatever the expected len/shape is must match
9675
/// var output_shape = new long[] {output.Length};
97-
///
76+
///
9877
/// var memInfo = OrtMemoryInfo.DefaultInstance; // CPU
9978
///
10079
/// using(var fixedBufferInput = FixedBufferOnnxvalue.CreateFromMemory{Half}(memInfo,

csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,6 @@
66
using System.Collections.Generic;
77
using System.Diagnostics;
88
using System.Linq;
9-
using System.Reflection;
10-
11-
12-
#if NET8_0_OR_GREATER
13-
using DotnetTensors = System.Numerics.Tensors;
14-
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
15-
#endif
169

1710
namespace Microsoft.ML.OnnxRuntime
1811
{
@@ -173,41 +166,13 @@ private static OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata el
173166
/// <exception cref="OnnxRuntimeException"></exception>
174167
private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta)
175168
{
176-
#if NET8_0_OR_GREATER
177-
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
178-
if (node.Value is not TensorBase && node.Value.GetType().GetGenericTypeDefinition() != typeof(DotnetTensors.Tensor<>))
179-
{
180-
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
181-
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
182-
}
183-
184-
OrtValue ortValue;
185-
TensorElementType elementType;
186-
187-
if (node.Value is TensorBase)
188-
{
189-
ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out elementType);
190-
}
191-
else
192-
{
193-
MethodInfo method = typeof(OrtValue).GetMethod(nameof(OrtValue.CreateTensorValueFromDotnetTensorObject), BindingFlags.Static | BindingFlags.Public);
194-
Type tensorType = node.Value.GetType().GetGenericArguments()[0];
195-
MethodInfo generic = method.MakeGenericMethod(tensorType);
196-
ortValue = (OrtValue)generic.Invoke(null, [node.Value]);
197-
elementType = TensorBase.GetTypeInfo(tensorType).ElementType;
198-
}
199-
200-
201-
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
202-
#else
203169
if (node.Value is not TensorBase)
204170
{
205171
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
206172
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
207173
}
208-
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);
209174

210-
#endif
175+
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);
211176
try
212177
{
213178
if (elementType != elementMeta.ElementDataType)
@@ -226,3 +191,4 @@ private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata
226191
}
227192
}
228193
}
194+

csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs

Lines changed: 12 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88
using System.Diagnostics;
99
using System.Linq;
1010

11-
#if NET8_0_OR_GREATER
12-
using DotnetTensors = System.Numerics.Tensors;
13-
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
14-
#endif
15-
1611
namespace Microsoft.ML.OnnxRuntime
1712
{
1813
/// <summary>
@@ -35,37 +30,37 @@ internal MapHelper(TensorBase keys, TensorBase values)
3530
/// <summary>
3631
/// This is a legacy class that is kept for backward compatibility.
3732
/// Use OrtValue based API.
38-
///
39-
/// The class associates a name with an Object.
33+
///
34+
/// The class associates a name with an Object.
4035
/// The name of the class is a misnomer, it does not hold any Onnx values,
4136
/// just managed representation of them.
42-
///
37+
///
4338
/// The class is currently used as both inputs and outputs. Because it is non-
4439
/// disposable, it can not hold on to any native objects.
45-
///
40+
///
4641
/// When used as input, we temporarily create OrtValues that map managed inputs
4742
/// directly. Thus we are able to avoid copying of contiguous data.
48-
///
43+
///
4944
/// For outputs, tensor buffers works the same as input, providing it matches
5045
/// the expected output shape. For other types (maps and sequences) we create a copy of the data.
5146
/// This is because, the class is not Disposable and it is a public interface, thus it can not own
5247
/// the underlying OrtValues that must be destroyed before Run() returns.
53-
///
48+
///
5449
/// To avoid data copying on output, use DisposableNamedOnnxValue class that is returned from Run() methods.
5550
/// This provides access to the native memory tensors and avoids copying.
56-
///
51+
///
5752
/// It is a recursive structure that may contain Tensors (base case)
5853
/// Other sequences and maps. Although the OnnxValueType is exposed,
5954
/// the caller is supposed to know the actual data type contained.
60-
///
55+
///
6156
/// The convention is that for tensors, it would contain a DenseTensor{T} instance or
6257
/// anything derived from Tensor{T}.
63-
///
58+
///
6459
/// For sequences, it would contain a IList{T} where T is an instance of NamedOnnxValue that
6560
/// would contain a tensor or another type.
66-
///
61+
///
6762
/// For Maps, it would contain a IDictionary{K, V} where K,V are primitive types or strings.
68-
///
63+
///
6964
/// </summary>
7065
public class NamedOnnxValue
7166
{
@@ -145,23 +140,6 @@ public static NamedOnnxValue CreateFromTensor<T>(string name, Tensor<T> value)
145140
return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR);
146141
}
147142

148-
#if NET8_0_OR_GREATER
149-
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
150-
/// <summary>
151-
/// This is a factory method that instantiates NamedOnnxValue
152-
/// and associated name with an instance of a Tensor<typeparamref name="T"/>
153-
/// </summary>
154-
/// <typeparam name="T"></typeparam>
155-
/// <param name="name">name</param>
156-
/// <param name="value">Tensor<typeparamref name="T"/></param>
157-
/// <returns></returns>
158-
public static NamedOnnxValue CreateFromDotnetTensor<T>(string name, DotnetTensors.Tensor<T> value)
159-
{
160-
return new NamedOnnxValue(name, value, OnnxValueType.ONNX_TYPE_TENSOR);
161-
}
162-
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
163-
#endif
164-
165143
/// <summary>
166144
/// This is a factory method that instantiates NamedOnnxValue.
167145
/// It would contain a sequence of elements
@@ -218,21 +196,6 @@ public Tensor<T> AsTensor<T>()
218196
return _value as Tensor<T>; // will return null if not castable
219197
}
220198

221-
222-
#if NET8_0_OR_GREATER
223-
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
224-
/// <summary>
225-
/// Try-get value as a Tensor&lt;T&gt;.
226-
/// </summary>
227-
/// <typeparam name="T">Type</typeparam>
228-
/// <returns>Tensor object if contained value is a Tensor. Null otherwise</returns>
229-
public DotnetTensors.Tensor<T> AsDotnetTensor<T>()
230-
{
231-
return _value as DotnetTensors.Tensor<T>; // will return null if not castable
232-
}
233-
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
234-
#endif
235-
236199
/// <summary>
237200
/// Try-get value as an Enumerable&lt;T&gt;.
238201
/// T is usually a NamedOnnxValue instance that may contain
@@ -303,7 +266,7 @@ internal virtual IntPtr OutputToOrtValueHandle(NodeMetadata metadata, out IDispo
303266
}
304267
}
305268

306-
throw new OnnxRuntimeException(ErrorCode.NotImplemented,
269+
throw new OnnxRuntimeException(ErrorCode.NotImplemented,
307270
$"Can not create output OrtValue for NamedOnnxValue '{metadata.OnnxValueType}' type." +
308271
$" Only tensors can be pre-allocated for outputs " +
309272
$" Use Run() overloads that return DisposableNamedOnnxValue to get access to all Onnx value types that may be returned as output.");

csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ public DotnetTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where
234234

235235
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
236236
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
237-
var nArray = shape.Select(x => (nint)x).ToArray();
237+
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));
238238

239239
return new DotnetTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []);
240240
}
@@ -281,7 +281,7 @@ public DotnetTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T
281281

282282
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
283283
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
284-
var nArray = shape.Select(x => (nint)x).ToArray();
284+
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));
285285

286286
return new DotnetTensors.TensorSpan<T>(typeSpan, nArray, []);
287287
}
@@ -308,7 +308,7 @@ public DotnetTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T :
308308
var byteSpan = GetTensorBufferRawData(typeof(T));
309309

310310
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
311-
var nArray = shape.Select(x => (nint)x).ToArray();
311+
nint[] nArray = Array.ConvertAll(shape, new Converter<long, nint>(x => (nint)x));
312312

313313
return new DotnetTensors.TensorSpan<byte>(byteSpan, nArray, []);
314314
}
@@ -720,8 +720,7 @@ public static OrtValue CreateTensorValueFromDotnetTensorObject<T>(DotnetTensors.
720720
}
721721

722722
var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T);
723-
724-
var shape = tensor.Lengths.ToArray().Select(x => (long)x).ToArray();
723+
long[] shape = Array.ConvertAll(tensor.Lengths.ToArray(), new Converter<nint, long>(x => (long)x));
725724

726725
var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
727726
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");

0 commit comments

Comments
 (0)