Skip to content

Commit 7d2e575

Browse files
committed
more changes
1 parent d4a7046 commit 7d2e575

File tree

4 files changed

+227
-38
lines changed

4 files changed

+227
-38
lines changed

csharp/src/Microsoft.ML.OnnxRuntime/ManagedProjections.shared.cs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
using System.Collections.Generic;
77
using System.Diagnostics;
88
using System.Linq;
9+
using System.Reflection;
10+
11+
12+
#if NET8_0
13+
using DotnetTensors = System.Numerics.Tensors;
14+
using TensorPrimitives = System.Numerics.Tensors.TensorPrimitives;
15+
#endif
916

1017
namespace Microsoft.ML.OnnxRuntime
1118
{
@@ -166,13 +173,41 @@ private static OrtValue CreateMapProjection(NamedOnnxValue node, NodeMetadata el
166173
/// <exception cref="OnnxRuntimeException"></exception>
167174
private static OrtValue CreateTensorProjection(NamedOnnxValue node, NodeMetadata elementMeta)
168175
{
169-
if (node.Value is not TensorBase)
176+
#if NET8_0
177+
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
178+
if (node.Value is not TensorBase && node.Value.GetType().GetGenericTypeDefinition() != typeof(DotnetTensors.Tensor<>))
170179
{
171180
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
172181
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
173182
}
174183

184+
OrtValue ortValue;
185+
TensorElementType elementType;
186+
187+
if (node.Value is TensorBase)
188+
{
189+
ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out elementType);
190+
}
191+
else
192+
{
193+
MethodInfo method = typeof(OrtValue).GetMethod(nameof(OrtValue.CreateTensorValueFromDotnetTensorObject), BindingFlags.Static | BindingFlags.Public);
194+
Type tensorType = node.Value.GetType().GetGenericArguments()[0];
195+
MethodInfo generic = method.MakeGenericMethod(tensorType);
196+
ortValue = (OrtValue)generic.Invoke(null, [node.Value]);
197+
elementType = TensorBase.GetTypeInfo(tensorType).ElementType;
198+
}
199+
200+
201+
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
202+
#else
203+
if (node.Value is not TensorBase)
204+
{
205+
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
206+
$"NamedOnnxValue contains: {node.Value.GetType()}, expecting a Tensor<T>");
207+
}
175208
OrtValue ortValue = OrtValue.CreateFromTensorObject(node.Value as TensorBase, out TensorElementType elementType);
209+
210+
#endif
176211
try
177212
{
178213
if (elementType != elementMeta.ElementDataType)

csharp/src/Microsoft.ML.OnnxRuntime/NamedOnnxValue.shared.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,21 @@ public Tensor<T> AsTensor<T>()
218218
return _value as Tensor<T>; // will return null if not castable
219219
}
220220

221+
222+
#if NET8_0
223+
#pragma warning disable SYSLIB5001 // System.Numerics.Tensors is only in preview so we can continue receiving API feedback
224+
/// <summary>
225+
/// Try-get value as a Tensor&lt;T&gt;.
226+
/// </summary>
227+
/// <typeparam name="T">Type</typeparam>
228+
/// <returns>Tensor object if contained value is a Tensor. Null otherwise</returns>
229+
public DotnetTensors.Tensor<T> AsDotnetTensor<T>()
230+
{
231+
return _value as DotnetTensors.Tensor<T>; // will return null if not castable
232+
}
233+
#pragma warning restore SYSLIB5001 // System.Numerics.Tensors is only in preview so it can continue receiving API feedback
234+
#endif
235+
221236
/// <summary>
222237
/// Try-get value as an Enumerable&lt;T&gt;.
223238
/// T is usually a NamedOnnxValue instance that may contain

csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
using System.Buffers;
77
using System.Collections.Generic;
88
using System.Diagnostics;
9+
using System.Linq;
10+
using System.Reflection;
911
using System.Runtime.CompilerServices;
1012
using System.Runtime.InteropServices;
1113
using System.Text;
@@ -232,11 +234,7 @@ public DotnetTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where
232234

233235
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
234236
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
235-
var nArray = new nint[shape.Length];
236-
for (int i = 0; i < shape.Length; i++)
237-
{
238-
nArray[i] = (nint)shape[i];
239-
}
237+
var nArray = shape.Select(x => (nint)x).ToArray();
240238

241239
return new DotnetTensors.ReadOnlyTensorSpan<T>(typeSpan, nArray, []);
242240
}
@@ -283,11 +281,7 @@ public DotnetTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T
283281

284282
var typeSpan = MemoryMarshal.Cast<byte, T>(byteSpan);
285283
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
286-
var nArray = new nint[shape.Length];
287-
for (int i = 0; i < shape.Length; i++)
288-
{
289-
nArray[i] = (nint)shape[i];
290-
}
284+
var nArray = shape.Select(x => (nint)x).ToArray();
291285

292286
return new DotnetTensors.TensorSpan<T>(typeSpan, nArray, []);
293287
}
@@ -314,11 +308,7 @@ public DotnetTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T :
314308
var byteSpan = GetTensorBufferRawData(typeof(T));
315309

316310
var shape = GetTypeInfo().TensorTypeAndShapeInfo.Shape;
317-
var nArray = new nint[shape.Length];
318-
for (int i = 0; i < shape.Length; i++)
319-
{
320-
nArray[i] = (nint)shape[i];
321-
}
311+
var nArray = shape.Select(x => (nint)x).ToArray();
322312

323313
return new DotnetTensors.TensorSpan<byte>(byteSpan, nArray, []);
324314
}
@@ -716,7 +706,10 @@ public static OrtValue CreateTensorValueFromDotnetTensorObject<T>(DotnetTensors.
716706
}
717707
unsafe
718708
{
719-
GCHandle handle = GCHandle.Alloc(tensor, GCHandleType.Pinned);
709+
var field = tensor.GetType().GetFields(BindingFlags.Instance | BindingFlags.NonPublic).Where(x => x.Name == "_values").FirstOrDefault();
710+
var backingData = (T[])field.GetValue(tensor);
711+
GCHandle handle = GCHandle.Alloc(backingData, GCHandleType.Pinned);
712+
//GCHandle handle = GCHandle.Alloc(tensor.GetPinnableReference(), GCHandleType.Pinned);
720713
var memHandle = new MemoryHandle(Unsafe.AsPointer(ref tensor.GetPinnableReference()), handle);
721714

722715
try
@@ -729,11 +722,7 @@ public static OrtValue CreateTensorValueFromDotnetTensorObject<T>(DotnetTensors.
729722

730723
var bufferLengthInBytes = tensor.FlattenedLength * sizeof(T);
731724

732-
var shape = new long[tensor.Rank];
733-
for (int i = 0; i < shape.Length; i++)
734-
{
735-
shape[i] = tensor.Lengths[i];
736-
}
725+
var shape = tensor.Lengths.ToArray().Select(x => (long)x).ToArray();
737726

738727
var typeInfo = TensorBase.GetTypeInfo(typeof(T)) ??
739728
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, $"Tensor of type: {typeof(T)} is not supported");

0 commit comments

Comments
 (0)