More about Nodes

This section describes some special cases for Node definitions.

On and Off Graph Nodes

The Node instances we have seen so far are all placed onto the graph. In other words, they are defined within the context of the Project and will have a run method that is executed when the Project runs.

Note

Each of these Node instances is represented by an individual stage in the DVC graph.

In some cases, a Node should provide additional methods but will only be used within other Node instances. Such a Node is called “off-graph” and can be represented by a Python dataclass. They are often used to define an interchangeable model, as illustrated in the example on Scikit-learn Classifier Comparison. Another use case for off-graph Node instances is reusing a Node from another project. If you load a Node via zntrack.from_rev, you can also use it as an off-graph Node.

Note

Just like on-graph Node definitions, it must be possible to import the dataclass-derived Node. Therefore, it is recommended to place them alongside on-graph Node definitions, e.g., in the same module. If you define them inside main.py, you must ensure that the Project is constructed inside a code block after if __name__ == "__main__": to avoid executing the script when importing the Node.

from dataclasses import dataclass
import zntrack

@dataclass
class Shift:
    shift: float

    def compute(self, input: float) -> float:
        return input + self.shift

@dataclass
class Scale:
    scale: float

    def compute(self, input: float) -> float:
        return input * self.scale

class ManipulateNumber(zntrack.Node):
    number: float = zntrack.params()
    method: Shift | Scale = zntrack.deps()

    result: float = zntrack.outs()

    def run(self) -> None:
        self.result = self.method.compute(self.number)

if __name__ == "__main__":
    project = zntrack.Project()

    # You can define these Nodes anywhere, but
    # to avoid confusion, they should be placed outside the Project context
    shift = Shift(shift=1.0)
    scale = Scale(scale=2.0)

    with project:
        shifted_number = ManipulateNumber(number=1.0, method=shift)
        scaled_number = ManipulateNumber(number=1.0, method=scale)
    project.repro()

Always Changed

In some cases, you may want a Node to always run, even if the inputs have not changed. This can be useful when debugging a new Node. In such cases, you can set always_changed=True.

import zntrack.examples

project = zntrack.Project()

with project:
    node = zntrack.examples.ParamsToOuts(params=42, always_changed=True)

project.repro()

Node State

Each Node provides a state attribute to access metadata or the DVCFileSystem. The zntrack.state.NodeStatus() is frozen and read-only.

class zntrack.state.NodeStatus(remote: str | None = None, rev: str | None = None, run_count: int = 0, state: ~zntrack.config.NodeStatusEnum = NodeStatusEnum.CREATED, lazy_evaluation: bool = True, tmp_path: ~pathlib.Path | None = None, node: Node|None = None, plugins: dict[str, ~zntrack.plugins.base.ZnTrackPlugin] = <factory>, group: ~zntrack.group.Group | None = None, run_time: ~datetime.timedelta | None = None)[source]

Node status object.

Parameters:
  • remote (str, optional) – The repository remote, e.g. the URL of the git repository.

  • rev (str, optional) – The revision of the repository, e.g. the git commit hash.

  • run_count (int) – How often this Node has been run. Only incremented when the Node is restarted.

  • state (NodeStatusEnum) – The state of the Node.

  • lazy_evaluation (bool) – Whether to load fields lazily.

  • tmp_path (pathlib.Path, optional) – The temporary path when using ‘use_tmp_path’.

  • node (Node, optional) – The Node object.

  • plugins (dict[str, ZnTrackPlugin], optional) – Active plugins. In addition to the default plugins, MLFLow or AIM plugins will be added here.

  • group (Group, optional) – The group of the Node.

  • run_time (datetime.timedelta, optional) – The total run time of the Node.

  • name (str) – The name of the Node.

  • nwd (pathlib.Path) – The node working directory.

  • restarted (bool) – Whether the Node was restarted and has been run at least once before.

property fs: AbstractFileSystem

Get the file system of the Node.

If the remote is None, the local file system is returned. Otherwise, a DVCFileSystem is returned. The FileSystem should be used to open files to ensure, that the correct version of the file is loaded.

Examples

>>> import zntrack
>>> from pathlib import Path
>>>
>>> class MyNode(zntrack.Node):
...     outs_path: Path = zntrack.outs_path(zntrack.nwd / "file.txt")
...
...     def run(self):
...         self.outs_path.parent.mkdir(parents=True, exist_ok=True)
...         self.outs_path.write_text("Hello World!")
...
...     @property
...     def data(self):
...         with self.state.fs.open(self.outs_path) as f:
...             return f.read()
...
>>> # build and run the graph and make multiple commits.
>>> # the filesystem ensures that the correct version of the file is loaded.
>>>
>>> zntrack.from_rev("MyNode", rev="HEAD").data
>>> zntrack.from_remote("MyNode", rev="HEAD~1").data
use_tmp_path(path: Path | None = None) Iterator[Path][source]

Load the data for *_path into a temporary directory.

If you can not use node.state.fs.open you can use this as an alternative. This will load the data into a temporary directory and then delete it afterwards. The respective paths node.*_path will be replaced automatically inside the context manager.

This is only set, if either remote or rev are set. Otherwise, the data will be loaded from the current directory.

Examples

>>> import zntrack
>>> from pathlib import Path
>>>
>>> class MyNode(zntrack.Node):
...     outs_path: Path = zntrack.outs_path(zntrack.nwd / "file.txt")
...
...     def run(self):
...         self.outs_path.parent.mkdir(parents=True, exist_ok=True)
...         self.outs_path.write_text("Hello World!")
...
...     @property
...     def data(self):
...         with self.state.use_tmp_path():
...             with open(self.outs_path) as f:
...                 return f.read()
...
>>> # build and run the graph and make multiple commits.
>>> # the `use_tmp_path` ensures that the correct version
>>> # of the file is loaded in the temporary directory and
>>> # the `self.outs_path` is updated accordingly.
>>>
>>> zntrack.from_rev("MyNode", rev="HEAD").data
>>> zntrack.from_remote("MyNode", rev="HEAD~1").data

Custom Run Methods

By default, a Node will execute the run method. Sometimes, it is useful to define multiple methods for a single Node with slightly different behavior. This can be achieved by using zntrack.apply().

zntrack.apply(obj: o, method: str) o[source]

Update the default run method of zntrack.Node.

Parameters:
  • obj (zntrack.Node) – The node to copy and update the run method.

  • method (str) – The new method to use instead of the default run method.

Returns:

A new class which uses the new method instead of the default run method.

Return type:

zntrack.Node

Examples

>>> import zntrack
>>> class MyNode(zntrack.Node):
...     outs: str = zntrack.outs()
...
...     def run(self):
...         self.outs = "Hello, World!"
...
...     def my_run(self):
...         self.outs = "Hello, Custom World!"
...
>>> OtherMyNode = zntrack.apply(MyNode, "my_run")
>>> with zntrack.Project() as proj:
...     a = MyNode()
...     b = OtherMyNode()
>>> proj.repro()
>>> a.outs
'Hello, World!'
>>> b.outs
'Hello, Custom World!'