# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
import pathlib
import re
import sys
import uuid
from itertools import chain
from typing import TYPE_CHECKING
import onnx.checker as onnx_checker
import onnx.onnx_cpp2py_export.checker as c_checker
from onnx.onnx_pb import (
AttributeProto,
FunctionProto,
GraphProto,
ModelProto,
TensorProto,
)
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
[docs]
class ExternalDataInfo:
def __init__(self, tensor: TensorProto) -> None:
self.location = ""
self.offset = None
self.length = None
self.checksum = None
self.basepath = ""
for entry in tensor.external_data:
setattr(self, entry.key, entry.value)
if self.offset:
self.offset = int(self.offset)
if self.length:
self.length = int(self.length)
def _validate_external_data_path(
base_dir: str,
data_path: str,
tensor_name: str,
*,
check_exists: bool = True,
) -> str:
"""Validate that an external data path is safe to open.
Performs three security checks:
1. Canonical path containment — resolved path must stay within base_dir.
2. Symlink rejection — final-component symlinks are not allowed.
3. Hardlink count — files with multiple hard links are rejected.
Args:
base_dir: The model base directory that data_path must be contained in.
data_path: The external data file path to validate.
tensor_name: Tensor name for error messages.
check_exists: If True (default), check hardlink count. Set to False
for save-side paths where the file may not exist yet.
Returns:
The validated data_path (unchanged).
Raises:
onnx.checker.ValidationError: If any security check fails.
"""
real_base = os.path.realpath(base_dir)
real_path = os.path.realpath(data_path)
if not real_path.startswith(real_base + os.sep) and real_path != real_base:
raise onnx_checker.ValidationError(
f"Tensor {tensor_name!r} external data path resolves to "
f"{real_path!r} which is outside the model directory {real_base!r}."
)
if os.path.islink(data_path):
raise onnx_checker.ValidationError(
f"Tensor {tensor_name!r} external data path {data_path!r} "
f"is a symbolic link, which is not allowed for security reasons."
)
if check_exists and os.path.exists(data_path) and os.stat(data_path).st_nlink > 1:
raise onnx_checker.ValidationError(
f"Tensor {tensor_name!r} external data path {data_path!r} "
f"has multiple hard links, which is not allowed for security reasons."
)
return data_path
[docs]
def load_external_data_for_tensor(tensor: TensorProto, base_dir: str) -> None:
"""Loads data from an external file for tensor.
Ideally TensorProto should not hold any raw data but if it does it will be ignored.
Arguments:
tensor: a TensorProto object.
base_dir: directory that contains the external data.
"""
info = ExternalDataInfo(tensor)
external_data_file_path = c_checker._resolve_external_data_location( # type: ignore[attr-defined]
base_dir, info.location, tensor.name
)
# Security checks (symlink, containment, hardlink) already performed
# by C++ _resolve_external_data_location() above.
# Use O_NOFOLLOW where available as defense-in-depth for symlink protection
open_flags = os.O_RDONLY
if hasattr(os, "O_NOFOLLOW"):
open_flags |= os.O_NOFOLLOW
fd = os.open(external_data_file_path, open_flags)
with os.fdopen(fd, "rb") as data_file:
if info.offset:
data_file.seek(info.offset)
if info.length:
tensor.raw_data = data_file.read(info.length)
else:
tensor.raw_data = data_file.read()
[docs]
def load_external_data_for_model(model: ModelProto, base_dir: str) -> None:
"""Loads external tensors into model
Arguments:
model: ModelProto to load external data to
base_dir: directory that contains external data
"""
for tensor in _get_all_tensors(model):
if uses_external_data(tensor):
load_external_data_for_tensor(tensor, base_dir)
# After loading raw_data from external_data, change the state of tensors
tensor.data_location = TensorProto.DEFAULT
# and remove external data
del tensor.external_data[:]
[docs]
def set_external_data(
tensor: TensorProto,
location: str,
offset: int | None = None,
length: int | None = None,
checksum: str | None = None,
basepath: str | None = None,
) -> None:
if not tensor.HasField("raw_data"):
raise ValueError(
"Tensor "
+ tensor.name
+ "does not have raw_data field. Cannot set external data for this tensor."
)
del tensor.external_data[:]
tensor.data_location = TensorProto.EXTERNAL
for k, v in {
"location": location,
"offset": int(offset) if offset is not None else None,
"length": int(length) if length is not None else None,
"checksum": checksum,
"basepath": basepath,
}.items():
if v is not None:
entry = tensor.external_data.add()
entry.key = k
entry.value = str(v)
[docs]
def convert_model_to_external_data(
model: ModelProto,
all_tensors_to_one_file: bool = True,
location: str | None = None,
size_threshold: int = 1024,
convert_attribute: bool = False,
) -> None:
"""Call to set all tensors with raw data as external data. This call should precede 'save_model'.
'save_model' saves all the tensors data as external data after calling this function.
Arguments:
model (ModelProto): Model to be converted.
all_tensors_to_one_file (bool): If true, save all tensors to one external file specified by location.
If false, save each tensor to a file named with the tensor name.
location: specify the external file relative to the model that all tensors to save to.
Path is relative to the model path.
If not specified, will use the model name.
size_threshold: Threshold for size of data. Only when tensor's data is >= the size_threshold
it will be converted to external data. To convert every tensor with raw data to external data set size_threshold=0.
convert_attribute (bool): If true, convert all tensors to external data
If false, convert only non-attribute tensors to external data
Raise:
ValueError: If location is not a relative path.
FileExistsError: If a file already exists in location.
"""
tensors = _get_initializer_tensors(model)
if convert_attribute:
tensors = _get_all_tensors(model)
if all_tensors_to_one_file:
file_name = str(uuid.uuid1()) + ".data"
if location:
if os.path.isabs(location):
raise ValueError(
"location must be a relative path that is relative to the model path."
)
if os.path.exists(location):
raise FileExistsError(f"External data file exists in {location}.")
file_name = location
for tensor in tensors:
if (
tensor.HasField("raw_data")
and sys.getsizeof(tensor.raw_data) >= size_threshold
):
set_external_data(tensor, file_name)
else:
for tensor in tensors:
if (
tensor.HasField("raw_data")
and sys.getsizeof(tensor.raw_data) >= size_threshold
):
tensor_location = tensor.name
if not _is_valid_filename(tensor_location):
tensor_location = str(uuid.uuid1())
set_external_data(tensor, tensor_location)
[docs]
def convert_model_from_external_data(model: ModelProto) -> None:
"""Call to set all tensors which use external data as embedded data.
save_model saves all the tensors data as embedded data after
calling this function.
Arguments:
model (ModelProto): Model to be converted.
"""
for tensor in _get_all_tensors(model):
if uses_external_data(tensor):
if not tensor.HasField("raw_data"):
raise ValueError("raw_data field doesn't exist.")
del tensor.external_data[:]
tensor.data_location = TensorProto.DEFAULT
[docs]
def save_external_data(tensor: TensorProto, base_path: str) -> None:
"""Writes tensor data to an external file according to information in the `external_data` field.
The function checks the external is a valid name and located in folder `base_path`.
Arguments:
tensor (TensorProto): Tensor object to be serialized
base_path: System path of a folder where tensor data is to be stored
Raises:
ValueError: If the external file is invalid.
"""
info = ExternalDataInfo(tensor)
# Let's check the tensor location is valid.
location_path = pathlib.Path(info.location)
if location_path.is_absolute():
raise onnx_checker.ValidationError(
f"Tensor {tensor.name!r} is external and must not be defined "
f"with an absolute path such as {info.location!r}, "
f"base_path={base_path!r}"
)
if ".." in location_path.parts:
raise onnx_checker.ValidationError(
f"Tensor {tensor.name!r} is external and must be placed in folder "
f"{base_path!r}, '..' is not needed in {info.location!r}."
)
if location_path.name in (".", ".."):
raise onnx_checker.ValidationError(
f"Tensor {tensor.name!r} is external and its name "
f"{info.location!r} is invalid."
)
external_data_file_path = os.path.join(base_path, info.location)
# C++ _resolve_external_data_location() cannot be used on save path
# (file may not exist yet), so Python performs its own security validation.
_validate_external_data_path(
base_path, external_data_file_path, tensor.name, check_exists=True
)
# Retrieve the tensor's data from raw_data or load external file
if not tensor.HasField("raw_data"):
raise onnx_checker.ValidationError("raw_data field doesn't exist.")
# Atomic file creation with symlink protection (O_NOFOLLOW where available)
open_flags = os.O_CREAT | os.O_RDWR
if hasattr(os, "O_NOFOLLOW"):
open_flags |= os.O_NOFOLLOW
# Use restrictive permissions: owner read/write only (0o600)
fd = os.open(external_data_file_path, open_flags, 0o600)
# Open file for reading and writing at random locations ('r+b')
with os.fdopen(fd, "r+b") as data_file:
data_file.seek(0, 2)
if info.offset is not None:
# Pad file to required offset if needed
file_size = data_file.tell()
if info.offset > file_size:
data_file.write(b"\0" * (info.offset - file_size))
data_file.seek(info.offset)
offset = data_file.tell()
data_file.write(tensor.raw_data)
set_external_data(tensor, info.location, offset, data_file.tell() - offset)
def _get_all_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
"""Scan an ONNX model for all tensors and return as an iterator."""
return chain(
_get_initializer_tensors(onnx_model_proto),
_get_attribute_tensors(onnx_model_proto),
)
def _recursive_attribute_processor(
attribute: AttributeProto, func: Callable[[GraphProto], Iterable[TensorProto]]
) -> Iterable[TensorProto]:
"""Create an iterator through processing ONNX model attributes with functor."""
if attribute.type == AttributeProto.GRAPH:
yield from func(attribute.g)
if attribute.type == AttributeProto.GRAPHS:
for graph in attribute.graphs:
yield from func(graph)
def _get_initializer_tensors_from_graph(
graph_or_function: GraphProto | FunctionProto, /
) -> Iterable[TensorProto]:
"""Create an iterator of initializer tensors from ONNX model graph/function."""
if isinstance(graph_or_function, GraphProto):
yield from graph_or_function.initializer
for node in graph_or_function.node:
for attribute in node.attribute:
yield from _recursive_attribute_processor(
attribute, _get_initializer_tensors_from_graph
)
def _get_initializer_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
"""Create an iterator of initializer tensors from ONNX model."""
yield from _get_initializer_tensors_from_graph(onnx_model_proto.graph)
for function in onnx_model_proto.functions:
yield from _get_attribute_tensors_from_graph(function)
def _get_attribute_tensors_from_graph(
graph_or_function: GraphProto | FunctionProto, /
) -> Iterable[TensorProto]:
"""Create an iterator of tensors from node attributes of an ONNX model graph/function."""
for node in graph_or_function.node:
for attribute in node.attribute:
if attribute.HasField("t"):
yield attribute.t
yield from attribute.tensors
yield from _recursive_attribute_processor(
attribute, _get_attribute_tensors_from_graph
)
def _get_attribute_tensors(onnx_model_proto: ModelProto) -> Iterable[TensorProto]:
"""Create an iterator of tensors from node attributes of an ONNX model."""
yield from _get_attribute_tensors_from_graph(onnx_model_proto.graph)
for function in onnx_model_proto.functions:
yield from _get_attribute_tensors_from_graph(function)
def _is_valid_filename(filename: str) -> bool:
"""Utility to check whether the provided filename is valid."""
exp = re.compile('^[^<>:;,?"*|/]+$')
match = exp.match(filename)
return bool(match)
[docs]
def uses_external_data(tensor: TensorProto) -> bool:
"""Returns true if the tensor stores data in an external location."""
return (
tensor.HasField("data_location")
and tensor.data_location == TensorProto.EXTERNAL
)
[docs]
def remove_external_data_field(tensor: TensorProto, field_key: str) -> None:
"""Removes a field from a Tensor's external_data key-value store.
Modifies tensor object in place.
Arguments:
tensor (TensorProto): Tensor object from which value will be removed
field_key (string): The key of the field to be removed
"""
for i, field in enumerate(tensor.external_data):
if field.key == field_key:
del tensor.external_data[i]
[docs]
def write_external_data_tensors(model: ModelProto, filepath: str) -> ModelProto:
"""Serializes data for all the tensors which have data location set to TensorProto.External.
Note: This function also strips basepath information from all tensors' external_data fields.
Arguments:
model (ModelProto): Model object which is the source of tensors to serialize.
filepath: System path to the directory which should be treated as base path for external data.
Returns:
ModelProto: The modified model object.
"""
for tensor in _get_all_tensors(model):
# Writing to external data happens in 2 passes:
# 1. Tensors with raw data which pass the necessary conditions (size threshold etc) are marked for serialization
# 2. The raw data in these tensors is serialized to a file
# Thus serialize only if tensor has raw data and it was marked for serialization
if uses_external_data(tensor) and tensor.HasField("raw_data"):
save_external_data(tensor, filepath)
tensor.ClearField("raw_data")
return model