Source code for langchain.chains.transform
"""Chain that runs an arbitrary python function."""
import functools
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from pydantic import Field
from langchain.chains.base import Chain
logger = logging.getLogger(__name__)
[docs]
class TransformChain(Chain):
"""Chain that transforms the chain output.
Example:
.. code-block:: python
from langchain.chains import TransformChain
transform_chain = TransformChain(input_variables=["text"],
output_variables["entities"], transform=func())
"""
input_variables: List[str]
"""The keys expected by the transform's input dictionary."""
output_variables: List[str]
"""The keys returned by the transform's output dictionary."""
transform_cb: Callable[[Dict[str, str]], Dict[str, str]] = Field(alias="transform")
"""The transform function."""
atransform_cb: Optional[Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]] = (
Field(None, alias="atransform")
)
"""The async coroutine transform function."""
@staticmethod
@functools.lru_cache
def _log_once(msg: str) -> None:
"""Log a message once.
:meta private:
"""
logger.warning(msg)
@property
def input_keys(self) -> List[str]:
"""Expect input keys.
:meta private:
"""
return self.input_variables
@property
def output_keys(self) -> List[str]:
"""Return output keys.
:meta private:
"""
return self.output_variables
def _call(
self,
inputs: Dict[str, str],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
return self.transform_cb(inputs)
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
if self.atransform_cb is not None:
return await self.atransform_cb(inputs)
else:
self._log_once(
"TransformChain's atransform is not provided, falling"
" back to synchronous transform"
)
return self.transform_cb(inputs)