[docs]defimport_wandb()->Any:"""Import the wandb python package and raise an error if it is not installed."""returnguard_import("wandb")
[docs]defload_json_to_dict(json_path:Union[str,Path])->dict:"""Load json file to a dictionary. Parameters: json_path (str): The path to the json file. Returns: (dict): The dictionary representation of the json file. """withopen(json_path,"r")asf:data=json.load(f)returndata
[docs]defanalyze_text(text:str,complexity_metrics:bool=True,visualize:bool=True,nlp:Any=None,output_dir:Optional[Union[str,Path]]=None,)->dict:"""Analyze text using textstat and spacy. Parameters: text (str): The text to analyze. complexity_metrics (bool): Whether to compute complexity metrics. visualize (bool): Whether to visualize the text. nlp (spacy.lang): The spacy language model to use for visualization. output_dir (str): The directory to save the visualization files to. Returns: (dict): A dictionary containing the complexity metrics and visualization files serialized in a wandb.Html element. """resp={}textstat=import_textstat()wandb=import_wandb()spacy=import_spacy()ifcomplexity_metrics:text_complexity_metrics={"flesch_reading_ease":textstat.flesch_reading_ease(text),"flesch_kincaid_grade":textstat.flesch_kincaid_grade(text),"smog_index":textstat.smog_index(text),"coleman_liau_index":textstat.coleman_liau_index(text),"automated_readability_index":textstat.automated_readability_index(text),"dale_chall_readability_score":textstat.dale_chall_readability_score(text),"difficult_words":textstat.difficult_words(text),"linsear_write_formula":textstat.linsear_write_formula(text),"gunning_fog":textstat.gunning_fog(text),"text_standard":textstat.text_standard(text),"fernandez_huerta":textstat.fernandez_huerta(text),"szigriszt_pazos":textstat.szigriszt_pazos(text),"gutierrez_polini":textstat.gutierrez_polini(text),"crawford":textstat.crawford(text),"gulpease_index":textstat.gulpease_index(text),"osman":textstat.osman(text),}resp.update(text_complexity_metrics)ifvisualizeandnlpandoutput_dirisnotNone:doc=nlp(text)dep_out=spacy.displacy.render(doc,style="dep",jupyter=False,page=True)dep_output_path=Path(output_dir,hash_string(f"dep-{text}")+".html")dep_output_path.open("w",encoding="utf-8").write(dep_out)ent_out=spacy.displacy.render(doc,style="ent",jupyter=False,page=True)ent_output_path=Path(output_dir,hash_string(f"ent-{text}")+".html")ent_output_path.open("w",encoding="utf-8").write(ent_out)text_visualizations={"dependency_tree":wandb.Html(str(dep_output_path)),"entities":wandb.Html(str(ent_output_path)),}resp.update(text_visualizations)returnresp
[docs]defconstruct_html_from_prompt_and_generation(prompt:str,generation:str)->Any:"""Construct an html element from a prompt and a generation. Parameters: prompt (str): The prompt. generation (str): The generation. Returns: (wandb.Html): The html element."""wandb=import_wandb()formatted_prompt=prompt.replace("\n","<br>")formatted_generation=generation.replace("\n","<br>")returnwandb.Html(f""" <p style="color:black;">{formatted_prompt}:</p> <blockquote> <p style="color:green;">{formatted_generation} </p> </blockquote> """,inject=False,)
[docs]classWandbCallbackHandler(BaseMetadataCallbackHandler,BaseCallbackHandler):"""Callback Handler that logs to Weights and Biases. Parameters: job_type (str): The type of job. project (str): The project to log to. entity (str): The entity to log to. tags (list): The tags to log. group (str): The group to log to. name (str): The name of the run. notes (str): The notes to log. visualize (bool): Whether to visualize the run. complexity_metrics (bool): Whether to log complexity metrics. stream_logs (bool): Whether to stream callback actions to W&B This handler will utilize the associated callback method called and formats the input of each callback function with metadata regarding the state of LLM run, and adds the response to the list of records for both the {method}_records and action. It then logs the response using the run.log() method to Weights and Biases. """
[docs]def__init__(self,job_type:Optional[str]=None,project:Optional[str]="langchain_callback_demo",entity:Optional[str]=None,tags:Optional[Sequence]=None,group:Optional[str]=None,name:Optional[str]=None,notes:Optional[str]=None,visualize:bool=False,complexity_metrics:bool=False,stream_logs:bool=False,)->None:"""Initialize callback handler."""wandb=import_wandb()import_pandas()import_textstat()spacy=import_spacy()super().__init__()self.job_type=job_typeself.project=projectself.entity=entityself.tags=tagsself.group=groupself.name=nameself.notes=notesself.visualize=visualizeself.complexity_metrics=complexity_metricsself.stream_logs=stream_logsself.temp_dir=tempfile.TemporaryDirectory()self.run=wandb.init(job_type=self.job_type,project=self.project,entity=self.entity,tags=self.tags,group=self.group,name=self.name,notes=self.notes,)warning=("DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor ""of the `WandbTracer`. Please update your code to use the `WandbTracer` ""instead.")wandb.termwarn(warning,repeat=False,)self.callback_columns:list=[]self.action_records:list=[]self.complexity_metrics=complexity_metricsself.visualize=visualizeself.nlp=spacy.load("en_core_web_sm")warn_deprecated("0.3.8",pending=False,message=("Please use the WeaveTracer instead of the WandbCallbackHandler. ""The WeaveTracer is a more flexible and powerful tool for logging ""and tracing your LangChain callables.""Find more information at https://weave-docs.wandb.ai/guides/integrations/langchain"),alternative=("Please instantiate the WeaveTracer from ""weave.integrations.langchain import WeaveTracer .""For autologging simply use weave.init() and log all traces ""from your LangChain callables."),)
[docs]defon_llm_start(self,serialized:Dict[str,Any],prompts:List[str],**kwargs:Any)->None:"""Run when LLM starts."""self.step+=1self.llm_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_llm_start"})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())forpromptinprompts:prompt_resp=deepcopy(resp)prompt_resp["prompts"]=promptself.on_llm_start_records.append(prompt_resp)self.action_records.append(prompt_resp)ifself.stream_logs:self.run.log(prompt_resp)
[docs]defon_llm_new_token(self,token:str,**kwargs:Any)->None:"""Run when LLM generates a new token."""self.step+=1self.llm_streams+=1resp=self._init_resp()resp.update({"action":"on_llm_new_token","token":token})resp.update(self.get_custom_callback_meta())self.on_llm_token_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.run.log(resp)
[docs]defon_llm_end(self,response:LLMResult,**kwargs:Any)->None:"""Run when LLM ends running."""self.step+=1self.llm_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_llm_end"})resp.update(flatten_dict(response.llm_outputor{}))resp.update(self.get_custom_callback_meta())forgenerationsinresponse.generations:forgenerationingenerations:generation_resp=deepcopy(resp)generation_resp.update(flatten_dict(generation.dict()))generation_resp.update(analyze_text(generation.text,complexity_metrics=self.complexity_metrics,visualize=self.visualize,nlp=self.nlp,output_dir=self.temp_dir.name,))self.on_llm_end_records.append(generation_resp)self.action_records.append(generation_resp)ifself.stream_logs:self.run.log(generation_resp)
[docs]defon_llm_error(self,error:BaseException,**kwargs:Any)->None:"""Run when LLM errors."""self.step+=1self.errors+=1
[docs]defon_chain_start(self,serialized:Dict[str,Any],inputs:Dict[str,Any],**kwargs:Any)->None:"""Run when chain starts running."""self.step+=1self.chain_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_chain_start"})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())chain_input=inputs["input"]ifisinstance(chain_input,str):input_resp=deepcopy(resp)input_resp["input"]=chain_inputself.on_chain_start_records.append(input_resp)self.action_records.append(input_resp)ifself.stream_logs:self.run.log(input_resp)elifisinstance(chain_input,list):forinpinchain_input:input_resp=deepcopy(resp)input_resp.update(inp)self.on_chain_start_records.append(input_resp)self.action_records.append(input_resp)ifself.stream_logs:self.run.log(input_resp)else:raiseValueError("Unexpected data format provided!")
[docs]defon_chain_end(self,outputs:Dict[str,Any],**kwargs:Any)->None:"""Run when chain ends running."""self.step+=1self.chain_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_chain_end","outputs":outputs["output"]})resp.update(self.get_custom_callback_meta())self.on_chain_end_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.run.log(resp)
[docs]defon_chain_error(self,error:BaseException,**kwargs:Any)->None:"""Run when chain errors."""self.step+=1self.errors+=1
[docs]defon_tool_start(self,serialized:Dict[str,Any],input_str:str,**kwargs:Any)->None:"""Run when tool starts running."""self.step+=1self.tool_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_tool_start","input_str":input_str})resp.update(flatten_dict(serialized))resp.update(self.get_custom_callback_meta())self.on_tool_start_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.run.log(resp)
[docs]defon_tool_end(self,output:Any,**kwargs:Any)->None:"""Run when tool ends running."""output=str(output)self.step+=1self.tool_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_tool_end","output":output})resp.update(self.get_custom_callback_meta())self.on_tool_end_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.run.log(resp)
[docs]defon_tool_error(self,error:BaseException,**kwargs:Any)->None:"""Run when tool errors."""self.step+=1self.errors+=1
[docs]defon_text(self,text:str,**kwargs:Any)->None:""" Run when agent is ending. """self.step+=1self.text_ctr+=1resp=self._init_resp()resp.update({"action":"on_text","text":text})resp.update(self.get_custom_callback_meta())self.on_text_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.run.log(resp)
[docs]defon_agent_finish(self,finish:AgentFinish,**kwargs:Any)->None:"""Run when agent ends running."""self.step+=1self.agent_ends+=1self.ends+=1resp=self._init_resp()resp.update({"action":"on_agent_finish","output":finish.return_values["output"],"log":finish.log,})resp.update(self.get_custom_callback_meta())self.on_agent_finish_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.run.log(resp)
[docs]defon_agent_action(self,action:AgentAction,**kwargs:Any)->Any:"""Run on agent action."""self.step+=1self.tool_starts+=1self.starts+=1resp=self._init_resp()resp.update({"action":"on_agent_action","tool":action.tool,"tool_input":action.tool_input,"log":action.log,})resp.update(self.get_custom_callback_meta())self.on_agent_action_records.append(resp)self.action_records.append(resp)ifself.stream_logs:self.run.log(resp)
def_create_session_analysis_df(self)->Any:"""Create a dataframe with all the information from the session."""pd=import_pandas()on_llm_start_records_df=pd.DataFrame(self.on_llm_start_records)on_llm_end_records_df=pd.DataFrame(self.on_llm_end_records)llm_input_prompts_df=(on_llm_start_records_df[["step","prompts","name"]].dropna(axis=1).rename({"step":"prompt_step"},axis=1))complexity_metrics_columns=[]visualizations_columns=[]ifself.complexity_metrics:complexity_metrics_columns=["flesch_reading_ease","flesch_kincaid_grade","smog_index","coleman_liau_index","automated_readability_index","dale_chall_readability_score","difficult_words","linsear_write_formula","gunning_fog","text_standard","fernandez_huerta","szigriszt_pazos","gutierrez_polini","crawford","gulpease_index","osman",]ifself.visualize:visualizations_columns=["dependency_tree","entities"]llm_outputs_df=(on_llm_end_records_df[["step","text","token_usage_total_tokens","token_usage_prompt_tokens","token_usage_completion_tokens",]+complexity_metrics_columns+visualizations_columns].dropna(axis=1).rename({"step":"output_step","text":"output"},axis=1))session_analysis_df=pd.concat([llm_input_prompts_df,llm_outputs_df],axis=1)session_analysis_df["chat_html"]=session_analysis_df[["prompts","output"]].apply(lambdarow:construct_html_from_prompt_and_generation(row["prompts"],row["output"]),axis=1,)returnsession_analysis_df
[docs]defflush_tracker(self,langchain_asset:Any=None,reset:bool=True,finish:bool=False,job_type:Optional[str]=None,project:Optional[str]=None,entity:Optional[str]=None,tags:Optional[Sequence]=None,group:Optional[str]=None,name:Optional[str]=None,notes:Optional[str]=None,visualize:Optional[bool]=None,complexity_metrics:Optional[bool]=None,)->None:"""Flush the tracker and reset the session. Args: langchain_asset: The langchain asset to save. reset: Whether to reset the session. finish: Whether to finish the run. job_type: The job type. project: The project. entity: The entity. tags: The tags. group: The group. name: The name. notes: The notes. visualize: Whether to visualize. complexity_metrics: Whether to compute complexity metrics. Returns: None """pd=import_pandas()wandb=import_wandb()action_records_table=wandb.Table(dataframe=pd.DataFrame(self.action_records))session_analysis_table=wandb.Table(dataframe=self._create_session_analysis_df())self.run.log({"action_records":action_records_table,"session_analysis":session_analysis_table,})iflangchain_asset:langchain_asset_path=Path(self.temp_dir.name,"model.json")model_artifact=wandb.Artifact(name="model",type="model")model_artifact.add(action_records_table,name="action_records")model_artifact.add(session_analysis_table,name="session_analysis")try:langchain_asset.save(langchain_asset_path)model_artifact.add_file(str(langchain_asset_path))model_artifact.metadata=load_json_to_dict(langchain_asset_path)exceptValueError:langchain_asset.save_agent(langchain_asset_path)model_artifact.add_file(str(langchain_asset_path))model_artifact.metadata=load_json_to_dict(langchain_asset_path)exceptNotImplementedErrorase:print("Could not save model.")# noqa: T201print(repr(e))# noqa: T201passself.run.log_artifact(model_artifact)iffinishorreset:self.run.finish()self.temp_dir.cleanup()self.reset_callback_meta()ifreset:self.__init__(# type: ignorejob_type=job_typeifjob_typeelseself.job_type,project=projectifprojectelseself.project,entity=entityifentityelseself.entity,tags=tagsiftagselseself.tags,group=groupifgroupelseself.group,name=nameifnameelseself.name,notes=notesifnoteselseself.notes,visualize=visualizeifvisualizeelseself.visualize,complexity_metrics=(complexity_metricsifcomplexity_metricselseself.complexity_metrics),)