[docs]defToSelectFrom(anything:Any)->_ToSelectFrom:"""Wrap a value to indicate that it should be selected from."""ifnotisinstance(anything,list):raiseValueError("ToSelectFrom must be a list to select from")return_ToSelectFrom(anything)
[docs]defEmbed(anything:Any,keep:bool=False)->Any:"""Wrap a value to indicate that it should be embedded."""ifisinstance(anything,_ToSelectFrom):returnToSelectFrom(Embed(anything.value,keep=keep))elifisinstance(anything,_BasedOn):returnBasedOn(Embed(anything.value,keep=keep))ifisinstance(anything,list):return[Embed(v,keep=keep)forvinanything]elifisinstance(anything,dict):return{k:Embed(v,keep=keep)fork,vinanything.items()}elifisinstance(anything,_Embed):returnanythingreturn_Embed(anything,keep=keep)
[docs]defEmbedAndKeep(anything:Any)->Any:"""Wrap a value to indicate that it should be embedded and kept."""returnEmbed(anything,keep=True)
# helper functions
[docs]defparse_lines(parser:"vw.TextFormatParser",input_str:str)->List["vw.Example"]:"""Parse the input string into a list of examples."""return[parser.parse_line(line)forlineininput_str.split("\n")]
[docs]defget_based_on_and_to_select_from(inputs:Dict[str,Any])->Tuple[Dict,Dict]:"""Get the BasedOn and ToSelectFrom from the inputs."""to_select_from={k:inputs[k].valueforkininputs.keys()ifisinstance(inputs[k],_ToSelectFrom)}ifnotto_select_from:raiseValueError("No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."# noqa: E501)based_on={k:inputs[k].valueifisinstance(inputs[k].value,list)else[inputs[k].value]forkininputs.keys()ifisinstance(inputs[k],_BasedOn)}returnbased_on,to_select_from
[docs]defprepare_inputs_for_autoembed(inputs:Dict[str,Any])->Dict[str,Any]:"""Prepare the inputs for auto embedding. Go over all the inputs and if something is either wrapped in _ToSelectFrom or _BasedOn, and if their inner values are not already _Embed, then wrap them in EmbedAndKeep while retaining their _ToSelectFrom or _BasedOn status """# noqa: E501next_inputs=inputs.copy()fork,vinnext_inputs.items():ifisinstance(v,_ToSelectFrom)orisinstance(v,_BasedOn):ifnotisinstance(v.value,_Embed):next_inputs[k].value=EmbedAndKeep(v.value)returnnext_inputs
# end helper functions
[docs]classSelected(ABC):"""Abstract class to represent the selected item."""pass
TSelected=TypeVar("TSelected",bound=Selected)
[docs]classEvent(Generic[TSelected],ABC):"""Abstract class to represent an event."""inputs:Dict[str,Any]selected:Optional[TSelected]
[docs]@staticmethoddefget_default_system_prompt()->SystemMessagePromptTemplate:returnSystemMessagePromptTemplate.from_template("PLEASE RESPOND ONLY WITH A SINGLE FLOAT AND NO OTHER TEXT EXPLANATION\n\ You are a strict judge that is called on to rank a response based on \ given criteria. You must respond with your ranking by providing a \ single float within the range [0, 1], 0 being very bad \ response and 1 being very good response.")
[docs]@staticmethoddefget_default_prompt()->ChatPromptTemplate:human_template='Given this based_on "{rl_chain_selected_based_on}" \ as the most important attribute, rank how good or bad this text is: \ "{rl_chain_selected}".'human_message_prompt=HumanMessagePromptTemplate.from_template(human_template)default_system_prompt=AutoSelectionScorer.get_default_system_prompt()chat_prompt=ChatPromptTemplate.from_messages([default_system_prompt,human_message_prompt])returnchat_prompt
[docs]defscore_response(self,inputs:Dict[str,Any],llm_response:str,event:Event)->float:ranking=self.llm_chain.predict(llm_response=llm_response,**inputs)ranking=ranking.strip()try:resp=float(ranking)returnrespexceptExceptionase:raiseRuntimeError(f"The auto selection scorer did not manage to score the response, there is always the option to try again or tweak the reward prompt. Error: {e}"# noqa: E501)
[docs]classRLChain(Chain,Generic[TEvent]):"""Chain that leverages the Vowpal Wabbit (VW) model as a learned policy for reinforcement learning. Attributes: - llm_chain (Chain): Represents the underlying Language Model chain. - prompt (BasePromptTemplate): The template for the base prompt. - selection_scorer (Union[SelectionScorer, None]): Scorer for the selection. Can be set to None. - policy (Optional[Policy]): The policy used by the chain to learn to populate a dynamic prompt. - auto_embed (bool): Determines if embedding should be automatic. Default is False. - metrics (Optional[Union[MetricsTrackerRollingWindow, MetricsTrackerAverage]]): Tracker for metrics, can be set to None. Initialization Attributes: - feature_embedder (Embedder): Embedder used for the `BasedOn` and `ToSelectFrom` inputs. - model_save_dir (str, optional): Directory for saving the VW model. Default is the current directory. - reset_model (bool): If set to True, the model starts training from scratch. Default is False. - vw_cmd (List[str], optional): Command line arguments for the VW model. - policy (Type[VwPolicy]): Policy used by the chain. - vw_logs (Optional[Union[str, os.PathLike]]): Path for the VW logs. - metrics_step (int): Step for the metrics tracker. Default is -1. If set without metrics_window_size, average metrics will be tracked, otherwise rolling window metrics will be tracked. - metrics_window_size (int): Window size for the metrics tracker. Default is -1. If set, rolling window metrics will be tracked. Notes: The class initializes the VW model using the provided arguments. If `selection_scorer` is not provided, a warning is logged, indicating that no reinforcement learning will occur unless the `update_with_delayed_score` method is called. """# noqa: E501class_NoOpPolicy(Policy):"""Placeholder policy that does nothing"""defpredict(self,event:TEvent)->Any:returnNonedeflearn(self,event:TEvent)->None:passdeflog(self,event:TEvent)->None:passllm_chain:Chainoutput_key:str="result"#: :meta private:prompt:BasePromptTemplateselection_scorer:Union[SelectionScorer,None]active_policy:Policy=_NoOpPolicy()auto_embed:bool=Falseselection_scorer_activated:bool=Trueselected_input_key:str="rl_chain_selected"selected_based_on_input_key:str="rl_chain_selected_based_on"metrics:Optional[Union[MetricsTrackerRollingWindow,MetricsTrackerAverage]]=Nonedef__init__(self,feature_embedder:Embedder,model_save_dir:str="./",reset_model:bool=False,vw_cmd:Optional[List[str]]=None,policy:Type[Policy]=VwPolicy,vw_logs:Optional[Union[str,os.PathLike]]=None,metrics_step:int=-1,metrics_window_size:int=-1,*args:Any,**kwargs:Any,):super().__init__(*args,**kwargs)ifself.selection_scorerisNone:logger.warning("No selection scorer provided, which means that no \ reinforcement learning will be done in the RL chain \ unless update_with_delayed_score is called.")ifisinstance(self.active_policy,RLChain._NoOpPolicy):self.active_policy=policy(model_repo=ModelRepository(model_save_dir,with_history=True,reset=reset_model),vw_cmd=vw_cmdor[],feature_embedder=feature_embedder,vw_logger=VwLogger(vw_logs),)ifmetrics_window_size>0:self.metrics=MetricsTrackerRollingWindow(step=metrics_step,window_size=metrics_window_size)else:self.metrics=MetricsTrackerAverage(step=metrics_step)model_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@propertydefinput_keys(self)->List[str]:"""Expect input key. :meta private: """return[]@propertydefoutput_keys(self)->List[str]:"""Expect output key. :meta private: """return[self.output_key]
[docs]defupdate_with_delayed_score(self,score:float,chain_response:Dict[str,Any],force_score:bool=False)->None:""" Updates the learned policy with the score provided. Will raise an error if selection_scorer is set, and force_score=True was not provided during the method call """# noqa: E501ifself._can_use_selection_scorer()andnotforce_score:raiseRuntimeError("The selection scorer is set, and force_score was not set to True. Please set force_score=True to use this function."# noqa: E501)ifself.metrics:self.metrics.on_feedback(score)event:TEvent=chain_response["selection_metadata"]self._call_after_scoring_before_learning(event=event,score=score)self.active_policy.learn(event=event)self.active_policy.log(event=event)
[docs]defdeactivate_selection_scorer(self)->None:""" Deactivates the selection scorer, meaning that the chain will no longer attempt to use the selection scorer to score responses. """# noqa: E501self.selection_scorer_activated=False
[docs]defactivate_selection_scorer(self)->None:""" Activates the selection scorer, meaning that the chain will attempt to use the selection scorer to score responses. """# noqa: E501self.selection_scorer_activated=True
[docs]defsave_progress(self)->None:""" This function should be called to save the state of the learned policy model. """self.active_policy.save()
def_validate_inputs(self,inputs:Dict[str,Any])->None:super()._validate_inputs(inputs)if(self.selected_input_keyininputs.keys()orself.selected_based_on_input_keyininputs.keys()):raiseValueError(f"The rl chain does not accept '{self.selected_input_key}' or '{self.selected_based_on_input_key}' as input keys, they are reserved for internal use during auto reward."# noqa: E501)def_can_use_selection_scorer(self)->bool:""" Returns whether the chain can use the selection scorer to score responses or not. """# noqa: E501returnself.selection_scorerisnotNoneandself.selection_scorer_activated@abstractmethoddef_call_before_predict(self,inputs:Dict[str,Any])->TEvent:...@abstractmethoddef_call_after_predict_before_llm(self,inputs:Dict[str,Any],event:TEvent,prediction:Any)->Tuple[Dict[str,Any],TEvent]:...@abstractmethoddef_call_after_llm_before_scoring(self,llm_response:str,event:TEvent)->Tuple[Dict[str,Any],TEvent]:...@abstractmethoddef_call_after_scoring_before_learning(self,event:TEvent,score:Optional[float])->TEvent:...def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()event:TEvent=self._call_before_predict(inputs=inputs)prediction=self.active_policy.predict(event=event)ifself.metrics:self.metrics.on_decision()next_chain_inputs,event=self._call_after_predict_before_llm(inputs=inputs,event=event,prediction=prediction)t=self.llm_chain.run(**next_chain_inputs,callbacks=_run_manager.get_child())_run_manager.on_text(t,color="green",verbose=self.verbose)t=t.strip()ifself.verbose:_run_manager.on_text("\nCode: ",verbose=self.verbose)output=t_run_manager.on_text("\nAnswer: ",verbose=self.verbose)_run_manager.on_text(output,color="yellow",verbose=self.verbose)next_chain_inputs,event=self._call_after_llm_before_scoring(llm_response=output,event=event)score=Nonetry:ifself._can_use_selection_scorer():score=self.selection_scorer.score_response(# type: ignoreinputs=next_chain_inputs,llm_response=output,event=event)exceptExceptionase:logger.info(f"The selection scorer was not able to score, \ and the chain was not able to adjust to this response, error: {e}")ifself.metricsandscoreisnotNone:self.metrics.on_feedback(score)event=self._call_after_scoring_before_learning(score=score,event=event)self.active_policy.learn(event=event)self.active_policy.log(event=event)return{self.output_key:{"response":output,"selection_metadata":event}}@propertydef_chain_type(self)->str:return"llm_personalizer_chain"