Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/replit_river/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import grpc
from aiochannel import Channel, ChannelClosed
from opentelemetry.propagators.textmap import Setter
from opentelemetry.propagators.textmap import Getter, Setter
from pydantic import BaseModel, ConfigDict, Field

from replit_river.error_schema import (
Expand Down Expand Up @@ -126,6 +126,36 @@ def set(self, carrier: TransportMessage, key: str, value: str) -> None:
logger.warning("unknown trace propagation key", extra={"key": key})


class TransportMessageTracingGetter(Getter[TransportMessage]):
"""
Handles extracting tracing context from an incoming transport message.
"""

def get(self, carrier: TransportMessage, key: str) -> list[str] | None:
if not carrier.tracing:
return None
match key:
case "traceparent":
value = carrier.tracing.traceparent
case "tracestate":
value = carrier.tracing.tracestate
case _:
return None
if not value:
return None
return [value]

def keys(self, carrier: TransportMessage) -> list[str]:
if not carrier.tracing:
return []
keys: list[str] = []
if carrier.tracing.traceparent:
keys.append("traceparent")
if carrier.tracing.tracestate:
keys.append("tracestate")
return keys


class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]):
"""Represents a gRPC-compatible ServicerContext for River interop."""

Expand Down
53 changes: 51 additions & 2 deletions src/replit_river/server_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import websockets
from aiochannel import Channel, ChannelClosed
from opentelemetry import context, trace
from opentelemetry.trace import SpanKind
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from websockets.exceptions import ConnectionClosed

Expand All @@ -24,6 +26,7 @@
STREAM_OPEN_BIT,
GenericRpcHandlerBuilder,
TransportMessage,
TransportMessageTracingGetter,
TransportMessageTracingSetter,
)

Expand All @@ -32,9 +35,11 @@

logger = logging.getLogger(__name__)

tracer = trace.get_tracer(__name__)

trace_propagator = TraceContextTextMapPropagator()
trace_setter = TransportMessageTracingSetter()
trace_getter = TransportMessageTracingGetter()


class ServerSession(Session):
Expand Down Expand Up @@ -216,6 +221,23 @@ async def _open_stream_and_call_handler(
"upload-stream", # subscription
"stream",
)

# Extract trace context from the incoming message and create a server span.
extracted_context = trace_propagator.extract(
carrier=msg, getter=trace_getter
)
span = tracer.start_span(
f"river.server.{method_type}.{msg.serviceName}.{msg.procedureName}",
context=extracted_context,
kind=SpanKind.SERVER,
)
span.set_attribute("river.service_name", msg.serviceName)
span.set_attribute("river.procedure_name", msg.procedureName)
span.set_attribute("river.method_type", method_type)
span.set_attribute("river.stream_id", msg.streamId)
span.set_attribute("river.client_id", msg.from_)
handler_ctx = trace.set_span_in_context(span, extracted_context)

# New channel pair.
input_stream: Channel[Any] = Channel(
MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1
Expand All @@ -231,9 +253,13 @@ async def _open_stream_and_call_handler(
await input_stream.put(msg.payload)
except (RuntimeError, ChannelClosed) as e:
raise InvalidMessageException(e) from e
# Start the handler.
# Start the handler with the extracted trace context.
self._task_manager.create_task(
handler_func(msg.from_, input_stream, output_stream), tg
self._run_handler_with_tracing(
handler_func, msg.from_, input_stream, output_stream,
span, handler_ctx,
),
tg,
)
self._task_manager.create_task(
self._send_responses_from_output_stream(
Expand All @@ -243,6 +269,29 @@ async def _open_stream_and_call_handler(
)
return input_stream

async def _run_handler_with_tracing(
self,
handler_func: GenericRpcHandlerBuilder,
peer: str,
input_stream: Channel[Any],
output_stream: Channel[Any],
span: trace.Span,
handler_ctx: context.Context,
) -> None:
"""Run an RPC handler within the extracted trace context, ending the span
when the handler completes."""
token = context.attach(handler_ctx)
try:
await handler_func(peer, input_stream, output_stream)
span.set_status(trace.StatusCode.OK)
except Exception as e:
span.set_status(trace.StatusCode.ERROR, str(e))
span.record_exception(e)
raise
finally:
span.end()
context.detach(token)

async def _send_responses_from_output_stream(
self,
stream_id: str,
Expand Down
Loading
Loading