More about Nodes¶
This section describes some special cases for Node definitions.
On and Off Graph Nodes¶
The Node instances we have seen so far are all placed onto the graph.
In other words, they are defined within the context of the Project and will have a run
method that is executed when the Project runs.
In some cases, a Node should provide additional methods but will only be used within other Node instances.
Such a Node is called “off-graph” and can be represented by a Python dataclass
.
They are often used to define an interchangeable model, as illustrated in the example on Scikit-learn Classifier Comparison.
Another use case for off-graph Node instances is reusing a Node from another project.
If you load a Node via zntrack.from_rev
, you can also use it as an off-graph Node.
Note
Just like on-graph Node definitions, it must be possible to import the dataclass
-derived Node.
Therefore, it is recommended to place them alongside on-graph Node definitions, e.g., in the same module.
If you define them inside main.py
, you must ensure that the Project is constructed inside a code block
after if __name__ == "__main__":
to avoid executing the script when importing the Node.
from dataclasses import dataclass
import zntrack
@dataclass
class Shift:
shift: float
def compute(self, input: float) -> float:
return input + self.shift
@dataclass
class Scale:
scale: float
def compute(self, input: float) -> float:
return input * self.scale
class ManipulateNumber(zntrack.Node):
number: float = zntrack.params()
method: Shift | Scale = zntrack.deps()
result: float = zntrack.outs()
def run(self) -> None:
self.result = self.method.compute(self.number)
if __name__ == "__main__":
project = zntrack.Project()
# You can define these Nodes anywhere, but
# to avoid confusion, they should be placed outside the Project context
shift = Shift(shift=1.0)
scale = Scale(scale=2.0)
with project:
shifted_number = ManipulateNumber(number=1.0, method=shift)
scaled_number = ManipulateNumber(number=1.0, method=scale)
project.repro()
Always Changed¶
In some cases, you may want a Node to always run, even if the inputs have not changed.
This can be useful when debugging a new Node.
In such cases, you can set always_changed=True
.
import zntrack.examples
project = zntrack.Project()
with project:
node = zntrack.examples.ParamsToOuts(params=42, always_changed=True)
project.repro()
Node State¶
Each Node provides a state
attribute to access metadata or the DVCFileSystem.
The zntrack.state.NodeStatus()
is frozen
and read-only.
- class zntrack.state.NodeStatus(remote: str | None = None, rev: str | None = None, run_count: int = 0, state: ~zntrack.config.NodeStatusEnum = NodeStatusEnum.CREATED, lazy_evaluation: bool = True, tmp_path: ~pathlib.Path | None = None, node: Node|None = None, plugins: dict[str, ~zntrack.plugins.base.ZnTrackPlugin] = <factory>, group: ~zntrack.group.Group | None = None, run_time: ~datetime.timedelta | None = None)[source]¶
Node status object.
- Parameters:
remote (str, optional) – The repository remote, e.g. the URL of the git repository.
rev (str, optional) – The revision of the repository, e.g. the git commit hash.
run_count (int) – How often this Node has been run. Only incremented when the Node is restarted.
state (NodeStatusEnum) – The state of the Node.
lazy_evaluation (bool) – Whether to load fields lazily.
tmp_path (pathlib.Path, optional) – The temporary path when using ‘use_tmp_path’.
node (Node, optional) – The Node object.
plugins (dict[str, ZnTrackPlugin], optional) – Active plugins. In addition to the default plugins, MLFLow or AIM plugins will be added here.
group (Group, optional) – The group of the Node.
run_time (datetime.timedelta, optional) – The total run time of the Node.
name (str) – The name of the Node.
nwd (pathlib.Path) – The node working directory.
restarted (bool) – Whether the Node was restarted and has been run at least once before.
- property fs: AbstractFileSystem¶
Get the file system of the Node.
If the remote is None, the local file system is returned. Otherwise, a DVCFileSystem is returned. The FileSystem should be used to open files to ensure, that the correct version of the file is loaded.
Examples
>>> import zntrack >>> from pathlib import Path >>> >>> class MyNode(zntrack.Node): ... outs_path: Path = zntrack.outs_path(zntrack.nwd / "file.txt") ... ... def run(self): ... self.outs_path.parent.mkdir(parents=True, exist_ok=True) ... self.outs_path.write_text("Hello World!") ... ... @property ... def data(self): ... with self.state.fs.open(self.outs_path) as f: ... return f.read() ... >>> # build and run the graph and make multiple commits. >>> # the filesystem ensures that the correct version of the file is loaded. >>> >>> zntrack.from_rev("MyNode", rev="HEAD").data >>> zntrack.from_remote("MyNode", rev="HEAD~1").data
- use_tmp_path(path: Path | None = None) Iterator[Path] [source]¶
Load the data for
*_path
into a temporary directory.If you can not use
node.state.fs.open
you can use this as an alternative. This will load the data into a temporary directory and then delete it afterwards. The respective pathsnode.*_path
will be replaced automatically inside the context manager.This is only set, if either
remote
orrev
are set. Otherwise, the data will be loaded from the current directory.Examples
>>> import zntrack >>> from pathlib import Path >>> >>> class MyNode(zntrack.Node): ... outs_path: Path = zntrack.outs_path(zntrack.nwd / "file.txt") ... ... def run(self): ... self.outs_path.parent.mkdir(parents=True, exist_ok=True) ... self.outs_path.write_text("Hello World!") ... ... @property ... def data(self): ... with self.state.use_tmp_path(): ... with open(self.outs_path) as f: ... return f.read() ... >>> # build and run the graph and make multiple commits. >>> # the `use_tmp_path` ensures that the correct version >>> # of the file is loaded in the temporary directory and >>> # the `self.outs_path` is updated accordingly. >>> >>> zntrack.from_rev("MyNode", rev="HEAD").data >>> zntrack.from_remote("MyNode", rev="HEAD~1").data
Custom Run Methods¶
By default, a Node will execute the run
method.
Sometimes, it is useful to define multiple methods for a single Node with slightly different behavior.
This can be achieved by using zntrack.apply()
.
- zntrack.apply(obj: o, method: str) o [source]¶
Update the default
run
method ofzntrack.Node
.- Parameters:
obj (zntrack.Node) – The node to copy and update the
run
method.method (str) – The new method to use instead of the default
run
method.
- Returns:
A new class which uses the new method instead of the default
run
method.- Return type:
zntrack.Node
Examples
>>> import zntrack >>> class MyNode(zntrack.Node): ... outs: str = zntrack.outs() ... ... def run(self): ... self.outs = "Hello, World!" ... ... def my_run(self): ... self.outs = "Hello, Custom World!" ... >>> OtherMyNode = zntrack.apply(MyNode, "my_run") >>> with zntrack.Project() as proj: ... a = MyNode() ... b = OtherMyNode() >>> proj.repro() >>> a.outs 'Hello, World!' >>> b.outs 'Hello, Custom World!'