onnx._custom_element_types¶
This module defines custom dtypes not supported by numpy.
Function onnx.numpy_helper.from_array()
and onnx.numpy_helper.to_array()
are using them
to convert arrays from/to these types.
Class onnx.reference.ReferenceEvalutor
also uses them.
To create such an array for unit test for example, it is convenient to write
something like the following:
import numpy as np
from onnx import TensorProto
from onnx.reference.ops.op_cast import Cast_19 as Cast
tensor_bfloat16 = Cast.eval(np.array([0, 1], dtype=np.float32), to=TensorProto.BFLOAT16)
The numpy representation dtypes used below are meant for internal use. They may change in the future based on the industry standardization of these numpy types.
- onnx._custom_element_types.bfloat16 = dtype((numpy.uint16, [('bfloat16', '<u2')]))¶
Defines a bfloat16 as a uint16.
- onnx._custom_element_types.float4e2m1 = dtype((numpy.uint8, [('float4e2m1', 'u1')]))¶
Defines float 4 e2m1 type, see See Float stored in 4 bits for technical details. Do note that one integer is stored using a byte and therefore is twice bigger than its onnx size.
- onnx._custom_element_types.float8e4m3fn = dtype((numpy.uint8, [('e4m3fn', 'u1')]))¶
Defines float 8 e4m3fn type, see See Float stored in 8 bits for technical details.
- onnx._custom_element_types.float8e4m3fnuz = dtype((numpy.uint8, [('e4m3fnuz', 'u1')]))¶
Defines float 8 e4m3fnuz type, see See Float stored in 8 bits for technical details.
- onnx._custom_element_types.float8e5m2 = dtype((numpy.uint8, [('e5m2', 'u1')]))¶
Defines float 8 e5m2 type, see See Float stored in 8 bits for technical details.
- onnx._custom_element_types.float8e5m2fnuz = dtype((numpy.uint8, [('e5m2fnuz', 'u1')]))¶
Defines float 8 e5m2fnuz type, see See Float stored in 8 bits for technical details.
- onnx._custom_element_types.int4 = dtype((numpy.int8, [('int4', 'i1')]))¶
Defines int4, see See 4 bit integer types for technical details. Do note that one integer is stored using a byte and therefore is twice bigger than its onnx size.
- onnx._custom_element_types.uint4 = dtype((numpy.uint8, [('uint4', 'u1')]))¶
Defines int4, see See 4 bit integer types for technical details. Do note that one integer is stored using a byte and therefore is twice bigger than its onnx size.