Source code for langchain_community.utilities.nvidia_riva
"""A common module for NVIDIA Riva Runnables."""importasyncioimportloggingimportpathlibimportqueueimporttempfileimportthreadingimportwavefromenumimportEnumfromtypingimport(TYPE_CHECKING,Any,AsyncGenerator,AsyncIterator,Dict,Generator,Iterator,List,Optional,Tuple,Union,cast,)fromlangchain_core.messagesimportAnyMessage,BaseMessagefromlangchain_core.prompt_valuesimportPromptValuefromlangchain_core.pydantic_v1import(AnyHttpUrl,BaseModel,Field,parse_obj_as,root_validator,validator,)fromlangchain_core.runnablesimportRunnableConfig,RunnableSerializableifTYPE_CHECKING:importriva.clientimportriva.client.proto.riva_asr_pb2asrasr_LOGGER=logging.getLogger(__name__)_QUEUE_GET_TIMEOUT=0.5_MAX_TEXT_LENGTH=400_SENTENCE_TERMINATORS=("\n",".","!","?","ยก","ยฟ")# COMMON utilities used by all Riva Runnablesdef_import_riva_client()->"riva.client":"""Import the riva client and raise an error on failure."""try:# pylint: disable-next=import-outside-toplevel # this client library is optionalimportriva.clientexceptImportErroraserr:raiseImportError("Could not import the NVIDIA Riva client library. ""Please install it with `pip install nvidia-riva-client`.")fromerrreturnriva.client
[docs]classRivaAudioEncoding(str,Enum):"""An enum of the possible choices for Riva audio encoding. The list of types exposed by the Riva GRPC Protobuf files can be found with the following commands: ```python import riva.client print(riva.client.AudioEncoding.keys()) # noqa: T201 ``` """ALAW="ALAW"ENCODING_UNSPECIFIED="ENCODING_UNSPECIFIED"FLAC="FLAC"LINEAR_PCM="LINEAR_PCM"MULAW="MULAW"OGGOPUS="OGGOPUS"@classmethoddeffrom_wave_format_code(cls,format_code:int)->"RivaAudioEncoding":"""Return the audio encoding specified by the format code in the wave file. ref: https://mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html """try:return{1:cls.LINEAR_PCM,6:cls.ALAW,7:cls.MULAW}[format_code]exceptKeyErroraserr:raiseNotImplementedError("The following wave file format code is "f"not supported by Riva: {format_code}")fromerr@propertydefriva_pb2(self)->"riva.client.AudioEncoding":"""Returns the Riva API object for the encoding."""riva_client=_import_riva_client()returngetattr(riva_client.AudioEncoding,self)
[docs]classRivaAuthMixin(BaseModel):"""Configuration for the authentication to a Riva service connection."""url:Union[AnyHttpUrl,str]=Field(AnyHttpUrl("http://localhost:50051",scheme="http"),description="The full URL where the Riva service can be found.",examples=["http://localhost:50051","https://user@pass:riva.example.com"],)ssl_cert:Optional[str]=Field(None,description="A full path to the file where Riva's public ssl key can be read.",)@propertydefauth(self)->"riva.client.Auth":"""Return a riva client auth object."""riva_client=_import_riva_client()url=cast(AnyHttpUrl,self.url)use_ssl=url.scheme=="https"# pylint: disable=no-member # false positiveurl_no_scheme=str(self.url).split("/")[2]returnriva_client.Auth(ssl_cert=self.ssl_cert,use_ssl=use_ssl,uri=url_no_scheme)@validator("url",pre=True,allow_reuse=True)@classmethoddef_validate_url(cls,val:Any)->AnyHttpUrl:"""Do some initial conversations for the URL before checking."""ifisinstance(val,str):returncast(AnyHttpUrl,parse_obj_as(AnyHttpUrl,val))returncast(AnyHttpUrl,val)
[docs]classRivaCommonConfigMixin(BaseModel):"""A collection of common Riva settings."""encoding:RivaAudioEncoding=Field(default=RivaAudioEncoding.LINEAR_PCM,description="The encoding on the audio stream.",)sample_rate_hertz:int=Field(default=8000,description="The sample rate frequency of audio stream.")language_code:str=Field(default="en-US",description=("The [BCP-47 language code]""(https://www.rfc-editor.org/rfc/bcp/bcp47.txt) for ""the target language."),)
class_Event:"""A combined event that is threadsafe and async safe."""_event:threading.Event_aevent:asyncio.Eventdef__init__(self)->None:"""Initialize the event."""self._event=threading.Event()self._aevent=asyncio.Event()defset(self)->None:"""Set the event."""self._event.set()self._aevent.set()defclear(self)->None:"""Set the event."""self._event.clear()self._aevent.clear()defis_set(self)->bool:"""Indicate if the event is set."""returnself._event.is_set()defwait(self)->None:"""Wait for the event to be set."""self._event.wait()asyncdefasync_wait(self)->None:"""Async wait for the event to be set."""awaitself._aevent.wait()def_mk_wave_file(output_directory:Optional[str],sample_rate:float)->Tuple[Optional[str],Optional[wave.Wave_write]]:"""Create a new wave file and return the wave write object and filename."""ifoutput_directory:withtempfile.NamedTemporaryFile(mode="bx",suffix=".wav",delete=False,dir=output_directory)asf:wav_file_name=f.namewav_file=wave.open(wav_file_name,"wb")wav_file.setnchannels(1)wav_file.setsampwidth(2)wav_file.setframerate(sample_rate)return(wav_file_name,wav_file)return(None,None)def_coerce_string(val:"TTSInputType")->str:"""Attempt to coerce the input value to a string. This is particularly useful for converting LangChain message to strings. """ifisinstance(val,PromptValue):returnval.to_string()ifisinstance(val,BaseMessage):returnstr(val.content)returnstr(val)def_process_chunks(inputs:Iterator["TTSInputType"])->Generator[str,None,None]:"""Filter the input chunks are return strings ready for TTS."""buffer=""forchunkininputs:chunk=_coerce_string(chunk)# return the buffer if an end of sentence character is detectedforterminatorin_SENTENCE_TERMINATORS:whileterminatorinchunk:last_sentence,chunk=chunk.split(terminator,1)yieldbuffer+last_sentence+terminatorbuffer=""buffer+=chunk# return the buffer if is too longiflen(buffer)>_MAX_TEXT_LENGTH:foridxinrange(0,len(buffer),_MAX_TEXT_LENGTH):yieldbuffer[idx:idx+5]buffer=""# return remaining bufferifbuffer:yieldbuffer# Riva AudioStream TypeStreamInputType=Union[bytes,SentinelT]StreamOutputType=str
[docs]def__init__(self,maxsize:int=0)->None:"""Initialize the queue."""self._put_lock=threading.Lock()self._queue=queue.Queue(maxsize=maxsize)self.output=queue.Queue()self.hangup=_Event()self.user_quiet=_Event()self.user_talking=_Event()self._worker=None
def__iter__(self)->Generator[bytes,None,None]:"""Return an error."""whileTrue:# get next itemtry:next_val=self._queue.get(True,_QUEUE_GET_TIMEOUT)exceptqueue.Empty:continue# hangup when requestedifnext_val==HANGUP:break# yield next itemyieldnext_valself._queue.task_done()asyncdef__aiter__(self)->AsyncIterator[StreamInputType]:"""Iterate through all items in the queue until HANGUP."""whileTrue:# get next itemtry:next_val=awaitasyncio.get_event_loop().run_in_executor(None,self._queue.get,True,_QUEUE_GET_TIMEOUT)exceptqueue.Empty:continue# hangup when requestedifnext_val==HANGUP:break# yield next itemyieldnext_valself._queue.task_done()@propertydefhungup(self)->bool:"""Indicate if the audio stream has hungup."""returnself.hangup.is_set()@propertydefempty(self)->bool:"""Indicate in the input stream buffer is empty."""returnself._queue.empty()@propertydefcomplete(self)->bool:"""Indicate if the audio stream has hungup and been processed."""input_done=self.hungupandself.emptyoutput_done=(self._workerisnotNoneandnotself._worker.is_alive()andself.output.empty())returninput_doneandoutput_done@propertydefrunning(self)->bool:"""Indicate if the ASR stream is running."""ifself._worker:returnself._worker.is_alive()returnFalse
[docs]defput(self,item:StreamInputType,timeout:Optional[int]=None)->None:"""Put a new item into the queue."""withself._put_lock:ifself.hungup:raiseRuntimeError("The audio stream has already been hungup. Cannot put more data.")ifitemisHANGUP:self.hangup.set()self._queue.put(item,timeout=timeout)
[docs]asyncdefaput(self,item:StreamInputType,timeout:Optional[int]=None)->None:"""Async put a new item into the queue."""loop=asyncio.get_event_loop()awaitasyncio.wait_for(loop.run_in_executor(None,self.put,item),timeout)
[docs]defclose(self,timeout:Optional[int]=None)->None:"""Send the hangup signal."""self.put(HANGUP,timeout)
[docs]asyncdefaclose(self,timeout:Optional[int]=None)->None:"""Async send the hangup signal."""awaitself.aput(HANGUP,timeout)
[docs]defregister(self,responses:Iterator["rasr.StreamingRecognizeResponse"])->None:"""Drain the responses from the provided iterator and put them into a queue."""ifself.running:raiseRuntimeError("An ASR instance has already been registered.")has_started=threading.Barrier(2,timeout=5)defworker()->None:"""Consume the ASR Generator."""has_started.wait()forresponseinresponses:ifnotresponse.results:continueforresultinresponse.results:ifnotresult.alternatives:continueifresult.is_final:self.user_talking.clear()self.user_quiet.set()transcript=cast(str,result.alternatives[0].transcript)self.output.put(transcript)elifnotself.user_talking.is_set():self.user_talking.set()self.user_quiet.clear()self._worker=threading.Thread(target=worker)self._worker.daemon=Trueself._worker.start()has_started.wait()
[docs]classRivaASR(RivaAuthMixin,RivaCommonConfigMixin,RunnableSerializable[ASRInputType,ASROutputType],):"""A runnable that performs Automatic Speech Recognition (ASR) using NVIDIA Riva."""name:str="nvidia_riva_asr"description:str=("A Runnable for converting audio bytes to a string.""This is useful for feeding an audio stream into a chain and""preprocessing that audio to create an LLM prompt.")# riva optionsaudio_channel_count:int=Field(1,description="The number of audio channels in the input audio stream.")profanity_filter:bool=Field(True,description=("Controls whether or not Riva should attempt to filter ""profanity out of the transcribed text."),)enable_automatic_punctuation:bool=Field(True,description=("Controls whether Riva should attempt to correct ""senetence puncuation in the transcribed text."),)@root_validator(pre=True)@classmethoddef_validate_environment(cls,values:Dict[str,Any])->Dict[str,Any]:"""Validate the Python environment and input arguments."""_=_import_riva_client()returnvalues@propertydefconfig(self)->"riva.client.StreamingRecognitionConfig":"""Create and return the riva config object."""riva_client=_import_riva_client()returnriva_client.StreamingRecognitionConfig(interim_results=True,config=riva_client.RecognitionConfig(encoding=self.encoding,sample_rate_hertz=self.sample_rate_hertz,audio_channel_count=self.audio_channel_count,max_alternatives=1,profanity_filter=self.profanity_filter,enable_automatic_punctuation=self.enable_automatic_punctuation,language_code=self.language_code,),)def_get_service(self)->"riva.client.ASRService":"""Connect to the riva service and return the a client object."""riva_client=_import_riva_client()try:returnriva_client.ASRService(self.auth)exceptExceptionaserr:raiseValueError("Error raised while connecting to the Riva ASR server.")fromerr
[docs]definvoke(self,input:ASRInputType,_:Optional[RunnableConfig]=None,)->ASROutputType:"""Transcribe the audio bytes into a string with Riva."""# create an output text generator with Rivaifnotinput.running:service=self._get_service()responses=service.streaming_response_generator(audio_chunks=input,streaming_config=self.config,)input.register(responses)# return the first valid resultfull_response:List[str]=[]whilenotinput.complete:withinput.output.not_empty:ready=input.output.not_empty.wait(0.1)ifready:whilenotinput.output.empty():try:full_response+=[input.output.get_nowait()]exceptqueue.Empty:continueinput.output.task_done()_LOGGER.debug("Riva ASR returning: %s",repr(full_response))return" ".join(full_response).strip()return""
[docs]classRivaTTS(RivaAuthMixin,RivaCommonConfigMixin,RunnableSerializable[TTSInputType,TTSOutputType],):"""A runnable that performs Text-to-Speech (TTS) with NVIDIA Riva."""name:str="nvidia_riva_tts"description:str=("A tool for converting text to speech.""This is useful for converting LLM output into audio bytes.")# riva optionsvoice_name:str=Field("English-US.Female-1",description=("The voice model in Riva to use for speech. ""Pre-trained models are documented in ""[the Riva documentation]""(https://docs.nvidia.com/deeplearning/riva/user-guide/docs/tts/tts-overview.html)."),)output_directory:Optional[str]=Field(None,description=("The directory where all audio files should be saved. ""A null value indicates that wave files should not be saved. ""This is useful for debugging purposes."),)@root_validator(pre=True)@classmethoddef_validate_environment(cls,values:Dict[str,Any])->Dict[str,Any]:"""Validate the Python environment and input arguments."""_=_import_riva_client()returnvalues@validator("output_directory")@classmethoddef_output_directory_validator(cls,v:str)->str:ifv:dirpath=pathlib.Path(v)dirpath.mkdir(parents=True,exist_ok=True)returnstr(dirpath.absolute())returnvdef_get_service(self)->"riva.client.SpeechSynthesisService":"""Connect to the riva service and return the a client object."""riva_client=_import_riva_client()try:returnriva_client.SpeechSynthesisService(self.auth)exceptExceptionaserr:raiseValueError("Error raised while connecting to the Riva TTS server.")fromerr
[docs]definvoke(self,input:TTSInputType,_:Union[RunnableConfig,None]=None)->TTSOutputType:"""Perform TTS by taking a string and outputting the entire audio file."""returnb"".join(self.transform(iter([input])))
deftransform(self,input:Iterator[TTSInputType],config:Optional[RunnableConfig]=None,**kwargs:Optional[Any],)->Iterator[TTSOutputType]:"""Perform TTS by taking a stream of characters and streaming output bytes."""service=self._get_service()# create an output wave filewav_file_name,wav_file=_mk_wave_file(self.output_directory,self.sample_rate_hertz)# split the input text and perform ttsforchunkin_process_chunks(input):_LOGGER.debug("Riva TTS chunk: %s",chunk)# start riva tts streamingresponses=service.synthesize_online(text=chunk,voice_name=self.voice_name,language_code=self.language_code,encoding=self.encoding.riva_pb2,sample_rate_hz=self.sample_rate_hertz,)# stream audio bytes outforrespinresponses:audio=cast(bytes,resp.audio)ifwav_file:wav_file.writeframesraw(audio)yieldaudio# close the wave file when we are doneifwav_file:wav_file.close()_LOGGER.debug("Riva TTS wrote file: %s",wav_file_name)asyncdefatransform(self,input:AsyncIterator[TTSInputType],config:Optional[RunnableConfig]=None,**kwargs:Optional[Any],)->AsyncGenerator[TTSOutputType,None]:"""Intercept async transforms and route them to the synchronous transform."""loop=asyncio.get_running_loop()input_queue:queue.Queue=queue.Queue()out_queue:asyncio.Queue=asyncio.Queue()asyncdef_producer()->None:"""Produce input into the input queue."""asyncforvalininput:input_queue.put_nowait(val)input_queue.put_nowait(_TRANSFORM_END)def_input_iterator()->Iterator[TTSInputType]:"""Iterate over the input_queue."""whileTrue:try:val=input_queue.get(timeout=0.5)exceptqueue.Empty:continueifval==_TRANSFORM_END:breakyieldvaldef_consumer()->None:"""Consume the input with transform."""forvalinself.transform(_input_iterator()):out_queue.put_nowait(val)out_queue.put_nowait(_TRANSFORM_END)asyncdef_consumer_coro()->None:"""Coroutine that wraps the consumer."""awaitloop.run_in_executor(None,_consumer)producer=loop.create_task(_producer())consumer=loop.create_task(_consumer_coro())whileTrue:try:val=awaitasyncio.wait_for(out_queue.get(),0.5)exceptasyncio.exceptions.TimeoutError:continueout_queue.task_done()ifvalis_TRANSFORM_END:breakyieldvalawaitproducerawaitconsumer