[docs]def__init__(self,folder:Union[str,os.PathLike],with_history:bool=True,reset:bool=False,):self.folder=Path(folder)self.model_path=self.folder/"latest.vw"self.with_history=with_historyifresetandself.has_history():logger.warning("There is non empty history which is recommended to be cleaned up")ifself.model_path.exists():os.remove(self.model_path)self.folder.mkdir(parents=True,exist_ok=True)
[docs]defsave(self,workspace:"vw.Workspace")->None:withopen(self.model_path,"wb")asf:logger.info(f"storing rl_chain model in: {self.model_path}")f.write(workspace.serialize())ifself.with_history:# write historyshutil.copyfile(self.model_path,self.folder/f"model-{self.get_tag()}.vw")
[docs]defload(self,commandline:List[str])->"vw.Workspace":try:importvowpal_wabbit_nextasvwexceptImportErrorase:raiseImportError("Unable to import vowpal_wabbit_next, please install with ""`pip install vowpal_wabbit_next`.")fromemodel_data=Noneifself.model_path.exists():withopen(self.model_path,"rb")asf:model_data=f.read()ifmodel_data:logger.info(f"rl_chain model is loaded from: {self.model_path}")returnvw.Workspace(commandline,model_data=model_data)returnvw.Workspace(commandline)