[docs]classBaseModel(Base):"""Base model for all SQL stores."""__abstract__=Trueuuid=sqlalchemy.Column(UUID(as_uuid=True),primary_key=True,default=uuid.uuid4)
[docs]@classmethoddefget_or_create(cls,session:Session,name:str,cmetadata:Optional[dict]=None,)->Tuple["CollectionStore",bool]:""" Get or create a collection. Returns [Collection, bool] where the bool is True if the collection was created. """created=Falsecollection=cls.get_by_name(session,name)ifcollection:returncollection,createdcollection=cls(name=name,cmetadata=cmetadata)session.add(collection)session.commit()created=Truereturncollection,created
[docs]classEmbeddingStore(BaseModel):"""Embedding store."""__tablename__="langchain_pg_embedding"collection_id=sqlalchemy.Column(UUID(as_uuid=True),sqlalchemy.ForeignKey(f"{CollectionStore.__tablename__}.uuid",ondelete="CASCADE",),)collection=relationship(CollectionStore,back_populates="embeddings")embedding=sqlalchemy.Column(sqlalchemy.ARRAY(sqlalchemy.REAL))# type: ignoredocument=sqlalchemy.Column(sqlalchemy.String,nullable=True)cmetadata=sqlalchemy.Column(JSON,nullable=True)# custom_id : any user defined idcustom_id=sqlalchemy.Column(sqlalchemy.String,nullable=True)
[docs]classQueryResult:"""Result from a query."""EmbeddingStore:EmbeddingStoredistance:float
[docs]classPGEmbedding(VectorStore):"""`Postgres` with the `pg_embedding` extension as a vector store. pg_embedding uses sequential scan by default. but you can create a HNSW index using the create_hnsw_index method. - `connection_string` is a postgres connection string. - `embedding_function` any embedding function implementing `langchain.embeddings.base.Embeddings` interface. - `collection_name` is the name of the collection to use. (default: langchain) - NOTE: This is not the name of the table, but the name of the collection. The tables will be created when initializing the store (if not exists) So, make sure the user has the right permissions to create tables. - `distance_strategy` is the distance strategy to use. (default: EUCLIDEAN) - `EUCLIDEAN` is the euclidean distance. - `pre_delete_collection` if True, will delete the collection if it exists. (default: False) - Useful for testing. """
[docs]defcreate_hnsw_extension(self)->None:try:withSession(self._conn)assession:statement=sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS embedding")session.execute(statement)session.commit()exceptExceptionase:self.logger.exception(e)
[docs]defcreate_hnsw_index(self,max_elements:int=10000,dims:int=ADA_TOKEN_COUNT,m:int=8,ef_construction:int=16,ef_search:int=16,)->None:create_index_query=sqlalchemy.text("CREATE INDEX IF NOT EXISTS langchain_pg_embedding_idx ""ON langchain_pg_embedding USING hnsw (embedding) ""WITH (""maxelements = {}, ""dims = {}, ""m = {}, ""efconstruction = {}, ""efsearch = {}"");".format(max_elements,dims,m,ef_construction,ef_search))# Execute the queriestry:withSession(self._conn)assession:# Create the HNSW indexsession.execute(create_index_query)session.commit()print("HNSW extension and index created successfully.")# noqa: T201exceptExceptionase:print(f"Failed to create HNSW extension or index: {e}")# noqa: T201
[docs]defdelete_collection(self)->None:self.logger.debug("Trying to delete collection")withSession(self._conn)assession:collection=self.get_collection(session)ifnotcollection:self.logger.warning("Collection not found")returnsession.delete(collection)session.commit()
[docs]defadd_embeddings(self,texts:List[str],embeddings:List[List[float]],metadatas:List[dict],ids:List[str],**kwargs:Any,)->None:withSession(self._conn)assession:collection=self.get_collection(session)ifnotcollection:raiseValueError("Collection not found")fortext,metadata,embedding,idinzip(texts,metadatas,embeddings,ids):embedding_store=EmbeddingStore(embedding=embedding,document=text,cmetadata=metadata,custom_id=id,)collection.embeddings.append(embedding_store)session.add(embedding_store)session.commit()
[docs]defadd_texts(self,texts:Iterable[str],metadatas:Optional[List[dict]]=None,ids:Optional[List[str]]=None,**kwargs:Any,)->List[str]:ifidsisNone:ids=[str(uuid.uuid4())for_intexts]embeddings=self.embedding_function.embed_documents(list(texts))ifnotmetadatas:metadatas=[{}for_intexts]withSession(self._conn)assession:collection=self.get_collection(session)ifnotcollection:raiseValueError("Collection not found")fortext,metadata,embedding,idinzip(texts,metadatas,embeddings,ids):embedding_store=EmbeddingStore(embedding=embedding,document=text,cmetadata=metadata,custom_id=id,)collection.embeddings.append(embedding_store)session.add(embedding_store)session.commit()returnids
[docs]defsimilarity_search_with_score_by_vector(self,embedding:List[float],k:int=4,filter:Optional[dict]=None,)->List[Tuple[Document,float]]:withSession(self._conn)assession:collection=self.get_collection(session)set_enable_seqscan_stmt=sqlalchemy.text("SET enable_seqscan = off")session.execute(set_enable_seqscan_stmt)ifnotcollection:raiseValueError("Collection not found")filter_by=EmbeddingStore.collection_id==collection.uuidiffilterisnotNone:filter_clauses=[]forkey,valueinfilter.items():IN="in"ifisinstance(value,dict)andINinmap(str.lower,value):value_case_insensitive={k.lower():vfork,vinvalue.items()}filter_by_metadata=EmbeddingStore.cmetadata[key].astext.in_(value_case_insensitive[IN])filter_clauses.append(filter_by_metadata)elifisinstance(value,dict)and"substring"inmap(str.lower,value):filter_by_metadata=EmbeddingStore.cmetadata[key].astext.ilike(f"%{value['substring']}%")filter_clauses.append(filter_by_metadata)else:filter_by_metadata=EmbeddingStore.cmetadata[key].astext==str(value)filter_clauses.append(filter_by_metadata)filter_by=sqlalchemy.and_(filter_by,*filter_clauses)results:List[QueryResult]=(session.query(EmbeddingStore,func.abs(EmbeddingStore.embedding.op("<->")(embedding)).label("distance"),)# Specify the columns you need here, e.g., EmbeddingStore.embedding.filter(filter_by).order_by(func.abs(EmbeddingStore.embedding.op("<->")(embedding)).asc())# Using PostgreSQL specific operator with the correct column name.limit(k).all())docs=[(Document(page_content=result.EmbeddingStore.document,# type: ignore[arg-type]metadata=result.EmbeddingStore.cmetadata,),result.distanceifself.embedding_functionisnotNoneelse0.0,)forresultinresults]returndocs
[docs]@classmethoddefget_connection_string(cls,kwargs:Dict[str,Any])->str:connection_string:str=get_from_dict_or_env(data=kwargs,key="connection_string",env_key="POSTGRES_CONNECTION_STRING",)ifnotconnection_string:raiseValueError("Postgres connection string is required""Either pass it as a parameter""or set the POSTGRES_CONNECTION_STRING environment variable.")returnconnection_string