Source code for langchain_experimental.rl_chain.pick_best_chain
from__future__importannotationsimportloggingfromtypingimportAny,Dict,List,Optional,Tuple,Type,Unionfromlangchain.base_languageimportBaseLanguageModelfromlangchain.chains.llmimportLLMChainfromlangchain_core.callbacks.managerimportCallbackManagerForChainRunfromlangchain_core.promptsimportBasePromptTemplateimportlangchain_experimental.rl_chain.baseasbasefromlangchain_experimental.rl_chain.helpersimportembedlogger=logging.getLogger(__name__)# sentinel object used to distinguish between# user didn't supply anything or user explicitly supplied NoneSENTINEL=object()
[docs]classPickBestSelected(base.Selected):"""Selected class for PickBest chain."""index:Optional[int]probability:Optional[float]score:Optional[float]
[docs]classPickBestFeatureEmbedder(base.Embedder[PickBestEvent]):"""Embed the `BasedOn` and `ToSelectFrom` inputs into a format that can be used by the learning policy. Attributes: model name (Any, optional): The type of embeddings to be used for feature representation. Defaults to BERT SentenceTransformer. """# noqa E501
[docs]defget_context_and_action_embeddings(self,event:PickBestEvent)->tuple:context_emb=embed(event.based_on,self.model)ifevent.based_onelseNoneto_select_from_var_name,to_select_from=next(iter(event.to_select_from.items()),(None,None))action_embs=((embed(to_select_from,self.model,to_select_from_var_name)ifevent.to_select_fromelseNone)ifto_select_fromelseNone)ifnotcontext_embornotaction_embs:raiseValueError("Context and to_select_from must be provided in the inputs dictionary")returncontext_emb,action_embs
[docs]defformat_auto_embed_off(self,event:PickBestEvent)->str:""" Converts the `BasedOn` and `ToSelectFrom` into a format that can be used by VW """chosen_action,cost,prob=self.get_label(event)context_emb,action_embs=self.get_context_and_action_embeddings(event)example_string=""example_string+="shared "forcontext_itemincontext_emb:forns,based_onincontext_item.items():e=" ".join(based_on)ifisinstance(based_on,list)elsebased_onexample_string+=f"|{ns}{e} "example_string+="\n"fori,actioninenumerate(action_embs):ifcostisnotNoneandchosen_action==i:example_string+=f"{chosen_action}:{cost}:{prob} "forns,action_embeddinginaction.items():e=(" ".join(action_embedding)ifisinstance(action_embedding,list)elseaction_embedding)example_string+=f"|{ns}{e} "example_string+="\n"# Strip the last newlinereturnexample_string[:-1]
[docs]classPickBest(base.RLChain[PickBestEvent]):"""Chain that leverages the Vowpal Wabbit (VW) model for reinforcement learning with a context, with the goal of modifying the prompt before the LLM call. Each invocation of the chain's `run()` method should be equipped with a set of potential actions (`ToSelectFrom`) and will result in the selection of a specific action based on the `BasedOn` input. This chosen action then informs the LLM (Language Model) prompt for the subsequent response generation. The standard operation flow of this Chain includes: 1. The Chain is invoked with inputs containing the `BasedOn` criteria and a list of potential actions (`ToSelectFrom`). 2. An action is selected based on the `BasedOn` input. 3. The LLM is called with the dynamic prompt, producing a response. 4. If a `selection_scorer` is provided, it is used to score the selection. 5. The internal Vowpal Wabbit model is updated with the `BasedOn` input, the chosen `ToSelectFrom` action, and the resulting score from the scorer. 6. The final response is returned. Expected input dictionary format: - At least one variable encapsulated within `BasedOn` to serve as the selection criteria. - A single list variable within `ToSelectFrom`, representing potential actions for the VW model. This list can take the form of: - A list of strings, e.g., `action = ToSelectFrom(["action1", "action2", "action3"])` - A list of list of strings e.g. `action = ToSelectFrom([["action1", "another identifier of action1"], ["action2", "another identifier of action2"]])` - A list of dictionaries, where each dictionary represents an action with namespace names as keys and corresponding action strings as values. For instance, `action = ToSelectFrom([{"namespace1": ["action1", "another identifier of action1"], "namespace2": "action2"}, {"namespace1": "action3", "namespace2": "action4"}])`. Extends: RLChain Attributes: feature_embedder (PickBestFeatureEmbedder, optional): Is an advanced attribute. Responsible for embedding the `BasedOn` and `ToSelectFrom` inputs. If omitted, a default embedder is utilized. """# noqa E501def__init__(self,*args:Any,**kwargs:Any,):auto_embed=kwargs.get("auto_embed",False)feature_embedder=kwargs.get("feature_embedder",None)iffeature_embedder:if"auto_embed"inkwargs:logger.warning("auto_embed will take no effect when explicit feature_embedder is provided"# noqa E501)# turning auto_embed off for cli setting belowauto_embed=Falseelse:feature_embedder=PickBestFeatureEmbedder(auto_embed=auto_embed)kwargs["feature_embedder"]=feature_embeddervw_cmd=kwargs.get("vw_cmd",[])ifvw_cmd:if"--cb_explore_adf"notinvw_cmd:raiseValueError("If vw_cmd is specified, it must include --cb_explore_adf")else:interactions=["--interactions=::"]ifauto_embed:interactions=["--interactions=@#","--ignore_linear=@","--ignore_linear=#",]vw_cmd=interactions+["--cb_explore_adf","--coin","--squarecb","--quiet",]kwargs["vw_cmd"]=vw_cmdsuper().__init__(*args,**kwargs)def_call_before_predict(self,inputs:Dict[str,Any])->PickBestEvent:context,actions=base.get_based_on_and_to_select_from(inputs=inputs)ifnotactions:raiseValueError("No variables using 'ToSelectFrom' found in the inputs. Please include at least one variable containing a list to select from."# noqa E501)iflen(list(actions.values()))>1:raiseValueError("Only one variable using 'ToSelectFrom' can be provided in the inputs for the PickBest chain. Please provide only one variable containing a list to select from."# noqa E501)ifnotcontext:raiseValueError("No variables using 'BasedOn' found in the inputs. Please include at least one variable containing information to base the selected of ToSelectFrom on."# noqa E501)event=PickBestEvent(inputs=inputs,to_select_from=actions,based_on=context)returneventdef_call_after_predict_before_llm(self,inputs:Dict[str,Any],event:PickBestEvent,prediction:List[Tuple[int,float]],)->Tuple[Dict[str,Any],PickBestEvent]:importnumpyasnpprob_sum=sum(probfor_,probinprediction)probabilities=[prob/prob_sumfor_,probinprediction]## sample from the pmfsampled_index=np.random.choice(len(prediction),p=probabilities)sampled_ap=prediction[sampled_index]sampled_action=sampled_ap[0]sampled_prob=sampled_ap[1]selected=PickBestSelected(index=sampled_action,probability=sampled_prob)event.selected=selected# only one key, value pair in event.to_select_fromkey,value=next(iter(event.to_select_from.items()))next_chain_inputs=inputs.copy()next_chain_inputs.update({key:value[event.selected.index]})returnnext_chain_inputs,eventdef_call_after_llm_before_scoring(self,llm_response:str,event:PickBestEvent)->Tuple[Dict[str,Any],PickBestEvent]:next_chain_inputs=event.inputs.copy()# only one key, value pair in event.to_select_fromvalue=next(iter(event.to_select_from.values()))v=(value[event.selected.index]ifevent.selectedelseevent.to_select_from.values())next_chain_inputs.update({self.selected_based_on_input_key:str(event.based_on),self.selected_input_key:v,})returnnext_chain_inputs,eventdef_call_after_scoring_before_learning(self,event:PickBestEvent,score:Optional[float])->PickBestEvent:ifevent.selected:event.selected.score=scorereturneventdef_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:returnsuper()._call(run_manager=run_manager,inputs=inputs)@propertydef_chain_type(self)->str:return"rl_chain_pick_best"