-
Notifications
You must be signed in to change notification settings - Fork 3.6k
TypeScript: Allow constructing float16 tensors using Float16Array #26742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This change updates the TypeScript definitions to allow constructing float16 tensors using Float16Array in environments where it is available. Runtime behavior remains unchanged (float16 is still represented as Uint16Array). - Introduce GlobalFloat16Array helper type to safely detect Float16Array without requiring global polyfills. - Add type-specific and inferred constructor overloads for float16. - No changes to runtime logic or public C APIs. This resolves compile-time errors when passing Float16Array to the Tensor constructor in the onnxruntime-web package.
|
@microsoft-github-policy-service agree |
|
I think you could reuse the onnxruntime/js/common/lib/type-helper.ts Line 27 in 790018d
|
|
Updated to use |
js/common/lib/tensor.ts
Outdated
|
|
||
| // Helper type: resolves to the instance type of `Float16Array` if it exists in the global scope, | ||
| // or `never` otherwise. Uses the shared TryGetGlobalType helper. | ||
| export type GlobalFloat16Array = TryGetGlobalType<'Float16Array', never>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it needs to be exported / part of the public API, just a local helper.
| export type GlobalFloat16Array = TryGetGlobalType<'Float16Array', never>; | |
| type GlobalFloat16Array = TryGetGlobalType<'Float16Array', never>; |
js/common/lib/tensor.ts
Outdated
| */ | ||
| new ( | ||
| type: 'float16', | ||
| data: Tensor.DataTypeMap['float16'] | GlobalFloat16Array | readonly number[], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, wouldn't it be easier to inline this helper into the DataTypeMap?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated DataTypeMap.float16 to inline TryGetGlobalType<'Float16Array'> and map it to the instance type via prototype, so Float16Array instances are accepted where supported without changing runtime behavior.
|
Thanks for working on this! Just to be clear, I'm not a maintainer on this repo, we'll need for one of them to review too. |
js/common/lib/tensor.ts
Outdated
| string: string[]; | ||
| bool: Uint8Array; | ||
| float16: Uint16Array; // Keep using Uint16Array until we have a concrete solution for float 16. | ||
| float16: Uint16Array | (TryGetGlobalType<'Float16Array'> extends { prototype: infer P } ? P : never); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TryGetGlobalType does the fallback for you btw, so this can be done much simpler.
| float16: Uint16Array | (TryGetGlobalType<'Float16Array'> extends { prototype: infer P } ? P : never); | |
| float16: Uint16Array | TryGetGlobalType<'Float16Array', never> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Updated DataTypeMap['float16'] to use TryGetGlobalType<'Float16Array', never> directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you removed manual overloads when converting to change to the DataMap, does the code from the original issue still work as expected without TS errors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just double-checked locally using the public dist/cjs entrypoint,
new Tensor(new Float16Array(...)) typechecks correctly with no TS errors.
Everything works as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh I wonder why they even need all those manual overloads if datatypemap is enough.
onnxruntime/js/common/lib/tensor.ts
Lines 257 to 361 in c6c72e3
| // #region CPU tensor - infer element types | |
| /** | |
| * Construct a new float32 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Float32Array, dims?: readonly number[]): TypedTensor<'float32'>; | |
| /** | |
| * Construct a new int8 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Int8Array, dims?: readonly number[]): TypedTensor<'int8'>; | |
| /** | |
| * Construct a new uint8 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Uint8Array, dims?: readonly number[]): TypedTensor<'uint8'>; | |
| /** | |
| * Construct a new uint8 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Uint8ClampedArray, dims?: readonly number[]): TypedTensor<'uint8'>; | |
| /** | |
| * Construct a new uint16 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Uint16Array, dims?: readonly number[]): TypedTensor<'uint16'>; | |
| /** | |
| * Construct a new int16 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Int16Array, dims?: readonly number[]): TypedTensor<'int16'>; | |
| /** | |
| * Construct a new int32 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Int32Array, dims?: readonly number[]): TypedTensor<'int32'>; | |
| /** | |
| * Construct a new int64 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: BigInt64Array, dims?: readonly number[]): TypedTensor<'int64'>; | |
| /** | |
| * Construct a new string tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: readonly string[], dims?: readonly number[]): TypedTensor<'string'>; | |
| /** | |
| * Construct a new bool tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: readonly boolean[], dims?: readonly number[]): TypedTensor<'bool'>; | |
| /** | |
| * Construct a new float64 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Float64Array, dims?: readonly number[]): TypedTensor<'float64'>; | |
| /** | |
| * Construct a new uint32 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: Uint32Array, dims?: readonly number[]): TypedTensor<'uint32'>; | |
| /** | |
| * Construct a new uint64 tensor object from the given data and dims. | |
| * | |
| * @param data - Specify the CPU tensor data. | |
| * @param dims - Specify the dimension of the tensor. If omitted, a 1-D tensor is assumed. | |
| */ | |
| new (data: BigUint64Array, dims?: readonly number[]): TypedTensor<'uint64'>; |
Oh well, perhaps an opportunity for future cleanup. Thanks for double-checking!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested locally and saw no errors. If I missed something, please let me know!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I believe you :)
Fixes #26741
This change updates the TypeScript definitions to allow constructing
float16tensors usingFloat16Arrayin environments where it is available. Runtime behavior remains unchanged (float16is still represented internally asUint16Array).GlobalFloat16Arrayhelper type to safely detectFloat16Arraywithout requiring global polyfills.float16.This resolves compile-time errors when passing
Float16Arrayto theTensorconstructor in theonnxruntime-webpackage.Description
This PR enhances the TypeScript typings for
float16tensors within the ONNX Runtime JavaScript API:GlobalFloat16Array, a conditional utility type that resolves to the instance type ofFloat16Arrayonly when available.Uint16Array(existing behavior),Float16Array(new behavior, when supported by the JS environment),readonly number[].new Tensor(new Float16Array(...)).Float16Arraywithout encountering type errors.Internally, ONNX Runtime continues to treat
float16data asUint16Array, so runtime behavior is unchanged.Motivation and Context
Modern JavaScript runtimes (browsers and Node versions) have begun introducing native
Float16Arraysupport. Developers using ONNX Runtime in TypeScript projects may attempt to constructfloat16tensors using: