Source code for langchain_community.embeddings.oracleai
# Authors:# Harichandan Roy (hroy)# David Jiang (ddjiang)## -----------------------------------------------------------------------------# oracleai.py# -----------------------------------------------------------------------------from__future__importannotationsimportjsonimportloggingimporttracebackfromtypingimportTYPE_CHECKING,Any,Dict,List,Optionalfromlangchain_core.embeddingsimportEmbeddingsfrompydanticimportBaseModel,ConfigDictifTYPE_CHECKING:fromoracledbimportConnectionlogger=logging.getLogger(__name__)"""OracleEmbeddings class"""
[docs]classOracleEmbeddings(BaseModel,Embeddings):"""Get Embeddings""""""Oracle Connection"""conn:Any=None"""Embedding Parameters"""params:Dict[str,Any]"""Proxy"""proxy:Optional[str]=Nonedef__init__(self,**kwargs:Any):super().__init__(**kwargs)model_config=ConfigDict(extra="forbid",)""" 1 - user needs to have create procedure, create mining model, create any directory privilege. 2 - grant create procedure, create mining model, create any directory to <user>; """
[docs]@staticmethoddefload_onnx_model(conn:Connection,dir:str,onnx_file:str,model_name:str)->None:"""Load an ONNX model to Oracle Database. Args: conn: Oracle Connection, dir: Oracle Directory, onnx_file: ONNX file name, model_name: Name of the model. """try:ifconnisNoneordirisNoneoronnx_fileisNoneormodel_nameisNone:raiseException("Invalid input")cursor=conn.cursor()cursor.execute(""" begin dbms_data_mining.drop_model(model_name => :model, force => true); SYS.DBMS_VECTOR.load_onnx_model(:path, :filename, :model, json('{"function" : "embedding", "embeddingOutput" : "embedding", "input": {"input": ["DATA"]}}')); end;""",path=dir,filename=onnx_file,model=model_name,)cursor.close()exceptExceptionasex:logger.info(f"An exception occurred :: {ex}")traceback.print_exc()cursor.close()raise
[docs]defembed_documents(self,texts:List[str])->List[List[float]]:"""Compute doc embeddings using an OracleEmbeddings. Args: texts: The list of texts to embed. Returns: List of embeddings, one for each input text. """try:importoracledbexceptImportErrorase:raiseImportError("Unable to import oracledb, please install with ""`pip install -U oracledb`.")fromeiftextsisNone:returnNoneembeddings:List[List[float]]=[]try:# returns strings or bytes instead of a locatororacledb.defaults.fetch_lobs=Falsecursor=self.conn.cursor()ifself.proxy:cursor.execute("begin utl_http.set_proxy(:proxy); end;",proxy=self.proxy)chunks=[]fori,textinenumerate(texts,start=1):chunk={"chunk_id":i,"chunk_data":text}chunks.append(json.dumps(chunk))vector_array_type=self.conn.gettype("SYS.VECTOR_ARRAY_T")inputs=vector_array_type.newobject(chunks)cursor.execute("select t.* "+"from dbms_vector_chain.utl_to_embeddings(:content, "+"json(:params)) t",content=inputs,params=json.dumps(self.params),)forrowincursor:ifrowisNone:embeddings.append([])else:rdata=json.loads(row[0])# dereference string as arrayvec=json.loads(rdata["embed_vector"])embeddings.append(vec)cursor.close()returnembeddingsexceptExceptionasex:logger.info(f"An exception occurred :: {ex}")traceback.print_exc()cursor.close()raise
[docs]defembed_query(self,text:str)->List[float]:"""Compute query embedding using an OracleEmbeddings. Args: text: The text to embed. Returns: Embedding for the text. """returnself.embed_documents([text])[0]
# uncomment the following code block to run the test"""# A sample unit test.import oracledb# get the Oracle connection conn = oracledb.connect( user="<user>", password="<password>", dsn="<hostname>/<service_name>",)print("Oracle connection is established...")# params embedder_params = {"provider": "database", "model": "demo_model"}proxy = ""# instanceembedder = OracleEmbeddings(conn=conn, params=embedder_params, proxy=proxy)docs = ["hello world!", "hi everyone!", "greetings!"]embeds = embedder.embed_documents(docs)print(f"Total Embeddings: {len(embeds)}")print(f"Embedding generated by OracleEmbeddings: {embeds[0]}\n")embed = embedder.embed_query("Hello World!")print(f"Embedding generated by OracleEmbeddings: {embed}")conn.close()print("Connection is closed.")"""