66using System . Buffers ;
77using System . Collections . Generic ;
88using System . Diagnostics ;
9+ using System . Linq ;
10+ using System . Reflection ;
911using System . Runtime . CompilerServices ;
1012using System . Runtime . InteropServices ;
1113using System . Text ;
@@ -232,11 +234,7 @@ public DotnetTensors.ReadOnlyTensorSpan<T> GetTensorDataAsTensorSpan<T>() where
232234
233235 var typeSpan = MemoryMarshal . Cast < byte , T > ( byteSpan ) ;
234236 var shape = GetTypeInfo ( ) . TensorTypeAndShapeInfo . Shape ;
235- var nArray = new nint [ shape . Length ] ;
236- for ( int i = 0 ; i < shape . Length ; i ++ )
237- {
238- nArray [ i ] = ( nint ) shape [ i ] ;
239- }
237+ var nArray = shape . Select ( x => ( nint ) x ) . ToArray ( ) ;
240238
241239 return new DotnetTensors . ReadOnlyTensorSpan < T > ( typeSpan , nArray , [ ] ) ;
242240 }
@@ -283,11 +281,7 @@ public DotnetTensors.TensorSpan<T> GetTensorMutableDataAsTensorSpan<T>() where T
283281
284282 var typeSpan = MemoryMarshal . Cast < byte , T > ( byteSpan ) ;
285283 var shape = GetTypeInfo ( ) . TensorTypeAndShapeInfo . Shape ;
286- var nArray = new nint [ shape . Length ] ;
287- for ( int i = 0 ; i < shape . Length ; i ++ )
288- {
289- nArray [ i ] = ( nint ) shape [ i ] ;
290- }
284+ var nArray = shape . Select ( x => ( nint ) x ) . ToArray ( ) ;
291285
292286 return new DotnetTensors . TensorSpan < T > ( typeSpan , nArray , [ ] ) ;
293287 }
@@ -314,11 +308,7 @@ public DotnetTensors.TensorSpan<byte> GetTensorSpanMutableRawData<T>() where T :
314308 var byteSpan = GetTensorBufferRawData ( typeof ( T ) ) ;
315309
316310 var shape = GetTypeInfo ( ) . TensorTypeAndShapeInfo . Shape ;
317- var nArray = new nint [ shape . Length ] ;
318- for ( int i = 0 ; i < shape . Length ; i ++ )
319- {
320- nArray [ i ] = ( nint ) shape [ i ] ;
321- }
311+ var nArray = shape . Select ( x => ( nint ) x ) . ToArray ( ) ;
322312
323313 return new DotnetTensors . TensorSpan < byte > ( byteSpan , nArray , [ ] ) ;
324314 }
@@ -716,7 +706,10 @@ public static OrtValue CreateTensorValueFromDotnetTensorObject<T>(DotnetTensors.
716706 }
717707 unsafe
718708 {
719- GCHandle handle = GCHandle . Alloc ( tensor , GCHandleType . Pinned ) ;
709+ var field = tensor . GetType ( ) . GetFields ( BindingFlags . Instance | BindingFlags . NonPublic ) . Where ( x => x . Name == "_values" ) . FirstOrDefault ( ) ;
710+ var backingData = ( T [ ] ) field . GetValue ( tensor ) ;
711+ GCHandle handle = GCHandle . Alloc ( backingData , GCHandleType . Pinned ) ;
712+ //GCHandle handle = GCHandle.Alloc(tensor.GetPinnableReference(), GCHandleType.Pinned);
720713 var memHandle = new MemoryHandle ( Unsafe . AsPointer ( ref tensor . GetPinnableReference ( ) ) , handle ) ;
721714
722715 try
@@ -729,11 +722,7 @@ public static OrtValue CreateTensorValueFromDotnetTensorObject<T>(DotnetTensors.
729722
730723 var bufferLengthInBytes = tensor . FlattenedLength * sizeof ( T ) ;
731724
732- var shape = new long [ tensor . Rank ] ;
733- for ( int i = 0 ; i < shape . Length ; i ++ )
734- {
735- shape [ i ] = tensor . Lengths [ i ] ;
736- }
725+ var shape = tensor . Lengths . ToArray ( ) . Select ( x => ( long ) x ) . ToArray ( ) ;
737726
738727 var typeInfo = TensorBase . GetTypeInfo ( typeof ( T ) ) ??
739728 throw new OnnxRuntimeException ( ErrorCode . InvalidArgument , $ "Tensor of type: { typeof ( T ) } is not supported") ;
0 commit comments