Source code for zntrack.fields.base

import dataclasses
import functools
import typing as t

import znfields

from zntrack.config import (
    FIELD_TYPE,
    ZNTRACK_CACHE,
    ZNTRACK_FIELD_DUMP,
    ZNTRACK_FIELD_LOAD,
    ZNTRACK_FIELD_SUFFIX,
    ZNTRACK_INDEPENDENT_OUTPUT_TYPE,
    FieldTypes,
)
from zntrack.node import Node
from zntrack.plugins import base_getter, plugin_getter

FN_WITH_SUFFIX = t.Callable[["Node", str, str], t.Any]
FN_WITHOUT_SUFFIX = t.Callable[["Node", str], t.Any]


[docs] def field( default=dataclasses.MISSING, *, cache: bool = True, independent: bool = False, field_type: FieldTypes, dump_fn: FN_WITH_SUFFIX | FN_WITHOUT_SUFFIX | None = None, suffix: str | None = None, load_fn: FN_WITHOUT_SUFFIX | FN_WITH_SUFFIX | None = None, **kwargs, ): """Create a custom field. Arguments --------- default : t.Any Default value of the field. For an output field, this should be ``zntrack.NOT_AVAILABLE`` and should not be set during initialization. cache : bool Use the DVC cache for the field. independent : bool If the output of this field can be independent of the input. field_type : FieldTypes The type of the field. dump_fn : FN_WITH_SUFFIX | FN_WITHOUT_SUFFIX Function to dump the field. suffix : str Suffix to append to the field name. Can be None if the output is a directory. load_fn : FN_WITHOUT_SUFFIX | FN_WITH_SUFFIX Function to load the field. **kwargs Additional arguments to pass to the field. Examples -------- >>> import numpy as np >>> import zntrack ... >>> def _load_fn(self: zntrack.Node, name: str, suffix: str) -> np.ndarray: ... with self.state.fs.open( ... (self.nwd / name).with_suffix(suffix), mode="rb" ... ) as f: ... return np.load(f) ... >>> def _dump_fn(self: zntrack.Node, name: str, suffix: str) -> None: ... with open((self.nwd / name).with_suffix(suffix), mode="wb") as f: ... np.save(f, getattr(self, name)) ... >>> def numpy_field(*, cache: bool = True, independent: bool = False, **kwargs): ... return field( default=zntrack.NOT_AVAILABLE ... cache=cache, ... independent=independent, ... field_type=zntrack.FieldTypes.OUTS, ... dump_fn=_dump_fn, ... suffix=".npy", ... load_fn=_load_fn, ... **kwargs, ... ) ... >>> class MyNode(Node): ... data: np.ndarray = numpy_field() ... ... def run(self) -> None: ... self.data = np.arange(9).reshape(3, 3) """ kwargs["metadata"] = kwargs.get("metadata", {}) kwargs["metadata"][FIELD_TYPE] = field_type kwargs["metadata"][ZNTRACK_CACHE] = cache kwargs["metadata"][ZNTRACK_INDEPENDENT_OUTPUT_TYPE] = independent if load_fn is not None: kwargs["metadata"][ZNTRACK_FIELD_LOAD] = functools.partial( base_getter, func=load_fn ) if dump_fn is not None: kwargs["metadata"][ZNTRACK_FIELD_DUMP] = dump_fn if suffix is not None: kwargs["metadata"][ZNTRACK_FIELD_SUFFIX] = suffix return znfields.field(default=default, getter=plugin_getter, **kwargs)