[docs]classAthenaLoader(BaseLoader):"""Load documents from `AWS Athena`. Each document represents one row of the result. - By default, all columns are written into the `page_content` of the document and none into the `metadata` of the document. - If `metadata_columns` are provided then these columns are written into the `metadata` of the document while the rest of the columns are written into the `page_content` of the document. To authenticate, the AWS client uses this method to automatically load credentials: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html If a specific credential profile should be used, you must pass the name of the profile from the ~/.aws/credentials file that is to be used. Make sure the credentials / roles used have the required policies to access the Amazon Textract service. """
[docs]def__init__(self,query:str,database:str,s3_output_uri:str,profile_name:Optional[str]=None,metadata_columns:Optional[List[str]]=None,):"""Initialize Athena document loader. Args: query: The query to run in Athena. database: Athena database. s3_output_uri: Athena output path. profile_name: Optional. AWS credential profile, if profiles are being used. metadata_columns: Optional. Columns written to Document `metadata`. """self.query=queryself.database=databaseself.s3_output_uri=s3_output_uriself.metadata_columns=metadata_columnsifmetadata_columnsisnotNoneelse[]try:importboto3exceptImportError:raiseImportError("Could not import boto3 python package. ""Please install it with `pip install boto3`.")try:session=(boto3.Session(profile_name=profile_name)ifprofile_nameisnotNoneelseboto3.Session())exceptExceptionase:raiseValueError("Could not load credentials to authenticate with AWS client. ""Please check that credentials in the specified ""profile name are valid.")fromeself.athena_client=session.client("athena")self.s3_client=session.client("s3")
def_execute_query(self)->List[Dict[str,Any]]:response=self.athena_client.start_query_execution(QueryString=self.query,QueryExecutionContext={"Database":self.database},ResultConfiguration={"OutputLocation":self.s3_output_uri},)query_execution_id=response["QueryExecutionId"]whileTrue:response=self.athena_client.get_query_execution(QueryExecutionId=query_execution_id)state=response["QueryExecution"]["Status"]["State"]ifstate=="SUCCEEDED":breakelifstate=="FAILED":resp_status=response["QueryExecution"]["Status"]state_change_reason=resp_status["StateChangeReason"]err=f"Query Failed: {state_change_reason}"raiseException(err)elifstate=="CANCELLED":raiseException("Query was cancelled by the user.")time.sleep(1)result_set=self._get_result_set(query_execution_id)returnjson.loads(result_set.to_json(orient="records"))def_remove_suffix(self,input_string:str,suffix:str)->str:ifsuffixandinput_string.endswith(suffix):returninput_string[:-len(suffix)]returninput_stringdef_remove_prefix(self,input_string:str,suffix:str)->str:ifsuffixandinput_string.startswith(suffix):returninput_string[len(suffix):]returninput_stringdef_get_result_set(self,query_execution_id:str)->Any:try:importpandasaspdexceptImportError:raiseImportError("Could not import pandas python package. ""Please install it with `pip install pandas`.")output_uri=self.s3_output_uritokens=self._remove_prefix(self._remove_suffix(output_uri,"/"),"s3://").split("/")bucket=tokens[0]key="/".join(tokens[1:]+[query_execution_id])+".csv"obj=self.s3_client.get_object(Bucket=bucket,Key=key)df=pd.read_csv(io.BytesIO(obj["Body"].read()),encoding="utf8")returndfdef_get_columns(self,query_result:List[Dict[str,Any]])->Tuple[List[str],List[str]]:content_columns=[]metadata_columns=[]all_columns=list(query_result[0].keys())forkeyinall_columns:ifkeyinself.metadata_columns:metadata_columns.append(key)else:content_columns.append(key)returncontent_columns,metadata_columns