importjsonimportznjsonfromzntrackimportconfigfromzntrack.configimportNOT_AVAILABLE,FieldTypesfromzntrack.fields.baseimportfieldfromzntrack.nodeimportNodedef_outs_getter(self:"Node",name:str,suffix:str):withself.state.fs.open((self.nwd/name).with_suffix(suffix))asf:returnjson.load(f,cls=znjson.ZnDecoder)def_outs_save_func(self:"Node",name:str,suffix:str):self.nwd.mkdir(parents=True,exist_ok=True)try:(self.nwd/name).with_suffix(suffix).write_text(znjson.dumps(getattr(self,name)))exceptTypeErroraserr:raiseTypeError(f"Error while saving {name} to {self.nwd/name}.json")fromerrdef_metrics_save_func(self:"Node",name:str,suffix:str):self.nwd.mkdir(parents=True,exist_ok=True)try:(self.nwd/name).with_suffix(suffix).write_text(json.dumps(getattr(self,name)))exceptTypeErroraserr:raiseTypeError(f"Error while saving {name} to {self.nwd/name}.json")fromerr
[docs]defouts(*,cache:bool=True,independent:bool=False,**kwargs):"""Define output for a node. An output can be anything that can be pickled. Parameters ---------- cache : bool, optional Set to true to use the DVC cache for the field. Default is ``zntrack.config.ALWAYS_CACHE``. independent : bool, optional Whether the output is independent of the node's inputs. Default is `False`. Examples -------- >>> import zntrack >>> class MyNode(zntrack.Node): ... outs: int = zntrack.outs() ... ... def run(self) -> None: ... '''Save output to self.outs.''' """returnfield(default=NOT_AVAILABLE,cache=cache,independent=independent,field_type=FieldTypes.OUTS,dump_fn=_outs_save_func,suffix=".json",load_fn=_outs_getter,repr=False,init=False,**kwargs,)
[docs]defmetrics(*,cache:bool|None=None,independent:bool=False,**kwargs):"""Define metrics for a node. The metrics must be a dictionary that can be serialized to JSON. Parameters ---------- cache : bool, optional Set to true to use the DVC cache for the field. Default is ``zntrack.config.ALWAYS_CACHE``. independent : bool, optional Whether the output is independent of the node's inputs. Default is `False`. Examples -------- >>> import zntrack >>> class MyNode(zntrack.Node): ... metrics: dict = zntrack.metrics() ... ... def run(self) -> None: ... '''Save metrics to self.metrics.''' """ifcacheisNone:cache=config.ALWAYS_CACHEreturnfield(default=NOT_AVAILABLE,cache=cache,independent=independent,field_type=FieldTypes.METRICS,dump_fn=_metrics_save_func,suffix=".json",load_fn=_outs_getter,repr=False,init=False,**kwargs,)