<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0-alpha1/dist/css/bootstrap.min.css" rel="stylesheet"
        integrity="sha384-GLhlTQ8iRABdZLl6O3oVMWSktQOp6b7In1Zl3/Jr59b6EGGoI1aFkw7cmDA6j6gD" crossorigin="anonymous">
    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.3.0/css/all.min.css"
        integrity="sha512-SzlrxWUlpfuzQ+pcUCosxcglQRNAq/DZjVsC0lE40xsADsfeQoEypE+enwcOiGjk/bSuGGKHEyjSoQ1zVisanQ=="
        crossorigin="anonymous" referrerpolicy="no-referrer" />
</head>
</html>
from typing import TYPE_CHECKING

import sentry_sdk
from sentry_sdk.consts import OP, SPANDATA
from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.integrations.grpc.consts import SPAN_ORIGIN
from sentry_sdk.tracing_utils import has_span_streaming_enabled

if TYPE_CHECKING:
    from typing import Any, Callable, Iterable, Iterator, Union

try:
    import grpc
    from google.protobuf.message import Message
    from grpc import Call, ClientCallDetails
    from grpc._interceptor import _UnaryOutcome
    from grpc.aio._interceptor import UnaryStreamCall
except ImportError:
    raise DidNotEnable("grpcio is not installed")


class ClientInterceptor(
    grpc.UnaryUnaryClientInterceptor,  # type: ignore
    grpc.UnaryStreamClientInterceptor,  # type: ignore
):
    def intercept_unary_unary(
        self: "ClientInterceptor",
        continuation: "Callable[[ClientCallDetails, Message], _UnaryOutcome]",
        client_call_details: "ClientCallDetails",
        request: "Message",
    ) -> "_UnaryOutcome":
        method = client_call_details.method

        span_streaming = has_span_streaming_enabled(sentry_sdk.get_client().options)
        if span_streaming:
            with sentry_sdk.traces.start_span(
                name="unary unary call to %s" % method,
                attributes={
                    "sentry.op": OP.GRPC_CLIENT,
                    "sentry.origin": SPAN_ORIGIN,
                    SPANDATA.RPC_METHOD: method,
                },
            ) as span:
                client_call_details = (
                    self._update_client_call_details_metadata_from_scope(
                        client_call_details
                    )
                )

                response = continuation(client_call_details, request)
                span.set_attribute(
                    SPANDATA.RPC_RESPONSE_STATUS_CODE, response.code().name
                )

                return response
        else:
            with sentry_sdk.start_span(
                op=OP.GRPC_CLIENT,
                name="unary unary call to %s" % method,
                origin=SPAN_ORIGIN,
            ) as span:
                span.set_data("type", "unary unary")
                span.set_data("method", method)

                client_call_details = (
                    self._update_client_call_details_metadata_from_scope(
                        client_call_details
                    )
                )

                response = continuation(client_call_details, request)
                span.set_data("code", response.code().name)

                return response

    def intercept_unary_stream(
        self: "ClientInterceptor",
        continuation: "Callable[[ClientCallDetails, Message], Union[Iterable[Any], UnaryStreamCall]]",
        client_call_details: "ClientCallDetails",
        request: "Message",
    ) -> "Union[Iterator[Message], Call]":
        method = client_call_details.method

        span_streaming = has_span_streaming_enabled(sentry_sdk.get_client().options)
        response: "UnaryStreamCall"
        if span_streaming:
            with sentry_sdk.traces.start_span(
                name="unary stream call to %s" % method,
                attributes={
                    "sentry.op": OP.GRPC_CLIENT,
                    "sentry.origin": SPAN_ORIGIN,
                    SPANDATA.RPC_METHOD: method,
                },
            ) as span:
                client_call_details = (
                    self._update_client_call_details_metadata_from_scope(
                        client_call_details
                    )
                )

                response = continuation(client_call_details, request)
                # Setting code on unary-stream leads to execution getting stuck
                # span.set_data("code", response.code().name)

                return response
        else:
            with sentry_sdk.start_span(
                op=OP.GRPC_CLIENT,
                name="unary stream call to %s" % method,
                origin=SPAN_ORIGIN,
            ) as span:
                span.set_data("type", "unary stream")
                span.set_data("method", method)

                client_call_details = (
                    self._update_client_call_details_metadata_from_scope(
                        client_call_details
                    )
                )

                response = continuation(client_call_details, request)
                # Setting code on unary-stream leads to execution getting stuck
                # span.set_data("code", response.code().name)

                return response

    @staticmethod
    def _update_client_call_details_metadata_from_scope(
        client_call_details: "ClientCallDetails",
    ) -> "ClientCallDetails":
        metadata = (
            list(client_call_details.metadata) if client_call_details.metadata else []
        )
        for (
            key,
            value,
        ) in sentry_sdk.get_current_scope().iter_trace_propagation_headers():
            metadata.append((key, value))

        client_call_details = grpc._interceptor._ClientCallDetails(
            method=client_call_details.method,
            timeout=client_call_details.timeout,
            metadata=metadata,
            credentials=client_call_details.credentials,
            wait_for_ready=client_call_details.wait_for_ready,
            compression=client_call_details.compression,
        )

        return client_call_details
