"""Internal tracer to power the event stream API."""
from __future__ import annotations
import asyncio
import contextlib
import logging
from collections.abc import AsyncIterator, Iterator, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
TypeVar,
Union,
cast,
)
from uuid import UUID, uuid4
from typing_extensions import NotRequired, TypedDict
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessageChunk, BaseMessage, BaseMessageChunk
from langchain_core.outputs import (
ChatGenerationChunk,
GenerationChunk,
LLMResult,
)
from langchain_core.runnables.schema import (
CustomStreamEvent,
EventData,
StandardStreamEvent,
StreamEvent,
)
from langchain_core.runnables.utils import (
Input,
Output,
_RootEventFilter,
)
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.tracers.log_stream import LogEntry
from langchain_core.tracers.memory_stream import _MemoryStream
from langchain_core.utils.aiter import aclosing, py_anext
if TYPE_CHECKING:
from langchain_core.documents import Document
from langchain_core.runnables import Runnable, RunnableConfig
logger = logging.getLogger(__name__)
[docs]
class RunInfo(TypedDict):
"""Information about a run.
This is used to keep track of the metadata associated with a run.
Parameters:
name: The name of the run.
tags: The tags associated with the run.
metadata: The metadata associated with the run.
run_type: The type of the run.
inputs: The inputs to the run.
parent_run_id: The ID of the parent run.
"""
name: str
tags: list[str]
metadata: dict[str, Any]
run_type: str
inputs: NotRequired[Any]
parent_run_id: Optional[UUID]
def _assign_name(name: Optional[str], serialized: Optional[dict[str, Any]]) -> str:
"""Assign a name to a run."""
if name is not None:
return name
if serialized is not None:
if "name" in serialized:
return serialized["name"]
elif "id" in serialized:
return serialized["id"][-1]
return "Unnamed"
T = TypeVar("T")
class _AstreamEventsCallbackHandler(AsyncCallbackHandler, _StreamingCallbackHandler):
"""An implementation of an async callback handler for astream events."""
def __init__(
self,
*args: Any,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> None:
"""Initialize the tracer."""
super().__init__(*args, **kwargs)
# Map of run ID to run info.
# the entry corresponding to a given run id is cleaned
# up when each corresponding run ends.
self.run_map: dict[UUID, RunInfo] = {}
# The callback event that corresponds to the end of a parent run
# may be invoked BEFORE the callback event that corresponds to the end
# of a child run, which results in clean up of run_map.
# So we keep track of the mapping between children and parent run IDs
# in a separate container. This container is GCed when the tracer is GCed.
self.parent_map: dict[UUID, Optional[UUID]] = {}
self.is_tapped: dict[UUID, Any] = {}
# Filter which events will be sent over the queue.
self.root_event_filter = _RootEventFilter(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
loop = asyncio.get_event_loop()
memory_stream = _MemoryStream[StreamEvent](loop)
self.send_stream = memory_stream.get_send_stream()
self.receive_stream = memory_stream.get_receive_stream()
def _get_parent_ids(self, run_id: UUID) -> list[str]:
"""Get the parent IDs of a run (non-recursively) cast to strings."""
parent_ids = []
while parent_id := self.parent_map.get(run_id):
str_parent_id = str(parent_id)
if str_parent_id in parent_ids:
msg = (
f"Parent ID {parent_id} is already in the parent_ids list. "
f"This should never happen."
)
raise AssertionError(msg)
parent_ids.append(str_parent_id)
run_id = parent_id
# Return the parent IDs in reverse order, so that the first
# parent ID is the root and the last ID is the immediate parent.
return parent_ids[::-1]
def _send(self, event: StreamEvent, event_type: str) -> None:
"""Send an event to the stream."""
if self.root_event_filter.include_event(event, event_type):
self.send_stream.send_nowait(event)
def __aiter__(self) -> AsyncIterator[Any]:
"""Iterate over the receive stream."""
return self.receive_stream.__aiter__()
async def tap_output_aiter(
self, run_id: UUID, output: AsyncIterator[T]
) -> AsyncIterator[T]:
"""Tap the output aiter.
This method is used to tap the output of a Runnable that produces
an async iterator. It is used to generate stream events for the
output of the Runnable.
Args:
run_id: The ID of the run.
output: The output of the Runnable.
Yields:
T: The output of the Runnable.
"""
sentinel = object()
# atomic check and set
tap = self.is_tapped.setdefault(run_id, sentinel)
# wait for first chunk
first = await py_anext(output, default=sentinel)
if first is sentinel:
return
# get run info
run_info = self.run_map.get(run_id)
if run_info is None:
# run has finished, don't issue any stream events
yield cast(T, first)
return
if tap is sentinel:
# if we are the first to tap, issue stream events
event: StandardStreamEvent = {
"event": f"on_{run_info['run_type']}_stream",
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"data": {},
"parent_ids": self._get_parent_ids(run_id),
}
self._send({**event, "data": {"chunk": first}}, run_info["run_type"])
yield cast(T, first)
# consume the rest of the output
async for chunk in output:
self._send(
{**event, "data": {"chunk": chunk}},
run_info["run_type"],
)
yield chunk
else:
# otherwise just pass through
yield cast(T, first)
# consume the rest of the output
async for chunk in output:
yield chunk
def tap_output_iter(self, run_id: UUID, output: Iterator[T]) -> Iterator[T]:
"""Tap the output aiter.
Args:
run_id: The ID of the run.
output: The output of the Runnable.
Yields:
T: The output of the Runnable.
"""
sentinel = object()
# atomic check and set
tap = self.is_tapped.setdefault(run_id, sentinel)
# wait for first chunk
first = next(output, sentinel)
if first is sentinel:
return
# get run info
run_info = self.run_map.get(run_id)
if run_info is None:
# run has finished, don't issue any stream events
yield cast(T, first)
return
if tap is sentinel:
# if we are the first to tap, issue stream events
event: StandardStreamEvent = {
"event": f"on_{run_info['run_type']}_stream",
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"data": {},
"parent_ids": self._get_parent_ids(run_id),
}
self._send({**event, "data": {"chunk": first}}, run_info["run_type"])
yield cast(T, first)
# consume the rest of the output
for chunk in output:
self._send(
{**event, "data": {"chunk": chunk}},
run_info["run_type"],
)
yield chunk
else:
# otherwise just pass through
yield cast(T, first)
# consume the rest of the output
for chunk in output:
yield chunk
def _write_run_start_info(
self,
run_id: UUID,
*,
tags: Optional[list[str]],
metadata: Optional[dict[str, Any]],
parent_run_id: Optional[UUID],
name_: str,
run_type: str,
**kwargs: Any,
) -> None:
"""Update the run info."""
info: RunInfo = {
"tags": tags or [],
"metadata": metadata or {},
"name": name_,
"run_type": run_type,
"parent_run_id": parent_run_id,
}
if "inputs" in kwargs:
# Handle inputs in a special case to allow inputs to be an
# optionally provided and distinguish between missing value
# vs. None value.
info["inputs"] = kwargs["inputs"]
self.run_map[run_id] = info
self.parent_map[run_id] = parent_run_id
async def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
name_ = _assign_name(name, serialized)
run_type = "chat_model"
self._write_run_start_info(
run_id,
tags=tags,
metadata=metadata,
parent_run_id=parent_run_id,
name_=name_,
run_type=run_type,
inputs={"messages": messages},
)
self._send(
{
"event": "on_chat_model_start",
"data": {
"input": {"messages": messages},
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
"parent_ids": self._get_parent_ids(run_id),
},
run_type,
)
async def on_llm_start(
self,
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for an LLM run."""
name_ = _assign_name(name, serialized)
run_type = "llm"
self._write_run_start_info(
run_id,
tags=tags,
metadata=metadata,
parent_run_id=parent_run_id,
name_=name_,
run_type=run_type,
inputs={"prompts": prompts},
)
self._send(
{
"event": "on_llm_start",
"data": {
"input": {
"prompts": prompts,
}
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
"parent_ids": self._get_parent_ids(run_id),
},
run_type,
)
async def on_custom_event(
self,
name: str,
data: Any,
*,
run_id: UUID,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Generate a custom astream event."""
event = CustomStreamEvent(
event="on_custom_event",
run_id=str(run_id),
name=name,
tags=tags or [],
metadata=metadata or {},
data=data,
parent_ids=self._get_parent_ids(run_id),
)
self._send(event, name)
async def on_llm_new_token(
self,
token: str,
*,
chunk: Optional[Union[GenerationChunk, ChatGenerationChunk]] = None,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
run_info = self.run_map.get(run_id)
chunk_: Union[GenerationChunk, BaseMessageChunk]
if run_info is None:
msg = f"Run ID {run_id} not found in run map."
raise AssertionError(msg)
if self.is_tapped.get(run_id):
return
if run_info["run_type"] == "chat_model":
event = "on_chat_model_stream"
if chunk is None:
chunk_ = AIMessageChunk(content=token)
else:
chunk_ = cast(ChatGenerationChunk, chunk).message
elif run_info["run_type"] == "llm":
event = "on_llm_stream"
if chunk is None:
chunk_ = GenerationChunk(text=token)
else:
chunk_ = cast(GenerationChunk, chunk)
else:
msg = f"Unexpected run type: {run_info['run_type']}"
raise ValueError(msg)
self._send(
{
"event": event,
"data": {
"chunk": chunk_,
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"parent_ids": self._get_parent_ids(run_id),
},
run_info["run_type"],
)
async def on_llm_end(
self, response: LLMResult, *, run_id: UUID, **kwargs: Any
) -> None:
"""End a trace for an LLM run."""
run_info = self.run_map.pop(run_id)
inputs_ = run_info["inputs"]
generations: Union[list[list[GenerationChunk]], list[list[ChatGenerationChunk]]]
output: Union[dict, BaseMessage] = {}
if run_info["run_type"] == "chat_model":
generations = cast(list[list[ChatGenerationChunk]], response.generations)
for gen in generations:
if output != {}:
break
for chunk in gen:
output = chunk.message
break
event = "on_chat_model_end"
elif run_info["run_type"] == "llm":
generations = cast(list[list[GenerationChunk]], response.generations)
output = {
"generations": [
[
{
"text": chunk.text,
"generation_info": chunk.generation_info,
"type": chunk.type,
}
for chunk in gen
]
for gen in generations
],
"llm_output": response.llm_output,
}
event = "on_llm_end"
else:
msg = f"Unexpected run type: {run_info['run_type']}"
raise ValueError(msg)
self._send(
{
"event": event,
"data": {"output": output, "input": inputs_},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"parent_ids": self._get_parent_ids(run_id),
},
run_info["run_type"],
)
async def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
run_type: Optional[str] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a chain run."""
name_ = _assign_name(name, serialized)
run_type_ = run_type or "chain"
data: EventData = {}
# Work-around Runnable core code not sending input in some
# cases.
if inputs != {"input": ""}:
data["input"] = inputs
kwargs["inputs"] = inputs
self._write_run_start_info(
run_id,
tags=tags,
metadata=metadata,
parent_run_id=parent_run_id,
name_=name_,
run_type=run_type_,
**kwargs,
)
self._send(
{
"event": f"on_{run_type_}_start",
"data": data,
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
"parent_ids": self._get_parent_ids(run_id),
},
run_type_,
)
async def on_chain_end(
self,
outputs: dict[str, Any],
*,
run_id: UUID,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""End a trace for a chain run."""
run_info = self.run_map.pop(run_id)
run_type = run_info["run_type"]
event = f"on_{run_type}_end"
inputs = inputs or run_info.get("inputs") or {}
data: EventData = {
"output": outputs,
"input": inputs,
}
self._send(
{
"event": event,
"data": data,
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"parent_ids": self._get_parent_ids(run_id),
},
run_type,
)
async def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
tags: Optional[list[str]] = None,
parent_run_id: Optional[UUID] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
inputs: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Start a trace for a tool run."""
name_ = _assign_name(name, serialized)
self._write_run_start_info(
run_id,
tags=tags,
metadata=metadata,
parent_run_id=parent_run_id,
name_=name_,
run_type="tool",
inputs=inputs,
)
self._send(
{
"event": "on_tool_start",
"data": {
"input": inputs or {},
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
"parent_ids": self._get_parent_ids(run_id),
},
"tool",
)
async def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
"""End a trace for a tool run."""
run_info = self.run_map.pop(run_id)
if "inputs" not in run_info:
msg = (
f"Run ID {run_id} is a tool call and is expected to have "
f"inputs associated with it."
)
raise AssertionError(msg)
inputs = run_info["inputs"]
self._send(
{
"event": "on_tool_end",
"data": {
"output": output,
"input": inputs,
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"parent_ids": self._get_parent_ids(run_id),
},
"tool",
)
async def on_retriever_start(
self,
serialized: dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""Run when Retriever starts running."""
name_ = _assign_name(name, serialized)
run_type = "retriever"
self._write_run_start_info(
run_id,
tags=tags,
metadata=metadata,
parent_run_id=parent_run_id,
name_=name_,
run_type=run_type,
inputs={"query": query},
)
self._send(
{
"event": "on_retriever_start",
"data": {
"input": {
"query": query,
}
},
"name": name_,
"tags": tags or [],
"run_id": str(run_id),
"metadata": metadata or {},
"parent_ids": self._get_parent_ids(run_id),
},
run_type,
)
async def on_retriever_end(
self, documents: Sequence[Document], *, run_id: UUID, **kwargs: Any
) -> None:
"""Run when Retriever ends running."""
run_info = self.run_map.pop(run_id)
self._send(
{
"event": "on_retriever_end",
"data": {
"output": documents,
"input": run_info["inputs"],
},
"run_id": str(run_id),
"name": run_info["name"],
"tags": run_info["tags"],
"metadata": run_info["metadata"],
"parent_ids": self._get_parent_ids(run_id),
},
run_info["run_type"],
)
def __deepcopy__(self, memo: dict) -> _AstreamEventsCallbackHandler:
"""Deepcopy the tracer."""
return self
def __copy__(self) -> _AstreamEventsCallbackHandler:
"""Copy the tracer."""
return self
async def _astream_events_implementation_v1(
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[StandardStreamEvent]:
from langchain_core.runnables import ensure_config
from langchain_core.runnables.utils import _RootEventFilter
from langchain_core.tracers.log_stream import (
LogStreamCallbackHandler,
RunLog,
_astream_log_implementation,
)
stream = LogStreamCallbackHandler(
auto_close=False,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
_schema_format="streaming_events",
)
run_log = RunLog(state=None) # type: ignore[arg-type]
encountered_start_event = False
_root_event_filter = _RootEventFilter(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
config = ensure_config(config)
root_tags = config.get("tags", [])
root_metadata = config.get("metadata", {})
root_name = config.get("run_name", runnable.get_name())
# Ignoring mypy complaint about too many different union combinations
# This arises because many of the argument types are unions
async for log in _astream_log_implementation( # type: ignore[misc]
runnable,
input,
config=config,
stream=stream,
diff=True,
with_streamed_output_list=True,
**kwargs,
):
run_log = run_log + log
if not encountered_start_event:
# Yield the start event for the root runnable.
encountered_start_event = True
state = run_log.state.copy()
event = StandardStreamEvent(
event=f"on_{state['type']}_start",
run_id=state["id"],
name=root_name,
tags=root_tags,
metadata=root_metadata,
data={
"input": input,
},
parent_ids=[], # Not supported in v1
)
if _root_event_filter.include_event(event, state["type"]):
yield event
paths = {
op["path"].split("/")[2]
for op in log.ops
if op["path"].startswith("/logs/")
}
# Elements in a set should be iterated in the same order
# as they were inserted in modern python versions.
for path in paths:
data: EventData = {}
log_entry: LogEntry = run_log.state["logs"][path]
if log_entry["end_time"] is None:
event_type = "stream" if log_entry["streamed_output"] else "start"
else:
event_type = "end"
if event_type == "start":
# Include the inputs with the start event if they are available.
# Usually they will NOT be available for components that operate
# on streams, since those components stream the input and
# don't know its final value until the end of the stream.
inputs = log_entry["inputs"]
if inputs is not None:
data["input"] = inputs
if event_type == "end":
inputs = log_entry["inputs"]
if inputs is not None:
data["input"] = inputs
# None is a VALID output for an end event
data["output"] = log_entry["final_output"]
if event_type == "stream":
num_chunks = len(log_entry["streamed_output"])
if num_chunks != 1:
msg = (
f"Expected exactly one chunk of streamed output, "
f"got {num_chunks} instead. This is impossible. "
f"Encountered in: {log_entry['name']}"
)
raise AssertionError(msg)
data = {"chunk": log_entry["streamed_output"][0]}
# Clean up the stream, we don't need it anymore.
# And this avoids duplicates as well!
log_entry["streamed_output"] = []
yield StandardStreamEvent(
event=f"on_{log_entry['type']}_{event_type}",
name=log_entry["name"],
run_id=log_entry["id"],
tags=log_entry["tags"],
metadata=log_entry["metadata"],
data=data,
parent_ids=[], # Not supported in v1
)
# Finally, we take care of the streaming output from the root chain
# if there is any.
state = run_log.state
if state["streamed_output"]:
num_chunks = len(state["streamed_output"])
if num_chunks != 1:
msg = (
f"Expected exactly one chunk of streamed output, "
f"got {num_chunks} instead. This is impossible. "
f"Encountered in: {state['name']}"
)
raise AssertionError(msg)
data = {"chunk": state["streamed_output"][0]}
# Clean up the stream, we don't need it anymore.
state["streamed_output"] = []
event = StandardStreamEvent(
event=f"on_{state['type']}_stream",
run_id=state["id"],
tags=root_tags,
metadata=root_metadata,
name=root_name,
data=data,
parent_ids=[], # Not supported in v1
)
if _root_event_filter.include_event(event, state["type"]):
yield event
state = run_log.state
# Finally yield the end event for the root runnable.
event = StandardStreamEvent(
event=f"on_{state['type']}_end",
name=root_name,
run_id=state["id"],
tags=root_tags,
metadata=root_metadata,
data={
"output": state["final_output"],
},
parent_ids=[], # Not supported in v1
)
if _root_event_filter.include_event(event, state["type"]):
yield event
async def _astream_events_implementation_v2(
runnable: Runnable[Input, Output],
input: Any,
config: Optional[RunnableConfig] = None,
*,
include_names: Optional[Sequence[str]] = None,
include_types: Optional[Sequence[str]] = None,
include_tags: Optional[Sequence[str]] = None,
exclude_names: Optional[Sequence[str]] = None,
exclude_types: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> AsyncIterator[StandardStreamEvent]:
"""Implementation of the astream events API for V2 runnables."""
from langchain_core.callbacks.base import BaseCallbackManager
from langchain_core.runnables import ensure_config
event_streamer = _AstreamEventsCallbackHandler(
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_names=exclude_names,
exclude_types=exclude_types,
exclude_tags=exclude_tags,
)
# Assign the stream handler to the config
config = ensure_config(config)
run_id = cast(UUID, config.setdefault("run_id", uuid4()))
callbacks = config.get("callbacks")
if callbacks is None:
config["callbacks"] = [event_streamer]
elif isinstance(callbacks, list):
config["callbacks"] = callbacks + [event_streamer]
elif isinstance(callbacks, BaseCallbackManager):
callbacks = callbacks.copy()
callbacks.add_handler(event_streamer, inherit=True)
config["callbacks"] = callbacks
else:
msg = (
f"Unexpected type for callbacks: {callbacks}."
"Expected None, list or AsyncCallbackManager."
)
raise ValueError(msg)
# Call the runnable in streaming mode,
# add each chunk to the output stream
async def consume_astream() -> None:
try:
# if astream also calls tap_output_aiter this will be a no-op
async with aclosing(runnable.astream(input, config, **kwargs)) as stream:
async for _ in event_streamer.tap_output_aiter(run_id, stream):
# All the content will be picked up
pass
finally:
await event_streamer.send_stream.aclose()
# Start the runnable in a task, so we can start consuming output
task = asyncio.create_task(consume_astream())
first_event_sent = False
first_event_run_id = None
try:
async for event in event_streamer:
if not first_event_sent:
first_event_sent = True
# This is a work-around an issue where the inputs into the
# chain are not available until the entire input is consumed.
# As a temporary solution, we'll modify the input to be the input
# that was passed into the chain.
event["data"]["input"] = input
first_event_run_id = event["run_id"]
yield event
continue
# If it's the end event corresponding to the root runnable
# we dont include the input in the event since it's guaranteed
# to be included in the first event.
if (
event["run_id"] == first_event_run_id
and event["event"].endswith("_end")
and "input" in event["data"]
):
del event["data"]["input"]
yield event
except asyncio.CancelledError as exc:
# Cancel the task if it's still running
task.cancel(exc.args[0] if exc.args else None)
raise
finally:
# Cancel the task if it's still running
task.cancel()
# Await it anyway, to run any cleanup code, and propagate any exceptions
with contextlib.suppress(asyncio.CancelledError):
await task