Source code for langchain_community.tools.memorize.tool

from abc import abstractmethod
from typing import Any, Optional, Protocol, Sequence, runtime_checkable

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import Field
from langchain_core.tools import BaseTool

from langchain_community.llms.gradient_ai import TrainResult


[docs]@runtime_checkable class TrainableLLM(Protocol): """Protocol for trainable language models."""
[docs] @abstractmethod def train_unsupervised(
self, inputs: Sequence[str], **kwargs: Any, ) -> TrainResult: ...
[docs] @abstractmethod async def atrain_unsupervised(
self, inputs: Sequence[str], **kwargs: Any, ) -> TrainResult: ...
[docs]class Memorize(BaseTool): """Tool that trains a language model.""" name: str = "memorize" description: str = ( "Useful whenever you observed novel information " "from previous conversation history, " "i.e., another tool's action outputs or human comments. " "The action input should include observed information in detail, " "then the tool will fine-tune yourself to remember it." ) llm: TrainableLLM = Field() def _run( self, information_to_learn: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> str: train_result = self.llm.train_unsupervised((information_to_learn,)) return f"Train complete. Loss: {train_result['loss']}" async def _arun( self, information_to_learn: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: train_result = await self.llm.atrain_unsupervised((information_to_learn,)) return f"Train complete. Loss: {train_result['loss']}"