[docs]defimport_clearml()->Any:"""Import the clearml python package and raise an error if it is not installed."""returnguard_import("clearml")
[docs]classClearMLCallbackHandler(BaseMetadataCallbackHandler,BaseCallbackHandler):"""Callback Handler that logs to ClearML. Parameters: job_type (str): The type of clearml task such as "inference", "testing" or "qc" project_name (str): The clearml project name tags (list): Tags to add to the task task_name (str): Name of the clearml task visualize (bool): Whether to visualize the run. complexity_metrics (bool): Whether to log complexity metrics stream_logs (bool): Whether to stream callback actions to ClearML This handler will utilize the associated callback method and formats the input of each callback function with metadata regarding the state of LLM run, and adds the response to the list of records for both the {method}_records and action. It then logs the response to the ClearML console. """
[docs]def__init__(self,task_type:Optional[str]="inference",project_name:Optional[str]="langchain_callback_demo",tags:Optional[Sequence]=None,task_name:Optional[str]=None,visualize:bool=False,complexity_metrics:bool=False,stream_logs:bool=False,)->None:"""Initialize callback handler."""clearml=import_clearml()spacy=import_spacy()super().__init__()self.task_type=task_typeself.project_name=project_nameself.tags=tagsself.task_name=task_nameself.visualize=visualizeself.complexity_metrics=complexity_metricsself.stream_logs=stream_logsself.temp_dir=tempfile.TemporaryDirectory()# Check if ClearML task already exists (e.g. in pipeline)ifclearml.Task.current_task():self.task=clearml.Task.current_task()else:self.task=clearml.Task.init(task_type=self.task_type,project_name=self.project_name,tags=self.tags,task_name=self.task_name,output_uri=True,)self.logger=self.task.get_logger()warning=("The clearml callback is currently in beta and is subject to change ""based on updates to `langchain`. Please report any issues to ""https://github.com/allegroai/clearml/issues with the tag `langchain`.")self.logger.report_text(warning,level=30,print_console=True)self.callback_columns:list=[]self.action_records:list=[]self.complexity_metrics=complexity_metricsself.visualize=visualizeself.nlp=spacy.load("en_core_web_sm")
[docs]defon_llm_start(self,serialized:Dict[str,Any],prompts:List[str],**kwargs:Any)->None:"""Run when LLM starts."""self.step+=1self.llm_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_llm_start"})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())forpromptinprompts:prompt_resp=deepcopy(resp)prompt_resp["prompts"]=promptself.on_llm_start_records.append(prompt_resp)self.action_records.append(prompt_resp)ifself.stream_logs:self.logger.report_text(prompt_resp)
[docs]defon_llm_new_token(self,token:str,**kwargs:Any)->None:"""Run when LLM generates a new token."""self.step+=1self.llm_streams+=1resp=self._init_resp()resp.update({"action":"on_llm_new_token","token":token})resp.update(self.get_custom_callback_meta())self.on_llm_token_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.logger.report_text(resp)
[docs]defon_llm_end(self,response:LLMResult,**kwargs:Any)->None:"""Run when LLM ends running."""self.step+=1self.llm_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_llm_end"})resp.update(flatten_dict(response.llm_outputor{}))resp.update(self.get_custom_callback_meta())forgenerationsinresponse.generations:forgenerationingenerations:generation_resp=deepcopy(resp)generation_resp.update(flatten_dict(generation.dict()))generation_resp.update(self.analyze_text(generation.text))self.on_llm_end_records.append(generation_resp)self.action_records.append(generation_resp)ifself.stream_logs:self.logger.report_text(generation_resp)
[docs]defon_llm_error(self,error:BaseException,**kwargs:Any)->None:"""Run when LLM errors."""self.step+=1self.errors+=1
[docs]defon_chain_start(self,serialized:Dict[str,Any],inputs:Dict[str,Any],**kwargs:Any)->None:"""Run when chain starts running."""self.step+=1self.chain_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_chain_start"})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())chain_input=inputs.get("input",inputs.get("human_input"))ifisinstance(chain_input,str):input_resp=deepcopy(resp)input_resp["input"]=chain_inputself.on_chain_start_records.append(input_resp)self.action_records.append(input_resp)ifself.stream_logs:self.logger.report_text(input_resp)elifisinstance(chain_input,list):forinpinchain_input:input_resp=deepcopy(resp)input_resp.update(inp)self.on_chain_start_records.append(input_resp)self.action_records.append(input_resp)ifself.stream_logs:self.logger.report_text(input_resp)else:raiseValueError("Unexpected data format provided!")
[docs]defon_chain_end(self,outputs:Dict[str,Any],**kwargs:Any)->None:"""Run when chain ends running."""self.step+=1self.chain_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_chain_end","outputs":outputs.get("output",outputs.get("text")),})resp.update(self.get_custom_callback_meta())self.on_chain_end_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.logger.report_text(resp)
[docs]defon_chain_error(self,error:BaseException,**kwargs:Any)->None:"""Run when chain errors."""self.step+=1self.errors+=1
[docs]defon_tool_start(self,serialized:Dict[str,Any],input_str:str,**kwargs:Any)->None:"""Run when tool starts running."""self.step+=1self.tool_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_tool_start","input_str":input_str})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())self.on_tool_start_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.logger.report_text(resp)
[docs]defon_tool_end(self,output:Any,**kwargs:Any)->None:"""Run when tool ends running."""output=str(output)self.step+=1self.tool_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_tool_end","output":output})resp.update(self.get_custom_callback_meta())self.on_tool_end_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.logger.report_text(resp)
[docs]defon_tool_error(self,error:BaseException,**kwargs:Any)->None:"""Run when tool errors."""self.step+=1self.errors+=1
[docs]defon_text(self,text:str,**kwargs:Any)->None:""" Run when agent is ending. """self.step+=1self.text_ctr+=1resp=self._init_resp()resp.update({"action":"on_text","text":text})resp.update(self.get_custom_callback_meta())self.on_text_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.logger.report_text(resp)
[docs]defon_agent_finish(self,finish:AgentFinish,**kwargs:Any)->None:"""Run when agent ends running."""self.step+=1self.agent_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_agent_finish","output":finish.return_values["output"],"log":finish.log,})resp.update(self.get_custom_callback_meta())self.on_agent_finish_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.logger.report_text(resp)
[docs]defon_agent_action(self,action:AgentAction,**kwargs:Any)->Any:"""Run on agent action."""self.step+=1self.tool_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_agent_action","tool":action.tool,"tool_input":action.tool_input,"log":action.log,})resp.update(self.get_custom_callback_meta())self.on_agent_action_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.logger.report_text(resp)
[docs]defanalyze_text(self,text:str)->dict:"""Analyze text using textstat and spacy. Parameters: text (str): The text to analyze. Returns: (dict): A dictionary containing the complexity metrics. """resp={}textstat=import_textstat()spacy=import_spacy()ifself.complexity_metrics:text_complexity_metrics={"flesch_reading_ease":textstat.flesch_reading_ease(text),"flesch_kincaid_grade":textstat.flesch_kincaid_grade(text),"smog_index":textstat.smog_index(text),"coleman_liau_index":textstat.coleman_liau_index(text),"automated_readability_index":textstat.automated_readability_index(text),"dale_chall_readability_score":textstat.dale_chall_readability_score(text),"difficult_words":textstat.difficult_words(text),"linsear_write_formula":textstat.linsear_write_formula(text),"gunning_fog":textstat.gunning_fog(text),"text_standard":textstat.text_standard(text),"fernandez_huerta":textstat.fernandez_huerta(text),"szigriszt_pazos":textstat.szigriszt_pazos(text),"gutierrez_polini":textstat.gutierrez_polini(text),"crawford":textstat.crawford(text),"gulpease_index":textstat.gulpease_index(text),"osman":textstat.osman(text),}resp.update(text_complexity_metrics)ifself.visualizeandself.nlpandself.temp_dir.nameisnotNone:doc=self.nlp(text)dep_out=spacy.displacy.render(doc,style="dep",jupyter=False,page=True)dep_output_path=Path(self.temp_dir.name,hash_string(f"dep-{text}")+".html")dep_output_path.open("w",encoding="utf-8").write(dep_out)ent_out=spacy.displacy.render(doc,style="ent",jupyter=False,page=True)ent_output_path=Path(self.temp_dir.name,hash_string(f"ent-{text}")+".html")ent_output_path.open("w",encoding="utf-8").write(ent_out)self.logger.report_media("Dependencies Plot",text,local_path=dep_output_path)self.logger.report_media("Entities Plot",text,local_path=ent_output_path)returnresp
@staticmethoddef_build_llm_df(base_df:pd.DataFrame,base_df_fields:Sequence,rename_map:Mapping)->pd.DataFrame:base_df_fields=[fieldforfieldinbase_df_fieldsiffieldinbase_df]rename_map={map_entry_k:map_entry_vformap_entry_k,map_entry_vinrename_map.items()ifmap_entry_kinbase_df_fields}llm_df=base_df[base_df_fields].dropna(axis=1)ifrename_map:llm_df=llm_df.rename(rename_map,axis=1)returnllm_dfdef_create_session_analysis_df(self)->Any:"""Create a dataframe with all the information from the session."""pd=import_pandas()on_llm_end_records_df=pd.DataFrame(self.on_llm_end_records)llm_input_prompts_df=ClearMLCallbackHandler._build_llm_df(base_df=on_llm_end_records_df,base_df_fields=["step","prompts"]+(["name"]if"name"inon_llm_end_records_dfelse["id"]),rename_map={"step":"prompt_step"},)complexity_metrics_columns=[]visualizations_columns:List=[]ifself.complexity_metrics:complexity_metrics_columns=["flesch_reading_ease","flesch_kincaid_grade","smog_index","coleman_liau_index","automated_readability_index","dale_chall_readability_score","difficult_words","linsear_write_formula","gunning_fog","text_standard","fernandez_huerta","szigriszt_pazos","gutierrez_polini","crawford","gulpease_index","osman",]llm_outputs_df=ClearMLCallbackHandler._build_llm_df(on_llm_end_records_df,["step","text","token_usage_total_tokens","token_usage_prompt_tokens","token_usage_completion_tokens",]+complexity_metrics_columns+visualizations_columns,{"step":"output_step","text":"output"},)session_analysis_df=pd.concat([llm_input_prompts_df,llm_outputs_df],axis=1)returnsession_analysis_df
[docs]defflush_tracker(self,name:Optional[str]=None,langchain_asset:Any=None,finish:bool=False,)->None:"""Flush the tracker and setup the session. Everything after this will be a new table. Args: name: Name of the performed session so far so it is identifiable langchain_asset: The langchain asset to save. finish: Whether to finish the run. Returns: None """pd=import_pandas()clearml=import_clearml()# Log the action recordsself.logger.report_table("Action Records",name,table_plot=pd.DataFrame(self.action_records))# Session analysissession_analysis_df=self._create_session_analysis_df()self.logger.report_table("Session Analysis",name,table_plot=session_analysis_df)ifself.stream_logs:self.logger.report_text({"action_records":pd.DataFrame(self.action_records),"session_analysis":session_analysis_df,})iflangchain_asset:langchain_asset_path=Path(self.temp_dir.name,"model.json")try:langchain_asset.save(langchain_asset_path)# Create output model and connect it to the taskoutput_model=clearml.OutputModel(task=self.task,config_text=load_json(langchain_asset_path))output_model.update_weights(weights_filename=str(langchain_asset_path),auto_delete_file=False,target_filename=name,)exceptValueError:langchain_asset.save_agent(langchain_asset_path)output_model=clearml.OutputModel(task=self.task,config_text=load_json(langchain_asset_path))output_model.update_weights(weights_filename=str(langchain_asset_path),auto_delete_file=False,target_filename=name,)exceptNotImplementedErrorase:print("Could not save model.")# noqa: T201print(repr(e))# noqa: T201pass# Cleanup after adding everything to ClearMLself.task.flush(wait_for_uploads=True)self.temp_dir.cleanup()self.temp_dir=tempfile.TemporaryDirectory()self.reset_callback_meta()iffinish:self.task.close()