onnx._custom_element_types

This module defines custom dtypes not supported by numpy. Function onnx.numpy_helper.from_array() and onnx.numpy_helper.to_array() are using them to convert arrays from/to these types. Class onnx.reference.ReferenceEvalutor also uses them. To create such an array for unit test for example, it is convenient to write something like the following:

import numpy as np
from onnx import TensorProto
from onnx.reference.ops.op_cast import Cast_19 as Cast

tensor_bfloat16 = Cast.eval(np.array([0, 1], dtype=np.float32), to=TensorProto.BFLOAT16)

The numpy representation dtypes used below are meant for internal use. They may change in the future based on the industry standardization of these numpy types.

onnx._custom_element_types.bfloat16 = dtype((numpy.uint16, [('bfloat16', '<u2')]))

Defines a bfloat16 as a uint16.

onnx._custom_element_types.float4e2m1 = dtype((numpy.uint8, [('float4e2m1', 'u1')]))

Defines float 4 e2m1 type, see See Float stored in 4 bits for technical details. Do note that one integer is stored using a byte and therefore is twice bigger than its onnx size.

onnx._custom_element_types.float8e4m3fn = dtype((numpy.uint8, [('e4m3fn', 'u1')]))

Defines float 8 e4m3fn type, see See Float stored in 8 bits for technical details.

onnx._custom_element_types.float8e4m3fnuz = dtype((numpy.uint8, [('e4m3fnuz', 'u1')]))

Defines float 8 e4m3fnuz type, see See Float stored in 8 bits for technical details.

onnx._custom_element_types.float8e5m2 = dtype((numpy.uint8, [('e5m2', 'u1')]))

Defines float 8 e5m2 type, see See Float stored in 8 bits for technical details.

onnx._custom_element_types.float8e5m2fnuz = dtype((numpy.uint8, [('e5m2fnuz', 'u1')]))

Defines float 8 e5m2fnuz type, see See Float stored in 8 bits for technical details.

onnx._custom_element_types.int4 = dtype((numpy.int8, [('int4', 'i1')]))

Defines int4, see See 4 bit integer types for technical details. Do note that one integer is stored using a byte and therefore is twice bigger than its onnx size.

onnx._custom_element_types.uint4 = dtype((numpy.uint8, [('uint4', 'u1')]))

Defines int4, see See 4 bit integer types for technical details. Do note that one integer is stored using a byte and therefore is twice bigger than its onnx size.