Source code for zntrack.state

import contextlib
import dataclasses
import datetime
import importlib.metadata
import json
import pathlib
import tempfile
import typing as t
import warnings

import dvc.api
import dvc.repo
import dvc.stage.serialize
from dvc.utils import dict_sha256
from fsspec.implementations.local import LocalFileSystem
from fsspec.spec import AbstractFileSystem

from zntrack.config import NodeStatusEnum
from zntrack.group import Group
from zntrack.plugins import ZnTrackPlugin
from zntrack.utils.node_wd import get_nwd

if t.TYPE_CHECKING:
    from zntrack import Node

PLUGIN_LIST = list[t.Type[ZnTrackPlugin]]
PLUGIN_DICT = dict[str, ZnTrackPlugin]


[docs] @dataclasses.dataclass(frozen=True) class NodeStatus: """Node status object. Parameters ---------- remote : str, optional The repository remote, e.g. the URL of the git repository. rev : str, optional The revision of the repository, e.g. the git commit hash. run_count : int How often this Node has been run. Only incremented when the Node is restarted. state : NodeStatusEnum The state of the Node. lazy_evaluation : bool Whether to load fields lazily. tmp_path : pathlib.Path, optional The temporary path when using 'use_tmp_path'. node : Node, optional The Node object. plugins : dict[str, ZnTrackPlugin], optional Active plugins. In addition to the default plugins, MLFLow or AIM plugins will be added here. group : Group, optional The group of the Node. run_time : datetime.timedelta, optional The total run time of the Node. name: str The name of the Node. nwd: pathlib.Path The node working directory. restarted: bool Whether the Node was restarted and has been run at least once before. """ remote: str | None = None rev: str | None = None run_count: int = 0 state: NodeStatusEnum = NodeStatusEnum.CREATED lazy_evaluation: bool = True tmp_path: pathlib.Path | None = None node: "Node|None" = dataclasses.field( default=None, repr=False, compare=False, hash=False ) plugins: PLUGIN_DICT = dataclasses.field( default_factory=dict, compare=False, repr=False ) group: Group | None = None run_time: datetime.timedelta | None = None # TODO: move node name and nwd to here as well @property def name(self) -> str: return self.node.name @property def nwd(self): if self.tmp_path is not None: return self.tmp_path return get_nwd(self.node) @property def fs(self) -> AbstractFileSystem: """Get the file system of the Node. If the remote is None, the local file system is returned. Otherwise, a DVCFileSystem is returned. The FileSystem should be used to open files to ensure, that the correct version of the file is loaded. Examples -------- >>> import zntrack >>> from pathlib import Path >>> >>> class MyNode(zntrack.Node): ... outs_path: Path = zntrack.outs_path(zntrack.nwd / "file.txt") ... ... def run(self): ... self.outs_path.parent.mkdir(parents=True, exist_ok=True) ... self.outs_path.write_text("Hello World!") ... ... @property ... def data(self): ... with self.state.fs.open(self.outs_path) as f: ... return f.read() ... >>> # build and run the graph and make multiple commits. >>> # the filesystem ensures that the correct version of the file is loaded. >>> >>> zntrack.from_rev("MyNode", rev="HEAD").data >>> zntrack.from_remote("MyNode", rev="HEAD~1").data """ if self.remote is None and self.rev is None: return LocalFileSystem() return dvc.api.DVCFileSystem( url=self.remote, rev=self.rev, ) @property def dvc_fs(self) -> dvc.api.DVCFileSystem: """Get the file system of the Node.""" return dvc.api.DVCFileSystem( url=self.remote, rev=self.rev, ) @property def restarted(self) -> bool: """Whether the node was restarted.""" return self.run_count > 1
[docs] @contextlib.contextmanager def use_tmp_path(self, path: pathlib.Path | None = None) -> t.Iterator[pathlib.Path]: """Load the data for ``*_path`` into a temporary directory. If you can not use ``node.state.fs.open`` you can use this as an alternative. This will load the data into a temporary directory and then delete it afterwards. The respective paths ``node.*_path`` will be replaced automatically inside the context manager. This is only set, if either ``remote`` or ``rev`` are set. Otherwise, the data will be loaded from the current directory. Examples -------- >>> import zntrack >>> from pathlib import Path >>> >>> class MyNode(zntrack.Node): ... outs_path: Path = zntrack.outs_path(zntrack.nwd / "file.txt") ... ... def run(self): ... self.outs_path.parent.mkdir(parents=True, exist_ok=True) ... self.outs_path.write_text("Hello World!") ... ... @property ... def data(self): ... with self.state.use_tmp_path(): ... with open(self.outs_path) as f: ... return f.read() ... >>> # build and run the graph and make multiple commits. >>> # the `use_tmp_path` ensures that the correct version >>> # of the file is loaded in the temporary directory and >>> # the `self.outs_path` is updated accordingly. >>> >>> zntrack.from_rev("MyNode", rev="HEAD").data >>> zntrack.from_remote("MyNode", rev="HEAD~1").data """ if path is not None: raise NotImplementedError("Custom paths are not implemented yet.") # This feature is only required when the load # is loaded, not when it is saved/executed if self.remote is None and self.rev is None: warnings.warn( "The temporary path is not used when neither remote or rev are set." "Consider checking for `self.state.remote` and `self.state.rev` when" "using `with node.state.use_tmp_path(): ...` ." ) yield pathlib.Path.cwd() return with tempfile.TemporaryDirectory() as tmpdir: self.node.__dict__["state"]["tmp_path"] = pathlib.Path(tmpdir) try: yield pathlib.Path(tmpdir) finally: self.node.__dict__["state"].pop("tmp_path")
def get_stage(self) -> dvc.stage.Stage: """Access to the internal dvc.repo api.""" stage = next(iter(self.dvc_fs.repo.stage.collect(self.name))) if self.rev is None and self.remote is None: # If we want to look at the current workspace result, we need to # load all the information, not just dvc.yaml stage.save(allow_missing=True, run_cache=False) return stage def get_stage_lock(self) -> dict: """Access to the internal dvc.repo api.""" stage = self.get_stage() return dvc.stage.serialize.to_single_stage_lockfile(stage) def get_stage_hash(self, include_outs: bool = False) -> str: """Get the hash of the stage.""" stage_lock = self.get_stage_lock() if include_outs: return dict_sha256(stage_lock) else: filtered_lock = { k: v for k, v in stage_lock.items() if k in ["cmd", "deps", "params"] } return dict_sha256(filtered_lock) def to_dict(self) -> dict: """Convert the NodeStatus to a dictionary.""" content = dataclasses.asdict(self) content.pop("node") return content def get_field(self, attribute: str) -> dataclasses.Field: fields = dataclasses.fields(self.node) for field in fields: if field.name == attribute: return field else: raise AttributeError(f"Unable to locate '{attribute}' in {self.node}.") def add_run_time(self, run_time: datetime.timedelta) -> None: """Add the run time to the node.""" if self.run_time is None: self.node.__dict__["state"]["run_time"] = run_time else: self.node.__dict__["state"]["run_time"] += run_time def increment_run_count(self) -> None: self.node.__dict__["state"]["run_count"] = self.run_count + 1 def save_node_meta(self) -> None: node_meta_content = { "uuid": str(self.node.uuid), "run_count": self.run_count, "zntrack_version": importlib.metadata.version("zntrack"), } if self.run_time is not None: node_meta_content["run_time"] = self.run_time.total_seconds() with contextlib.suppress(importlib.metadata.PackageNotFoundError): module = self.node.__module__.split(".")[0] node_meta_content["package_version"] = importlib.metadata.version(module) self.nwd.mkdir(parents=True, exist_ok=True) (self.nwd / "node-meta.json").write_text(json.dumps(node_meta_content, indent=2))