Source code for langchain_core.prompts.dict

"""Dict prompt template."""

import warnings
from functools import cached_property
from typing import Any, Literal, Optional

from typing_extensions import override

from langchain_core.load import dumpd
from langchain_core.prompts.string import (
    DEFAULT_FORMATTER_MAPPING,
    get_template_variables,
)
from langchain_core.runnables import RunnableConfig, RunnableSerializable
from langchain_core.runnables.config import ensure_config


[docs] class DictPromptTemplate(RunnableSerializable[dict, dict]): """Template represented by a dict. Recognizes variables in f-string or mustache formatted string dict values. Does NOT recognize variables in dict keys. Applies recursively. """ template: dict[str, Any] template_format: Literal["f-string", "mustache"] @property def input_variables(self) -> list[str]: """Template input variables.""" return _get_input_variables(self.template, self.template_format)
[docs] def format(self, **kwargs: Any) -> dict[str, Any]: """Format the prompt with the inputs.""" return _insert_input_variables(self.template, kwargs, self.template_format)
[docs] async def aformat(self, **kwargs: Any) -> dict[str, Any]: """Format the prompt with the inputs.""" return self.format(**kwargs)
[docs] @override def invoke( self, input: dict, config: Optional[RunnableConfig] = None, **kwargs: Any ) -> dict: """Invoke the prompt.""" return self._call_with_config( lambda x: self.format(**x), input, ensure_config(config), run_type="prompt", serialized=self._serialized, **kwargs, )
@property def _prompt_type(self) -> str: return "dict-prompt" @cached_property def _serialized(self) -> dict[str, Any]: return dumpd(self) @classmethod def is_lc_serializable(cls) -> bool: """Return whether or not the class is serializable. Returns: True. """ return True @classmethod def get_lc_namespace(cls) -> list[str]: """Serialization namespace.""" return ["langchain_core", "prompts", "dict"]
[docs] def pretty_repr(self, *, html: bool = False) -> str: """Human-readable representation. Args: html: Whether to format as HTML. Defaults to False. Returns: Human-readable representation. """ raise NotImplementedError
def _get_input_variables( template: dict, template_format: Literal["f-string", "mustache"] ) -> list[str]: input_variables = [] for v in template.values(): if isinstance(v, str): input_variables += get_template_variables(v, template_format) elif isinstance(v, dict): input_variables += _get_input_variables(v, template_format) elif isinstance(v, (list, tuple)): for x in v: if isinstance(x, str): input_variables += get_template_variables(x, template_format) elif isinstance(x, dict): input_variables += _get_input_variables(x, template_format) else: pass return list(set(input_variables)) def _insert_input_variables( template: dict[str, Any], inputs: dict[str, Any], template_format: Literal["f-string", "mustache"], ) -> dict[str, Any]: formatted = {} formatter = DEFAULT_FORMATTER_MAPPING[template_format] for k, v in template.items(): if isinstance(v, str): formatted[k] = formatter(v, **inputs) elif isinstance(v, dict): if k == "image_url" and "path" in v: msg = ( "Specifying image inputs via file path in environments with " "user-input paths is a security vulnerability. Out of an abundance " "of caution, the utility has been removed to prevent possible " "misuse." ) warnings.warn(msg, stacklevel=2) formatted[k] = _insert_input_variables(v, inputs, template_format) elif isinstance(v, (list, tuple)): formatted_v = [] for x in v: if isinstance(x, str): formatted_v.append(formatter(x, **inputs)) elif isinstance(x, dict): formatted_v.append( _insert_input_variables(x, inputs, template_format) ) formatted[k] = type(v)(formatted_v) else: formatted[k] = v return formatted