[docs]classArizeCallbackHandler(BaseCallbackHandler):"""Callback Handler that logs to Arize."""
[docs]def__init__(self,model_id:Optional[str]=None,model_version:Optional[str]=None,SPACE_KEY:Optional[str]=None,API_KEY:Optional[str]=None,)->None:"""Initialize callback handler."""super().__init__()self.model_id=model_idself.model_version=model_versionself.space_key=SPACE_KEYself.api_key=API_KEYself.prompt_records:List[str]=[]self.response_records:List[str]=[]self.prediction_ids:List[str]=[]self.pred_timestamps:List[int]=[]self.response_embeddings:List[float]=[]self.prompt_embeddings:List[float]=[]self.prompt_tokens=0self.completion_tokens=0self.total_tokens=0self.step=0fromarize.pandas.embeddingsimportEmbeddingGenerator,UseCasesfromarize.pandas.loggerimportClientself.generator=EmbeddingGenerator.from_use_case(use_case=UseCases.NLP.SEQUENCE_CLASSIFICATION,model_name="distilbert-base-uncased",tokenizer_max_length=512,batch_size=256,)self.arize_client=Client(space_key=SPACE_KEY,api_key=API_KEY)ifSPACE_KEY=="SPACE_KEY"orAPI_KEY=="API_KEY":raiseValueError("❌ CHANGE SPACE AND API KEYS")else:print("✅ Arize client setup done! Now you can start using Arize!")# noqa: T201
[docs]defon_llm_end(self,response:LLMResult,**kwargs:Any)->None:pd=import_pandas()fromarize.utils.typesimport(EmbeddingColumnNames,Environments,ModelTypes,Schema,)# Safe check if 'llm_output' and 'token_usage' existifresponse.llm_outputand"token_usage"inresponse.llm_output:self.prompt_tokens=response.llm_output["token_usage"].get("prompt_tokens",0)self.total_tokens=response.llm_output["token_usage"].get("total_tokens",0)self.completion_tokens=response.llm_output["token_usage"].get("completion_tokens",0)else:self.prompt_tokens=self.total_tokens=self.completion_tokens=(0# assign default value)forgenerationsinresponse.generations:forgenerationingenerations:prompt=self.prompt_records[self.step]self.step=self.step+1prompt_embedding=pd.Series(self.generator.generate_embeddings(text_col=pd.Series(prompt.replace("\n"," "))).reset_index(drop=True))# Assigning text to response_text instead of responseresponse_text=generation.text.replace("\n"," ")response_embedding=pd.Series(self.generator.generate_embeddings(text_col=pd.Series(generation.text.replace("\n"," "))).reset_index(drop=True))pred_timestamp=datetime.now().timestamp()# Define the columns and datacolumns=["prediction_ts","response","prompt","response_vector","prompt_vector","prompt_token","completion_token","total_token",]data=[[pred_timestamp,response_text,prompt,response_embedding[0],prompt_embedding[0],self.prompt_tokens,self.total_tokens,self.completion_tokens,]]# Create the DataFramedf=pd.DataFrame(data,columns=columns)# Declare prompt and response columnsprompt_columns=EmbeddingColumnNames(vector_column_name="prompt_vector",data_column_name="prompt")response_columns=EmbeddingColumnNames(vector_column_name="response_vector",data_column_name="response")schema=Schema(timestamp_column_name="prediction_ts",tag_column_names=["prompt_token","completion_token","total_token",],prompt_column_names=prompt_columns,response_column_names=response_columns,)response_from_arize=self.arize_client.log(dataframe=df,schema=schema,model_id=self.model_id,model_version=self.model_version,model_type=ModelTypes.GENERATIVE_LLM,environment=Environments.PRODUCTION,)ifresponse_from_arize.status_code==200:print("✅ Successfully logged data to Arize!")# noqa: T201else:print(f'❌ Logging failed "{response_from_arize.text}"')# noqa: T201