"""Experiment with different models."""from__future__importannotationsfromtypingimportList,Optional,Sequencefromlangchain_core.language_models.llmsimportBaseLLMfromlangchain_core.prompts.promptimportPromptTemplatefromlangchain_core.utils.inputimportget_color_mapping,print_textfromlangchain.chains.baseimportChainfromlangchain.chains.llmimportLLMChain
[docs]classModelLaboratory:"""A utility to experiment with and compare the performance of different models."""
[docs]def__init__(self,chains:Sequence[Chain],names:Optional[List[str]]=None):"""Initialize the ModelLaboratory with chains to experiment with. Args: chains (Sequence[Chain]): A sequence of chains to experiment with. Each chain must have exactly one input and one output variable. names (Optional[List[str]]): Optional list of names corresponding to each chain. If provided, its length must match the number of chains. Raises: ValueError: If any chain is not an instance of `Chain`. ValueError: If a chain does not have exactly one input variable. ValueError: If a chain does not have exactly one output variable. ValueError: If the length of `names` does not match the number of chains. """forchaininchains:ifnotisinstance(chain,Chain):raiseValueError("ModelLaboratory should now be initialized with Chains. ""If you want to initialize with LLMs, use the `from_llms` method ""instead (`ModelLaboratory.from_llms(...)`)")iflen(chain.input_keys)!=1:raiseValueError("Currently only support chains with one input variable, "f"got {chain.input_keys}")iflen(chain.output_keys)!=1:raiseValueError("Currently only support chains with one output variable, "f"got {chain.output_keys}")ifnamesisnotNone:iflen(names)!=len(chains):raiseValueError("Length of chains does not match length of names.")self.chains=chainschain_range=[str(i)foriinrange(len(self.chains))]self.chain_colors=get_color_mapping(chain_range)self.names=names
[docs]@classmethoddeffrom_llms(cls,llms:List[BaseLLM],prompt:Optional[PromptTemplate]=None)->ModelLaboratory:"""Initialize the ModelLaboratory with LLMs and an optional prompt. Args: llms (List[BaseLLM]): A list of LLMs to experiment with. prompt (Optional[PromptTemplate]): An optional prompt to use with the LLMs. If provided, the prompt must contain exactly one input variable. Returns: ModelLaboratory: An instance of `ModelLaboratory` initialized with LLMs. """ifpromptisNone:prompt=PromptTemplate(input_variables=["_input"],template="{_input}")chains=[LLMChain(llm=llm,prompt=prompt)forllminllms]names=[str(llm)forllminllms]returncls(chains,names=names)
[docs]defcompare(self,text:str)->None:"""Compare model outputs on an input text. If a prompt was provided with starting the laboratory, then this text will be fed into the prompt. If no prompt was provided, then the input text is the entire prompt. Args: text: input text to run all models on. """print(f"\033[1mInput:\033[0m\n{text}\n")# noqa: T201fori,chaininenumerate(self.chains):ifself.namesisnotNone:name=self.names[i]else:name=str(chain)print_text(name,end="\n")output=chain.run(text)print_text(output,color=self.chain_colors[str(i)],end="\n\n")