[docs]classOCIAuthType(Enum):"""OCI authentication types as enumerator."""API_KEY=1SECURITY_TOKEN=2INSTANCE_PRINCIPAL=3RESOURCE_PRINCIPAL=4
[docs]classOCIGenAIBase(BaseModel,ABC):"""Base class for OCI GenAI models"""client:Any=Field(default=None,exclude=True)#: :meta private:auth_type:Optional[str]="API_KEY""""Authentication type, could be API_KEY, SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL If not specified, API_KEY will be used """auth_profile:Optional[str]="DEFAULT""""The name of the profile in ~/.oci/config If not specified , DEFAULT will be used """auth_file_location:Optional[str]="~/.oci/config""""Path to the config file. If not specified, ~/.oci/config will be used """model_id:Optional[str]=None"""Id of the model to call, e.g., cohere.command"""provider:Optional[str]=None"""Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input """model_kwargs:Optional[Dict]=None"""Keyword arguments to pass to the model"""service_endpoint:Optional[str]=None"""service endpoint url"""compartment_id:Optional[str]=None"""OCID of compartment"""is_stream:bool=False"""Whether to stream back partial progress"""model_config=ConfigDict(extra="forbid",arbitrary_types_allowed=True,protected_namespaces=())
[docs]@pre_initdefvalidate_environment(cls,values:Dict)->Dict:"""Validate that OCI config and python package exists in environment."""# Skip creating new client if passed in constructorifvalues["client"]isnotNone:returnvaluestry:importociclient_kwargs={"config":{},"signer":None,"service_endpoint":values["service_endpoint"],"retry_strategy":oci.retry.DEFAULT_RETRY_STRATEGY,"timeout":(10,240),# default timeout config for OCI Gen AI service}ifvalues["auth_type"]==OCIAuthType(1).name:client_kwargs["config"]=oci.config.from_file(file_location=values["auth_file_location"],profile_name=values["auth_profile"],)client_kwargs.pop("signer",None)elifvalues["auth_type"]==OCIAuthType(2).name:defmake_security_token_signer(oci_config):# type: ignore[no-untyped-def]pk=oci.signer.load_private_key_from_file(oci_config.get("key_file"),None)withopen(oci_config.get("security_token_file"),encoding="utf-8")asf:st_string=f.read()returnoci.auth.signers.SecurityTokenSigner(st_string,pk)client_kwargs["config"]=oci.config.from_file(file_location=values["auth_file_location"],profile_name=values["auth_profile"],)client_kwargs["signer"]=make_security_token_signer(oci_config=client_kwargs["config"])elifvalues["auth_type"]==OCIAuthType(3).name:client_kwargs["signer"]=(oci.auth.signers.InstancePrincipalsSecurityTokenSigner())elifvalues["auth_type"]==OCIAuthType(4).name:client_kwargs["signer"]=(oci.auth.signers.get_resource_principals_signer())else:raiseValueError("Please provide valid value to auth_type, "f"{values['auth_type']} is not valid.")values["client"]=oci.generative_ai_inference.GenerativeAiInferenceClient(**client_kwargs)exceptImportErrorasex:raiseModuleNotFoundError("Could not import oci python package. ""Please make sure you have the oci package installed.")fromexexceptExceptionase:raiseValueError("""Could not authenticate with OCI client. If INSTANCE_PRINCIPAL or RESOURCE_PRINCIPAL is used, please check the specified auth_profile, auth_file_location and auth_type are valid.""",e,)fromereturnvalues
@propertydef_identifying_params(self)->Mapping[str,Any]:"""Get the identifying parameters."""_model_kwargs=self.model_kwargsor{}return{**{"model_kwargs":_model_kwargs},}def_get_provider(self,provider_map:Mapping[str,Any])->Any:ifself.providerisnotNone:provider=self.providerelse:ifself.model_idisNone:raiseValueError("model_id is required to derive the provider, ""please provide the provider explicitly or specify ""the model_id to derive the provider.")provider=self.model_id.split(".")[0].lower()ifprovidernotinprovider_map:raiseValueError(f"Invalid provider derived from model_id: {self.model_id} ""Please explicitly pass in the supported provider ""when using custom endpoint")returnprovider_map[provider]
[docs]classOCIGenAI(LLM,OCIGenAIBase):"""OCI large language models. To authenticate, the OCI client uses the methods described in https://docs.oracle.com/en-us/iaas/Content/API/Concepts/sdk_authentication_methods.htm The authentifcation method is passed through auth_type and should be one of: API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL Make sure you have the required policies (profile/roles) to access the OCI Generative AI service. If a specific config profile is used, you must pass the name of the profile (from ~/.oci/config) through auth_profile. If a specific config file location is used, you must pass the file location where profile name configs present through auth_file_location To use, you must provide the compartment id along with the endpoint url, and model id as named parameters to the constructor. Example: .. code-block:: python from langchain_community.llms import OCIGenAI llm = OCIGenAI( model_id="MY_MODEL_ID", service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com", compartment_id="MY_OCID" ) """model_config=ConfigDict(extra="forbid",arbitrary_types_allowed=True,)@propertydef_llm_type(self)->str:"""Return type of llm."""return"oci_generative_ai_completion"@propertydef_provider_map(self)->Mapping[str,Any]:"""Get the provider map"""return{"cohere":CohereProvider(),"meta":MetaProvider(),}@propertydef_provider(self)->Any:"""Get the internal provider object"""returnself._get_provider(provider_map=self._provider_map)def_prepare_invocation_object(self,prompt:str,stop:Optional[List[str]],kwargs:Dict[str,Any])->Dict[str,Any]:fromoci.generative_ai_inferenceimportmodels_model_kwargs=self.model_kwargsor{}ifstopisnotNone:_model_kwargs[self._provider.stop_sequence_key]=stopifself.model_idisNone:raiseValueError("model_id is required to call the model, please provide the model_id.")ifself.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):serving_mode=models.DedicatedServingMode(endpoint_id=self.model_id)else:serving_mode=models.OnDemandServingMode(model_id=self.model_id)inference_params={**_model_kwargs,**kwargs}inference_params["prompt"]=promptinference_params["is_stream"]=self.is_streaminvocation_obj=models.GenerateTextDetails(compartment_id=self.compartment_id,serving_mode=serving_mode,inference_request=self._provider.llm_inference_request(**inference_params),)returninvocation_objdef_process_response(self,response:Any,stop:Optional[List[str]])->str:text=self._provider.completion_response_to_text(response)ifstopisnotNone:text=enforce_stop_tokens(text,stop)returntextdef_call(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->str:"""Call out to OCIGenAI generate endpoint. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: The string generated by the model. Example: .. code-block:: python response = llm.invoke("Tell me a joke.") """ifself.is_stream:text=""forchunkinself._stream(prompt,stop,run_manager,**kwargs):text+=chunk.textifstopisnotNone:text=enforce_stop_tokens(text,stop)returntextinvocation_obj=self._prepare_invocation_object(prompt,stop,kwargs)response=self.client.generate_text(invocation_obj)returnself._process_response(response,stop)def_stream(self,prompt:str,stop:Optional[List[str]]=None,run_manager:Optional[CallbackManagerForLLMRun]=None,**kwargs:Any,)->Iterator[GenerationChunk]:"""Stream OCIGenAI LLM on given prompt. Args: prompt: The prompt to pass into the model. stop: Optional list of stop words to use when generating. Returns: An iterator of GenerationChunks. Example: .. code-block:: python response = llm.stream("Tell me a joke.") """self.is_stream=Trueinvocation_obj=self._prepare_invocation_object(prompt,stop,kwargs)response=self.client.generate_text(invocation_obj)foreventinresponse.data.events():json_load=json.loads(event.data)if"text"injson_load:event_data_text=json_load["text"]else:event_data_text=""chunk=GenerationChunk(text=event_data_text)ifrun_manager:run_manager.on_llm_new_token(chunk.text,chunk=chunk)yieldchunk