[docs]classCometCallbackHandler(BaseMetadataCallbackHandler,BaseCallbackHandler):"""Callback Handler that logs to Comet. Parameters: job_type (str): The type of comet_ml task such as "inference", "testing" or "qc" project_name (str): The comet_ml project name tags (list): Tags to add to the task task_name (str): Name of the comet_ml task visualize (bool): Whether to visualize the run. complexity_metrics (bool): Whether to log complexity metrics stream_logs (bool): Whether to stream callback actions to Comet This handler will utilize the associated callback method and formats the input of each callback function with metadata regarding the state of LLM run, and adds the response to the list of records for both the {method}_records and action. It then logs the response to Comet. """
[docs]def__init__(self,task_type:Optional[str]="inference",workspace:Optional[str]=None,project_name:Optional[str]=None,tags:Optional[Sequence]=None,name:Optional[str]=None,visualizations:Optional[List[str]]=None,complexity_metrics:bool=False,custom_metrics:Optional[Callable]=None,stream_logs:bool=True,)->None:"""Initialize callback handler."""self.comet_ml=import_comet_ml()super().__init__()self.task_type=task_typeself.workspace=workspaceself.project_name=project_nameself.tags=tagsself.visualizations=visualizationsself.complexity_metrics=complexity_metricsself.custom_metrics=custom_metricsself.stream_logs=stream_logsself.temp_dir=tempfile.TemporaryDirectory()self.experiment=_get_experiment(workspace,project_name)self.experiment.log_other("Created from","langchain")iftags:self.experiment.add_tags(tags)self.name=nameifself.name:self.experiment.set_name(self.name)warning=("The comet_ml callback is currently in beta and is subject to change ""based on updates to `langchain`. Please report any issues to ""https://github.com/comet-ml/issue-tracking/issues with the tag ""`langchain`.")self.comet_ml.LOGGER.warning(warning)self.callback_columns:list=[]self.action_records:list=[]self.complexity_metrics=complexity_metricsifself.visualizations:spacy=import_spacy()self.nlp=spacy.load("en_core_web_sm")else:self.nlp=None
[docs]defon_llm_start(self,serialized:Dict[str,Any],prompts:List[str],**kwargs:Any)->None:"""Run when LLM starts."""self.step+=1self.llm_starts+=1self.starts+=1metadata=self._init_resp()metadata.update({"action":"on_llm_start"})metadata.update(flatten_dict(serialized))metadata.update(self.get_custom_callback_meta())forpromptinprompts:prompt_resp=deepcopy(metadata)prompt_resp["prompts"]=promptself.on_llm_start_records.append(prompt_resp)self.action_records.append(prompt_resp)ifself.stream_logs:self._log_stream(prompt,metadata,self.step)
[docs]defon_llm_new_token(self,token:str,**kwargs:Any)->None:"""Run when LLM generates a new token."""self.step+=1self.llm_streams+=1resp=self._init_resp()resp.update({"action":"on_llm_new_token","token":token})resp.update(self.get_custom_callback_meta())self.action_records.append(resp)
[docs]defon_llm_end(self,response:LLMResult,**kwargs:Any)->None:"""Run when LLM ends running."""self.step+=1self.llm_ends+=1self.ends+=1metadata=self._init_resp()metadata.update({"action":"on_llm_end"})metadata.update(flatten_dict(response.llm_outputor{}))metadata.update(self.get_custom_callback_meta())output_complexity_metrics=[]output_custom_metrics=[]forprompt_idx,generationsinenumerate(response.generations):forgen_idx,generationinenumerate(generations):text=generation.textgeneration_resp=deepcopy(metadata)generation_resp.update(flatten_dict(generation.dict()))complexity_metrics=self._get_complexity_metrics(text)ifcomplexity_metrics:output_complexity_metrics.append(complexity_metrics)generation_resp.update(complexity_metrics)custom_metrics=self._get_custom_metrics(generation,prompt_idx,gen_idx)ifcustom_metrics:output_custom_metrics.append(custom_metrics)generation_resp.update(custom_metrics)ifself.stream_logs:self._log_stream(text,metadata,self.step)self.action_records.append(generation_resp)self.on_llm_end_records.append(generation_resp)self._log_text_metrics(output_complexity_metrics,step=self.step)self._log_text_metrics(output_custom_metrics,step=self.step)
[docs]defon_llm_error(self,error:BaseException,**kwargs:Any)->None:"""Run when LLM errors."""self.step+=1self.errors+=1
[docs]defon_chain_start(self,serialized:Dict[str,Any],inputs:Dict[str,Any],**kwargs:Any)->None:"""Run when chain starts running."""self.step+=1self.chain_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_chain_start"})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())forchain_input_key,chain_input_valininputs.items():ifisinstance(chain_input_val,str):input_resp=deepcopy(resp)ifself.stream_logs:self._log_stream(chain_input_val,resp,self.step)input_resp.update({chain_input_key:chain_input_val})self.action_records.append(input_resp)else:self.comet_ml.LOGGER.warning(f"Unexpected data format provided! "f"Input Value for {chain_input_key} will not be logged")
[docs]defon_chain_end(self,outputs:Dict[str,Any],**kwargs:Any)->None:"""Run when chain ends running."""self.step+=1self.chain_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_chain_end"})resp.update(self.get_custom_callback_meta())forchain_output_key,chain_output_valinoutputs.items():ifisinstance(chain_output_val,str):output_resp=deepcopy(resp)ifself.stream_logs:self._log_stream(chain_output_val,resp,self.step)output_resp.update({chain_output_key:chain_output_val})self.action_records.append(output_resp)else:self.comet_ml.LOGGER.warning(f"Unexpected data format provided! "f"Output Value for {chain_output_key} will not be logged")
[docs]defon_chain_error(self,error:BaseException,**kwargs:Any)->None:"""Run when chain errors."""self.step+=1self.errors+=1
[docs]defon_tool_start(self,serialized:Dict[str,Any],input_str:str,**kwargs:Any)->None:"""Run when tool starts running."""self.step+=1self.tool_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_tool_start"})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())ifself.stream_logs:self._log_stream(input_str,resp,self.step)resp.update({"input_str":input_str})self.action_records.append(resp)
[docs]defon_tool_end(self,output:Any,**kwargs:Any)->None:"""Run when tool ends running."""output=str(output)self.step+=1self.tool_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_tool_end"})resp.update(self.get_custom_callback_meta())ifself.stream_logs:self._log_stream(output,resp,self.step)resp.update({"output":output})self.action_records.append(resp)
[docs]defon_tool_error(self,error:BaseException,**kwargs:Any)->None:"""Run when tool errors."""self.step+=1self.errors+=1
[docs]defon_text(self,text:str,**kwargs:Any)->None:""" Run when agent is ending. """self.step+=1self.text_ctr+=1resp=self._init_resp()resp.update({"action":"on_text"})resp.update(self.get_custom_callback_meta())ifself.stream_logs:self._log_stream(text,resp,self.step)resp.update({"text":text})self.action_records.append(resp)
[docs]defon_agent_finish(self,finish:AgentFinish,**kwargs:Any)->None:"""Run when agent ends running."""self.step+=1self.agent_ends+=1self.ends+=1resp=self._init_resp()output=finish.return_values["output"]log=finish.logresp.update({"action":"on_agent_finish","log":log})resp.update(self.get_custom_callback_meta())ifself.stream_logs:self._log_stream(output,resp,self.step)resp.update({"output":output})self.action_records.append(resp)
[docs]defon_agent_action(self,action:AgentAction,**kwargs:Any)->Any:"""Run on agent action."""self.step+=1self.tool_starts+=1self.starts+=1tool=action.tooltool_input=str(action.tool_input)log=action.logresp=self._init_resp()resp.update({"action":"on_agent_action","log":log,"tool":tool})resp.update(self.get_custom_callback_meta())ifself.stream_logs:self._log_stream(tool_input,resp,self.step)resp.update({"tool_input":tool_input})self.action_records.append(resp)
def_get_complexity_metrics(self,text:str)->dict:"""Compute text complexity metrics using textstat. Parameters: text (str): The text to analyze. Returns: (dict): A dictionary containing the complexity metrics. """resp={}ifself.complexity_metrics:text_complexity_metrics=_fetch_text_complexity_metrics(text)resp.update(text_complexity_metrics)returnrespdef_get_custom_metrics(self,generation:Generation,prompt_idx:int,gen_idx:int)->dict:"""Compute Custom Metrics for an LLM Generated Output Args: generation (LLMResult): Output generation from an LLM prompt_idx (int): List index of the input prompt gen_idx (int): List index of the generated output Returns: dict: A dictionary containing the custom metrics. """resp={}ifself.custom_metrics:custom_metrics=self.custom_metrics(generation,prompt_idx,gen_idx)resp.update(custom_metrics)returnresp
[docs]defflush_tracker(self,langchain_asset:Any=None,task_type:Optional[str]="inference",workspace:Optional[str]=None,project_name:Optional[str]="comet-langchain-demo",tags:Optional[Sequence]=None,name:Optional[str]=None,visualizations:Optional[List[str]]=None,complexity_metrics:bool=False,custom_metrics:Optional[Callable]=None,finish:bool=False,reset:bool=False,)->None:"""Flush the tracker and setup the session. Everything after this will be a new table. Args: name: Name of the performed session so far so it is identifiable langchain_asset: The langchain asset to save. finish: Whether to finish the run. Returns: None """self._log_session(langchain_asset)iflangchain_asset:try:self._log_model(langchain_asset)exceptException:self.comet_ml.LOGGER.error("Failed to export agent or LLM to Comet",exc_info=True,extra={"show_traceback":True},)iffinish:self.experiment.end()ifreset:self._reset(task_type,workspace,project_name,tags,name,visualizations,complexity_metrics,custom_metrics,)
def_log_stream(self,prompt:str,metadata:dict,step:int)->None:self.experiment.log_text(prompt,metadata=metadata,step=step)def_log_model(self,langchain_asset:Any)->None:model_parameters=self._get_llm_parameters(langchain_asset)self.experiment.log_parameters(model_parameters,prefix="model")langchain_asset_path=Path(self.temp_dir.name,"model.json")model_name=self.nameifself.nameelseLANGCHAIN_MODEL_NAMEtry:ifhasattr(langchain_asset,"save"):langchain_asset.save(langchain_asset_path)self.experiment.log_model(model_name,str(langchain_asset_path))except(ValueError,AttributeError,NotImplementedError)ase:ifhasattr(langchain_asset,"save_agent"):langchain_asset.save_agent(langchain_asset_path)self.experiment.log_model(model_name,str(langchain_asset_path))else:self.comet_ml.LOGGER.error(f"{e}"" Could not save Langchain Asset "f"for {langchain_asset.__class__.__name__}")def_log_session(self,langchain_asset:Optional[Any]=None)->None:try:llm_session_df=self._create_session_analysis_dataframe(langchain_asset)# Log the cleaned dataframe as a tableself.experiment.log_table("langchain-llm-session.csv",llm_session_df)exceptException:self.comet_ml.LOGGER.warning("Failed to log session data to Comet",exc_info=True,extra={"show_traceback":True},)try:metadata={"langchain_version":str(langchain_community.__version__)}# Log the langchain low-level records as a JSON file directlyself.experiment.log_asset_data(self.action_records,"langchain-action_records.json",metadata=metadata)exceptException:self.comet_ml.LOGGER.warning("Failed to log session data to Comet",exc_info=True,extra={"show_traceback":True},)try:self._log_visualizations(llm_session_df)exceptException:self.comet_ml.LOGGER.warning("Failed to log visualizations to Comet",exc_info=True,extra={"show_traceback":True},)def_log_text_metrics(self,metrics:Sequence[dict],step:int)->None:ifnotmetrics:returnmetrics_summary=_summarize_metrics_for_generated_outputs(metrics)forkey,valueinmetrics_summary.items():self.experiment.log_metrics(value,prefix=key,step=step)def_log_visualizations(self,session_df:Any)->None:ifnot(self.visualizationsandself.nlp):returnspacy=import_spacy()prompts=session_df["prompts"].tolist()outputs=session_df["text"].tolist()foridx,(prompt,output)inenumerate(zip(prompts,outputs)):doc=self.nlp(output)sentence_spans=list(doc.sents)forvisualizationinself.visualizations:try:html=spacy.displacy.render(sentence_spans,style=visualization,options={"compact":True},jupyter=False,page=True,)self.experiment.log_asset_data(html,name=f"langchain-viz-{visualization}-{idx}.html",metadata={"prompt":prompt},step=idx,)exceptExceptionase:self.comet_ml.LOGGER.warning(e,exc_info=True,extra={"show_traceback":True})returndef_reset(self,task_type:Optional[str]=None,workspace:Optional[str]=None,project_name:Optional[str]=None,tags:Optional[Sequence]=None,name:Optional[str]=None,visualizations:Optional[List[str]]=None,complexity_metrics:bool=False,custom_metrics:Optional[Callable]=None,)->None:_task_type=task_typeiftask_typeelseself.task_type_workspace=workspaceifworkspaceelseself.workspace_project_name=project_nameifproject_nameelseself.project_name_tags=tagsiftagselseself.tags_name=nameifnameelseself.name_visualizations=visualizationsifvisualizationselseself.visualizations_complexity_metrics=(complexity_metricsifcomplexity_metricselseself.complexity_metrics)_custom_metrics=custom_metricsifcustom_metricselseself.custom_metricsself.__init__(# type: ignoretask_type=_task_type,workspace=_workspace,project_name=_project_name,tags=_tags,name=_name,visualizations=_visualizations,complexity_metrics=_complexity_metrics,custom_metrics=_custom_metrics,)self.reset_callback_meta()self.temp_dir=tempfile.TemporaryDirectory()def_create_session_analysis_dataframe(self,langchain_asset:Any=None)->dict:pd=import_pandas()llm_parameters=self._get_llm_parameters(langchain_asset)num_generations_per_prompt=llm_parameters.get("n",1)llm_start_records_df=pd.DataFrame(self.on_llm_start_records)# Repeat each input row based on the number of outputs generated per promptllm_start_records_df=llm_start_records_df.loc[llm_start_records_df.index.repeat(num_generations_per_prompt)].reset_index(drop=True)llm_end_records_df=pd.DataFrame(self.on_llm_end_records)llm_session_df=pd.merge(llm_start_records_df,llm_end_records_df,left_index=True,right_index=True,suffixes=["_llm_start","_llm_end"],)returnllm_session_dfdef_get_llm_parameters(self,langchain_asset:Any=None)->dict:ifnotlangchain_asset:return{}try:ifhasattr(langchain_asset,"agent"):llm_parameters=langchain_asset.agent.llm_chain.llm.dict()elifhasattr(langchain_asset,"llm_chain"):llm_parameters=langchain_asset.llm_chain.llm.dict()elifhasattr(langchain_asset,"llm"):llm_parameters=langchain_asset.llm.dict()else:llm_parameters=langchain_asset.dict()exceptException:return{}returnllm_parameters