diff --git a/.gitignore b/.gitignore index f6e1955edc7..775b65d5164 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,4 @@ # Ignoring AI agent files .agents/ +/server diff --git a/api/common/v1/principal.go-helpers.pb.go b/api/common/v1/principal.go-helpers.pb.go new file mode 100644 index 00000000000..412642d48d5 --- /dev/null +++ b/api/common/v1/principal.go-helpers.pb.go @@ -0,0 +1,43 @@ +// Code generated by protoc-gen-go-helpers. DO NOT EDIT. +package commonspb + +import ( + "google.golang.org/protobuf/proto" +) + +// Marshal an object of type AttributedPrincipal to the protobuf v3 wire format +func (val *AttributedPrincipal) Marshal() ([]byte, error) { + return proto.Marshal(val) +} + +// Unmarshal an object of type AttributedPrincipal from the protobuf v3 wire format +func (val *AttributedPrincipal) Unmarshal(buf []byte) error { + return proto.Unmarshal(buf, val) +} + +// Size returns the size of the object, in bytes, once serialized +func (val *AttributedPrincipal) Size() int { + return proto.Size(val) +} + +// Equal returns whether two AttributedPrincipal values are equivalent by recursively +// comparing the message's fields. +// For more information see the documentation for +// https://pkg.go.dev/google.golang.org/protobuf/proto#Equal +func (this *AttributedPrincipal) Equal(that interface{}) bool { + if that == nil { + return this == nil + } + + var that1 *AttributedPrincipal + switch t := that.(type) { + case *AttributedPrincipal: + that1 = t + case AttributedPrincipal: + that1 = &t + default: + return false + } + + return proto.Equal(this, that1) +} diff --git a/api/common/v1/principal.pb.go b/api/common/v1/principal.pb.go new file mode 100644 index 00000000000..1b7d5510e44 --- /dev/null +++ b/api/common/v1/principal.pb.go @@ -0,0 +1,148 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// plugins: +// protoc-gen-go +// protoc +// source: temporal/server/api/common/v1/principal.proto + +package commonspb + +import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + + v1 "go.temporal.io/api/common/v1" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// AttributedPrincipal is the at-rest representation of a captured caller +// identity used by Nexus storage. It embeds the stable identity +// (temporal.api.common.v1.Principal) verbatim and adds a write-time snapshot of +// the human-readable name, for audit fidelity when the identity is an opaque ID +// that may later be renamed or deleted. +type AttributedPrincipal struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The stable identity, copied verbatim from the authenticated Principal. + // In OSS, principal.name is already human-readable (e.g. jwt/, + // mtls/). In Cloud, principal.name is an opaque ID (never surfaced to + // users) and principal.type marks it as a cloud identity. + Principal *v1.Principal `protobuf:"bytes,1,opt,name=principal,proto3" json:"principal,omitempty"` + // resolved_name is a snapshot of the human-readable identity name at write + // time. Populated only where principal.name is an opaque ID (Cloud); EMPTY in + // OSS, where principal.name is already human-readable. Readers MUST fall back + // to principal.name when this is empty. + ResolvedName string `protobuf:"bytes,2,opt,name=resolved_name,json=resolvedName,proto3" json:"resolved_name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AttributedPrincipal) Reset() { + *x = AttributedPrincipal{} + mi := &file_temporal_server_api_common_v1_principal_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AttributedPrincipal) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AttributedPrincipal) ProtoMessage() {} + +func (x *AttributedPrincipal) ProtoReflect() protoreflect.Message { + mi := &file_temporal_server_api_common_v1_principal_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AttributedPrincipal.ProtoReflect.Descriptor instead. +func (*AttributedPrincipal) Descriptor() ([]byte, []int) { + return file_temporal_server_api_common_v1_principal_proto_rawDescGZIP(), []int{0} +} + +func (x *AttributedPrincipal) GetPrincipal() *v1.Principal { + if x != nil { + return x.Principal + } + return nil +} + +func (x *AttributedPrincipal) GetResolvedName() string { + if x != nil { + return x.ResolvedName + } + return "" +} + +var File_temporal_server_api_common_v1_principal_proto protoreflect.FileDescriptor + +const file_temporal_server_api_common_v1_principal_proto_rawDesc = "" + + "\n" + + "-temporal/server/api/common/v1/principal.proto\x12\x1dtemporal.server.api.common.v1\x1a$temporal/api/common/v1/message.proto\"{\n" + + "\x13AttributedPrincipal\x12?\n" + + "\tprincipal\x18\x01 \x01(\v2!.temporal.api.common.v1.PrincipalR\tprincipal\x12#\n" + + "\rresolved_name\x18\x02 \x01(\tR\fresolvedNameB/Z-go.temporal.io/server/api/common/v1;commonspbb\x06proto3" + +var ( + file_temporal_server_api_common_v1_principal_proto_rawDescOnce sync.Once + file_temporal_server_api_common_v1_principal_proto_rawDescData []byte +) + +func file_temporal_server_api_common_v1_principal_proto_rawDescGZIP() []byte { + file_temporal_server_api_common_v1_principal_proto_rawDescOnce.Do(func() { + file_temporal_server_api_common_v1_principal_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_temporal_server_api_common_v1_principal_proto_rawDesc), len(file_temporal_server_api_common_v1_principal_proto_rawDesc))) + }) + return file_temporal_server_api_common_v1_principal_proto_rawDescData +} + +var file_temporal_server_api_common_v1_principal_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_temporal_server_api_common_v1_principal_proto_goTypes = []any{ + (*AttributedPrincipal)(nil), // 0: temporal.server.api.common.v1.AttributedPrincipal + (*v1.Principal)(nil), // 1: temporal.api.common.v1.Principal +} +var file_temporal_server_api_common_v1_principal_proto_depIdxs = []int32{ + 1, // 0: temporal.server.api.common.v1.AttributedPrincipal.principal:type_name -> temporal.api.common.v1.Principal + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_temporal_server_api_common_v1_principal_proto_init() } +func file_temporal_server_api_common_v1_principal_proto_init() { + if File_temporal_server_api_common_v1_principal_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_temporal_server_api_common_v1_principal_proto_rawDesc), len(file_temporal_server_api_common_v1_principal_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_temporal_server_api_common_v1_principal_proto_goTypes, + DependencyIndexes: file_temporal_server_api_common_v1_principal_proto_depIdxs, + MessageInfos: file_temporal_server_api_common_v1_principal_proto_msgTypes, + }.Build() + File_temporal_server_api_common_v1_principal_proto = out.File + file_temporal_server_api_common_v1_principal_proto_goTypes = nil + file_temporal_server_api_common_v1_principal_proto_depIdxs = nil +} diff --git a/api/matchingservice/v1/request_response.pb.go b/api/matchingservice/v1/request_response.pb.go index da685a405d2..01280098219 100644 --- a/api/matchingservice/v1/request_response.pb.go +++ b/api/matchingservice/v1/request_response.pb.go @@ -3828,8 +3828,13 @@ type DispatchNexusTaskRequest struct { NamespaceId string `protobuf:"bytes,1,opt,name=namespace_id,json=namespaceId,proto3" json:"namespace_id,omitempty"` TaskQueue *v14.TaskQueue `protobuf:"bytes,2,opt,name=task_queue,json=taskQueue,proto3" json:"task_queue,omitempty"` // Nexus request extracted by the frontend and translated into Temporal API format. - Request *v113.Request `protobuf:"bytes,3,opt,name=request,proto3" json:"request,omitempty"` - ForwardInfo *v18.TaskForwardInfo `protobuf:"bytes,4,opt,name=forward_info,json=forwardInfo,proto3" json:"forward_info,omitempty"` + Request *v113.Request `protobuf:"bytes,3,opt,name=request,proto3" json:"request,omitempty"` + ForwardInfo *v18.TaskForwardInfo `protobuf:"bytes,4,opt,name=forward_info,json=forwardInfo,proto3" json:"forward_info,omitempty"` + // Verified caller identity the frontend attributed to this request, carried + // with the task so matching can surface it on PollNexusTaskQueueResponse for + // the handler worker. Server-set and trusted; never sourced from the inbound + // request. Unset when caller attribution is not configured. + CallerInfo *v113.NexusCallerInfo `protobuf:"bytes,5,opt,name=caller_info,json=callerInfo,proto3" json:"caller_info,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -3892,6 +3897,13 @@ func (x *DispatchNexusTaskRequest) GetForwardInfo() *v18.TaskForwardInfo { return nil } +func (x *DispatchNexusTaskRequest) GetCallerInfo() *v113.NexusCallerInfo { + if x != nil { + return x.CallerInfo + } + return nil +} + type DispatchNexusTaskResponse struct { state protoimpl.MessageState `protogen:"open.v1"` // Types that are valid to be assigned to Outcome: @@ -6105,13 +6117,15 @@ const file_temporal_server_api_matchingservice_v1_request_response_proto_rawDesc "\n" + "task_queue\x18\x02 \x01(\tR\ttaskQueue\x12\x18\n" + "\aversion\x18\x03 \x01(\x03R\aversion\"+\n" + - ")CheckTaskQueueUserDataPropagationResponse\"\x92\x02\n" + + ")CheckTaskQueueUserDataPropagationResponse\"\xdb\x02\n" + "\x18DispatchNexusTaskRequest\x12!\n" + "\fnamespace_id\x18\x01 \x01(\tR\vnamespaceId\x12C\n" + "\n" + "task_queue\x18\x02 \x01(\v2$.temporal.api.taskqueue.v1.TaskQueueR\ttaskQueue\x128\n" + "\arequest\x18\x03 \x01(\v2\x1e.temporal.api.nexus.v1.RequestR\arequest\x12T\n" + - "\fforward_info\x18\x04 \x01(\v21.temporal.server.api.taskqueue.v1.TaskForwardInfoR\vforwardInfo\"\xf4\x02\n" + + "\fforward_info\x18\x04 \x01(\v21.temporal.server.api.taskqueue.v1.TaskForwardInfoR\vforwardInfo\x12G\n" + + "\vcaller_info\x18\x05 \x01(\v2&.temporal.api.nexus.v1.NexusCallerInfoR\n" + + "callerInfo\"\xf4\x02\n" + "\x19DispatchNexusTaskResponse\x12N\n" + "\rhandler_error\x18\x01 \x01(\v2#.temporal.api.nexus.v1.HandlerErrorB\x02\x18\x01H\x00R\fhandlerError\x12=\n" + "\bresponse\x18\x02 \x01(\v2\x1f.temporal.api.nexus.v1.ResponseH\x00R\bresponse\x12t\n" + @@ -6357,27 +6371,28 @@ var file_temporal_server_api_matchingservice_v1_request_response_proto_goTypes = (*v112.RoutingConfig)(nil), // 132: temporal.api.deployment.v1.RoutingConfig (*v111.TaskQueueUserData)(nil), // 133: temporal.server.api.persistence.v1.TaskQueueUserData (*v113.Request)(nil), // 134: temporal.api.nexus.v1.Request - (*v113.HandlerError)(nil), // 135: temporal.api.nexus.v1.HandlerError - (*v113.Response)(nil), // 136: temporal.api.nexus.v1.Response - (*v114.Failure)(nil), // 137: temporal.api.failure.v1.Failure - (*v1.PollNexusTaskQueueRequest)(nil), // 138: temporal.api.workflowservice.v1.PollNexusTaskQueueRequest - (*v1.PollNexusTaskQueueResponse)(nil), // 139: temporal.api.workflowservice.v1.PollNexusTaskQueueResponse - (*v1.RespondNexusTaskCompletedRequest)(nil), // 140: temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest - (*v1.RespondNexusTaskFailedRequest)(nil), // 141: temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest - (*v111.NexusEndpointSpec)(nil), // 142: temporal.server.api.persistence.v1.NexusEndpointSpec - (*v111.NexusEndpointEntry)(nil), // 143: temporal.server.api.persistence.v1.NexusEndpointEntry - (*v1.RecordWorkerHeartbeatRequest)(nil), // 144: temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest - (*v1.ListWorkersRequest)(nil), // 145: temporal.api.workflowservice.v1.ListWorkersRequest - (*v115.WorkerInfo)(nil), // 146: temporal.api.worker.v1.WorkerInfo - (*v115.WorkerListInfo)(nil), // 147: temporal.api.worker.v1.WorkerListInfo - (*v1.UpdateTaskQueueConfigRequest)(nil), // 148: temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest - (*v14.TaskQueueConfig)(nil), // 149: temporal.api.taskqueue.v1.TaskQueueConfig - (*v1.DescribeWorkerRequest)(nil), // 150: temporal.api.workflowservice.v1.DescribeWorkerRequest - (v116.FairnessState)(0), // 151: temporal.server.api.enums.v1.FairnessState - (*v14.TaskQueueStats)(nil), // 152: temporal.api.taskqueue.v1.TaskQueueStats - (*v18.TaskQueueVersionInfoInternal)(nil), // 153: temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal - (*v1.UpdateWorkerBuildIdCompatibilityRequest)(nil), // 154: temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest - (*v110.WorkerDeploymentVersionData)(nil), // 155: temporal.server.api.deployment.v1.WorkerDeploymentVersionData + (*v113.NexusCallerInfo)(nil), // 135: temporal.api.nexus.v1.NexusCallerInfo + (*v113.HandlerError)(nil), // 136: temporal.api.nexus.v1.HandlerError + (*v113.Response)(nil), // 137: temporal.api.nexus.v1.Response + (*v114.Failure)(nil), // 138: temporal.api.failure.v1.Failure + (*v1.PollNexusTaskQueueRequest)(nil), // 139: temporal.api.workflowservice.v1.PollNexusTaskQueueRequest + (*v1.PollNexusTaskQueueResponse)(nil), // 140: temporal.api.workflowservice.v1.PollNexusTaskQueueResponse + (*v1.RespondNexusTaskCompletedRequest)(nil), // 141: temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest + (*v1.RespondNexusTaskFailedRequest)(nil), // 142: temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest + (*v111.NexusEndpointSpec)(nil), // 143: temporal.server.api.persistence.v1.NexusEndpointSpec + (*v111.NexusEndpointEntry)(nil), // 144: temporal.server.api.persistence.v1.NexusEndpointEntry + (*v1.RecordWorkerHeartbeatRequest)(nil), // 145: temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest + (*v1.ListWorkersRequest)(nil), // 146: temporal.api.workflowservice.v1.ListWorkersRequest + (*v115.WorkerInfo)(nil), // 147: temporal.api.worker.v1.WorkerInfo + (*v115.WorkerListInfo)(nil), // 148: temporal.api.worker.v1.WorkerListInfo + (*v1.UpdateTaskQueueConfigRequest)(nil), // 149: temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest + (*v14.TaskQueueConfig)(nil), // 150: temporal.api.taskqueue.v1.TaskQueueConfig + (*v1.DescribeWorkerRequest)(nil), // 151: temporal.api.workflowservice.v1.DescribeWorkerRequest + (v116.FairnessState)(0), // 152: temporal.server.api.enums.v1.FairnessState + (*v14.TaskQueueStats)(nil), // 153: temporal.api.taskqueue.v1.TaskQueueStats + (*v18.TaskQueueVersionInfoInternal)(nil), // 154: temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal + (*v1.UpdateWorkerBuildIdCompatibilityRequest)(nil), // 155: temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest + (*v110.WorkerDeploymentVersionData)(nil), // 156: temporal.server.api.deployment.v1.WorkerDeploymentVersionData } var file_temporal_server_api_matchingservice_v1_request_response_proto_depIdxs = []int32{ 92, // 0: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueRequest.poll_request:type_name -> temporal.api.workflowservice.v1.PollWorkflowTaskQueueRequest @@ -6489,49 +6504,50 @@ var file_temporal_server_api_matchingservice_v1_request_response_proto_depIdxs = 97, // 106: temporal.server.api.matchingservice.v1.DispatchNexusTaskRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue 134, // 107: temporal.server.api.matchingservice.v1.DispatchNexusTaskRequest.request:type_name -> temporal.api.nexus.v1.Request 111, // 108: temporal.server.api.matchingservice.v1.DispatchNexusTaskRequest.forward_info:type_name -> temporal.server.api.taskqueue.v1.TaskForwardInfo - 135, // 109: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.handler_error:type_name -> temporal.api.nexus.v1.HandlerError - 136, // 110: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.response:type_name -> temporal.api.nexus.v1.Response - 91, // 111: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.request_timeout:type_name -> temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.Timeout - 137, // 112: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.failure:type_name -> temporal.api.failure.v1.Failure - 138, // 113: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.request:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueRequest - 81, // 114: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.conditions:type_name -> temporal.server.api.matchingservice.v1.PollConditions - 139, // 115: temporal.server.api.matchingservice.v1.PollNexusTaskQueueResponse.response:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueResponse - 97, // 116: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue - 140, // 117: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest - 97, // 118: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue - 141, // 119: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest - 142, // 120: temporal.server.api.matchingservice.v1.CreateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec - 143, // 121: temporal.server.api.matchingservice.v1.CreateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 142, // 122: temporal.server.api.matchingservice.v1.UpdateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec - 143, // 123: temporal.server.api.matchingservice.v1.UpdateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 143, // 124: temporal.server.api.matchingservice.v1.ListNexusEndpointsResponse.entries:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry - 144, // 125: temporal.server.api.matchingservice.v1.RecordWorkerHeartbeatRequest.heartbeart_request:type_name -> temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest - 145, // 126: temporal.server.api.matchingservice.v1.ListWorkersRequest.list_request:type_name -> temporal.api.workflowservice.v1.ListWorkersRequest - 146, // 127: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers_info:type_name -> temporal.api.worker.v1.WorkerInfo - 147, // 128: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers:type_name -> temporal.api.worker.v1.WorkerListInfo - 148, // 129: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigRequest.update_taskqueue_config:type_name -> temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest - 149, // 130: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigResponse.updated_taskqueue_config:type_name -> temporal.api.taskqueue.v1.TaskQueueConfig - 150, // 131: temporal.server.api.matchingservice.v1.DescribeWorkerRequest.request:type_name -> temporal.api.workflowservice.v1.DescribeWorkerRequest - 146, // 132: temporal.server.api.matchingservice.v1.DescribeWorkerResponse.worker_info:type_name -> temporal.api.worker.v1.WorkerInfo - 115, // 133: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType - 151, // 134: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.fairness_state:type_name -> temporal.server.api.enums.v1.FairnessState - 115, // 135: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType - 117, // 136: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.version:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersion - 95, // 137: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponse.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery - 95, // 138: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponseWithRawHistory.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery - 115, // 139: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesRequest.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType - 115, // 140: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType - 152, // 141: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats:type_name -> temporal.api.taskqueue.v1.TaskQueueStats - 86, // 142: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats_by_priority_key:type_name -> temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry - 152, // 143: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry.value:type_name -> temporal.api.taskqueue.v1.TaskQueueStats - 153, // 144: temporal.server.api.matchingservice.v1.DescribeTaskQueuePartitionResponse.VersionsInfoInternalEntry.value:type_name -> temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal - 154, // 145: temporal.server.api.matchingservice.v1.UpdateWorkerBuildIdCompatibilityRequest.ApplyPublicRequest.request:type_name -> temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest - 155, // 146: temporal.server.api.matchingservice.v1.SyncDeploymentUserDataRequest.UpsertVersionsDataEntry.value:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersionData - 147, // [147:147] is the sub-list for method output_type - 147, // [147:147] is the sub-list for method input_type - 147, // [147:147] is the sub-list for extension type_name - 147, // [147:147] is the sub-list for extension extendee - 0, // [0:147] is the sub-list for field type_name + 135, // 109: temporal.server.api.matchingservice.v1.DispatchNexusTaskRequest.caller_info:type_name -> temporal.api.nexus.v1.NexusCallerInfo + 136, // 110: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.handler_error:type_name -> temporal.api.nexus.v1.HandlerError + 137, // 111: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.response:type_name -> temporal.api.nexus.v1.Response + 91, // 112: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.request_timeout:type_name -> temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.Timeout + 138, // 113: temporal.server.api.matchingservice.v1.DispatchNexusTaskResponse.failure:type_name -> temporal.api.failure.v1.Failure + 139, // 114: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.request:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueRequest + 81, // 115: temporal.server.api.matchingservice.v1.PollNexusTaskQueueRequest.conditions:type_name -> temporal.server.api.matchingservice.v1.PollConditions + 140, // 116: temporal.server.api.matchingservice.v1.PollNexusTaskQueueResponse.response:type_name -> temporal.api.workflowservice.v1.PollNexusTaskQueueResponse + 97, // 117: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue + 141, // 118: temporal.server.api.matchingservice.v1.RespondNexusTaskCompletedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskCompletedRequest + 97, // 119: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.task_queue:type_name -> temporal.api.taskqueue.v1.TaskQueue + 142, // 120: temporal.server.api.matchingservice.v1.RespondNexusTaskFailedRequest.request:type_name -> temporal.api.workflowservice.v1.RespondNexusTaskFailedRequest + 143, // 121: temporal.server.api.matchingservice.v1.CreateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec + 144, // 122: temporal.server.api.matchingservice.v1.CreateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 143, // 123: temporal.server.api.matchingservice.v1.UpdateNexusEndpointRequest.spec:type_name -> temporal.server.api.persistence.v1.NexusEndpointSpec + 144, // 124: temporal.server.api.matchingservice.v1.UpdateNexusEndpointResponse.entry:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 144, // 125: temporal.server.api.matchingservice.v1.ListNexusEndpointsResponse.entries:type_name -> temporal.server.api.persistence.v1.NexusEndpointEntry + 145, // 126: temporal.server.api.matchingservice.v1.RecordWorkerHeartbeatRequest.heartbeart_request:type_name -> temporal.api.workflowservice.v1.RecordWorkerHeartbeatRequest + 146, // 127: temporal.server.api.matchingservice.v1.ListWorkersRequest.list_request:type_name -> temporal.api.workflowservice.v1.ListWorkersRequest + 147, // 128: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers_info:type_name -> temporal.api.worker.v1.WorkerInfo + 148, // 129: temporal.server.api.matchingservice.v1.ListWorkersResponse.workers:type_name -> temporal.api.worker.v1.WorkerListInfo + 149, // 130: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigRequest.update_taskqueue_config:type_name -> temporal.api.workflowservice.v1.UpdateTaskQueueConfigRequest + 150, // 131: temporal.server.api.matchingservice.v1.UpdateTaskQueueConfigResponse.updated_taskqueue_config:type_name -> temporal.api.taskqueue.v1.TaskQueueConfig + 151, // 132: temporal.server.api.matchingservice.v1.DescribeWorkerRequest.request:type_name -> temporal.api.workflowservice.v1.DescribeWorkerRequest + 147, // 133: temporal.server.api.matchingservice.v1.DescribeWorkerResponse.worker_info:type_name -> temporal.api.worker.v1.WorkerInfo + 115, // 134: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType + 152, // 135: temporal.server.api.matchingservice.v1.UpdateFairnessStateRequest.fairness_state:type_name -> temporal.server.api.enums.v1.FairnessState + 115, // 136: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.task_queue_type:type_name -> temporal.api.enums.v1.TaskQueueType + 117, // 137: temporal.server.api.matchingservice.v1.CheckTaskQueueVersionMembershipRequest.version:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersion + 95, // 138: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponse.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery + 95, // 139: temporal.server.api.matchingservice.v1.PollWorkflowTaskQueueResponseWithRawHistory.QueriesEntry.value:type_name -> temporal.api.query.v1.WorkflowQuery + 115, // 140: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesRequest.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType + 115, // 141: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.type:type_name -> temporal.api.enums.v1.TaskQueueType + 153, // 142: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats:type_name -> temporal.api.taskqueue.v1.TaskQueueStats + 86, // 143: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.stats_by_priority_key:type_name -> temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry + 153, // 144: temporal.server.api.matchingservice.v1.DescribeVersionedTaskQueuesResponse.VersionTaskQueue.StatsByPriorityKeyEntry.value:type_name -> temporal.api.taskqueue.v1.TaskQueueStats + 154, // 145: temporal.server.api.matchingservice.v1.DescribeTaskQueuePartitionResponse.VersionsInfoInternalEntry.value:type_name -> temporal.server.api.taskqueue.v1.TaskQueueVersionInfoInternal + 155, // 146: temporal.server.api.matchingservice.v1.UpdateWorkerBuildIdCompatibilityRequest.ApplyPublicRequest.request:type_name -> temporal.api.workflowservice.v1.UpdateWorkerBuildIdCompatibilityRequest + 156, // 147: temporal.server.api.matchingservice.v1.SyncDeploymentUserDataRequest.UpsertVersionsDataEntry.value:type_name -> temporal.server.api.deployment.v1.WorkerDeploymentVersionData + 148, // [148:148] is the sub-list for method output_type + 148, // [148:148] is the sub-list for method input_type + 148, // [148:148] is the sub-list for extension type_name + 148, // [148:148] is the sub-list for extension extendee + 0, // [0:148] is the sub-list for field type_name } func init() { file_temporal_server_api_matchingservice_v1_request_response_proto_init() } diff --git a/chasm/lib/nexusoperation/fx.go b/chasm/lib/nexusoperation/fx.go index ec5caeb11ad..cd80942d73d 100644 --- a/chasm/lib/nexusoperation/fx.go +++ b/chasm/lib/nexusoperation/fx.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "os" "go.temporal.io/api/serviceerror" persistencespb "go.temporal.io/server/api/persistence/v1" @@ -12,11 +13,13 @@ import ( "go.temporal.io/server/common" "go.temporal.io/server/common/cluster" "go.temporal.io/server/common/collection" + "go.temporal.io/server/common/config" "go.temporal.io/server/common/dynamicconfig" "go.temporal.io/server/common/log" "go.temporal.io/server/common/metrics" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/nexus/nexusrpc" + "go.temporal.io/server/common/nexus/principaltoken" "go.temporal.io/server/common/persistence" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/rpc" @@ -28,6 +31,11 @@ const nexusCallbackSourceHeader = "Nexus-Callback-Source" var Module = fx.Module( "chasm.lib.nexusoperation", fx.Provide(configProvider), + // Signed caller-identity propagation across the Nexus hop (Signer/Verifier/ + // KeyProvider). Feature-off by default; the decorator drives it from + // config.Global.NexusPrincipalPropagation. Cloud may override the seams. + principaltoken.Module, + fx.Decorate(principalTokenConfigFromServerConfig), fx.Provide(commonnexus.NewCallbackTokenGenerator), fx.Provide(endpointRegistryProvider), fx.Invoke(endpointRegistryLifetimeHooks), @@ -65,6 +73,54 @@ func register( return registry.Register(library) } +// principalTokenConfigFromServerConfig decorates the principaltoken.Module's +// default (feature-off) Config with the operator-supplied settings from the +// official server config. This is what makes the feature configurable via YAML +// / temporal.WithConfig rather than test-only fx injection. An empty +// config.Global.NexusPrincipalPropagation yields an empty principaltoken.Config +// (feature stays off). +func principalTokenConfigFromServerConfig(_ principaltoken.Config, cfg *config.Config) (principaltoken.Config, error) { + src := cfg.Global.NexusPrincipalPropagation + out := principaltoken.Config{ + Issuer: src.Issuer, + SigningKeyID: src.SigningKeyID, + TrustMode: principaltoken.TrustMode(src.TrustMode), + TTL: src.TTL, + Leeway: src.Leeway, + } + signingKey, err := readPEM(src.SigningKeyData, src.SigningKeyFile) + if err != nil { + return principaltoken.Config{}, fmt.Errorf("nexus principal signing key: %w", err) + } + out.SigningKeyPEM = signingKey + + for _, ti := range src.TrustedIssuers { + pub, err := readPEM(ti.PublicKeyData, ti.PublicKeyFile) + if err != nil { + return principaltoken.Config{}, fmt.Errorf("nexus trusted issuer %s/%s: %w", ti.Issuer, ti.KeyID, err) + } + if out.TrustedIssuers == nil { + out.TrustedIssuers = map[string]map[string][]byte{} + } + if out.TrustedIssuers[ti.Issuer] == nil { + out.TrustedIssuers[ti.Issuer] = map[string][]byte{} + } + out.TrustedIssuers[ti.Issuer][ti.KeyID] = pub + } + return out, nil +} + +// readPEM returns inline PEM data if present, else reads it from file, else nil. +func readPEM(data, file string) ([]byte, error) { + if data != "" { + return []byte(data), nil + } + if file != "" { + return os.ReadFile(file) + } + return nil, nil +} + func endpointRegistryProvider( matchingClient resource.MatchingClient, endpointManager persistence.NexusEndpointManager, diff --git a/chasm/lib/nexusoperation/fx_test.go b/chasm/lib/nexusoperation/fx_test.go new file mode 100644 index 00000000000..58f32287740 --- /dev/null +++ b/chasm/lib/nexusoperation/fx_test.go @@ -0,0 +1,48 @@ +package nexusoperation + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.temporal.io/server/common/config" + "go.temporal.io/server/common/nexus/principaltoken" + "go.uber.org/fx" +) + +// TestPrincipalTokenConfigDecorator_FxGraph validates that the official-config +// decorator composes with principaltoken.Module: given a *config.Config in +// scope, the module's default Config is overridden and Signer/Verifier still +// resolve. +func TestPrincipalTokenConfigDecorator_FxGraph(t *testing.T) { + require.NoError(t, fx.ValidateApp( + fx.Supply(&config.Config{}), + principaltoken.Module, + fx.Decorate(principalTokenConfigFromServerConfig), + fx.Invoke(func(principaltoken.Signer, principaltoken.Verifier, principaltoken.KeyProvider) {}), + )) +} + +func TestPrincipalTokenConfigFromServerConfig(t *testing.T) { + // Empty server config => feature off (no key material). + out, err := principalTokenConfigFromServerConfig(principaltoken.Config{}, &config.Config{}) + require.NoError(t, err) + require.Empty(t, out.SigningKeyPEM) + require.Nil(t, out.TrustedIssuers) + + // Populated inline => mapped through verbatim. + cfg := &config.Config{} + cfg.Global.NexusPrincipalPropagation = config.NexusPrincipalPropagation{ + Issuer: "iss", + SigningKeyID: "k1", + SigningKeyData: "PRIV-PEM", + TrustedIssuers: []config.NexusTrustedIssuer{ + {Issuer: "iss", KeyID: "k1", PublicKeyData: "PUB-PEM"}, + }, + } + out, err = principalTokenConfigFromServerConfig(principaltoken.Config{}, cfg) + require.NoError(t, err) + require.Equal(t, "iss", out.Issuer) + require.Equal(t, "k1", out.SigningKeyID) + require.Equal(t, []byte("PRIV-PEM"), out.SigningKeyPEM) + require.Equal(t, []byte("PUB-PEM"), out.TrustedIssuers["iss"]["k1"]) +} diff --git a/chasm/lib/nexusoperation/gen/nexusoperationpb/v1/operation.pb.go b/chasm/lib/nexusoperation/gen/nexusoperationpb/v1/operation.pb.go index 4156df1d5ec..976ed30ab8b 100644 --- a/chasm/lib/nexusoperation/gen/nexusoperationpb/v1/operation.pb.go +++ b/chasm/lib/nexusoperation/gen/nexusoperationpb/v1/operation.pb.go @@ -15,6 +15,7 @@ import ( v11 "go.temporal.io/api/common/v1" v1 "go.temporal.io/api/failure/v1" v12 "go.temporal.io/api/sdk/v1" + v13 "go.temporal.io/server/api/common/v1" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" anypb "google.golang.org/protobuf/types/known/anypb" @@ -689,13 +690,30 @@ func (x *CancellationState) GetReason() string { } type OperationRequestData struct { - state protoimpl.MessageState `protogen:"open.v1"` - Input *v11.Payload `protobuf:"bytes,1,opt,name=input,proto3" json:"input,omitempty"` - NexusHeader map[string]string `protobuf:"bytes,2,rep,name=nexus_header,json=nexusHeader,proto3" json:"nexus_header,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` - UserMetadata *v12.UserMetadata `protobuf:"bytes,3,opt,name=user_metadata,json=userMetadata,proto3" json:"user_metadata,omitempty"` - Identity string `protobuf:"bytes,4,opt,name=identity,proto3" json:"identity,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Input *v11.Payload `protobuf:"bytes,1,opt,name=input,proto3" json:"input,omitempty"` + NexusHeader map[string]string `protobuf:"bytes,2,rep,name=nexus_header,json=nexusHeader,proto3" json:"nexus_header,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + UserMetadata *v12.UserMetadata `protobuf:"bytes,3,opt,name=user_metadata,json=userMetadata,proto3" json:"user_metadata,omitempty"` + Identity string `protobuf:"bytes,4,opt,name=identity,proto3" json:"identity,omitempty"` + // Principal of the immediate caller that scheduled this Nexus operation, + // captured at schedule time from the inbound RPC's authenticated context. + // For workflow-initiated operations this is the worker's service-account + // principal; for standalone Nexus operations this is the SDK client's. + // + // Server-set, immutable from user code. Propagated as metadata on the + // outbound HTTP dispatch so the handler frontend can authorize on the + // immediate caller. + ServiceCallerPrincipal *v13.AttributedPrincipal `protobuf:"bytes,5,opt,name=service_caller_principal,json=serviceCallerPrincipal,proto3" json:"service_caller_principal,omitempty"` + // Principal that originated this workflow chain at the edge (captured + // from the root workflow's WorkflowExecutionInfo.RootCallerPrincipal at + // schedule time). Propagated alongside the service caller principal so + // the handler can audit and authorize on the end-user identity. + // + // Nil when no end-user principal is available (e.g. workflow started + // before the field existed, or chain broken at an activity boundary). + EndUserCallerPrincipal *v13.AttributedPrincipal `protobuf:"bytes,6,opt,name=end_user_caller_principal,json=endUserCallerPrincipal,proto3" json:"end_user_caller_principal,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *OperationRequestData) Reset() { @@ -756,6 +774,20 @@ func (x *OperationRequestData) GetIdentity() string { return "" } +func (x *OperationRequestData) GetServiceCallerPrincipal() *v13.AttributedPrincipal { + if x != nil { + return x.ServiceCallerPrincipal + } + return nil +} + +func (x *OperationRequestData) GetEndUserCallerPrincipal() *v13.AttributedPrincipal { + if x != nil { + return x.EndUserCallerPrincipal + } + return nil +} + type OperationOutcome_Successful struct { state protoimpl.MessageState `protogen:"open.v1"` Result *v11.Payload `protobuf:"bytes,1,opt,name=result,proto3" json:"result,omitempty"` @@ -848,7 +880,7 @@ var File_temporal_server_chasm_lib_nexusoperation_proto_v1_operation_proto proto const file_temporal_server_chasm_lib_nexusoperation_proto_v1_operation_proto_rawDesc = "" + "\n" + - "Atemporal/server/chasm/lib/nexusoperation/proto/v1/operation.proto\x121temporal.server.chasm.lib.nexusoperation.proto.v1\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a$temporal/api/common/v1/message.proto\x1a%temporal/api/failure/v1/message.proto\x1a'temporal/api/sdk/v1/user_metadata.proto\"\xe9\t\n" + + "Atemporal/server/chasm/lib/nexusoperation/proto/v1/operation.proto\x121temporal.server.chasm.lib.nexusoperation.proto.v1\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a$temporal/api/common/v1/message.proto\x1a%temporal/api/failure/v1/message.proto\x1a'temporal/api/sdk/v1/user_metadata.proto\x1a-temporal/server/api/common/v1/principal.proto\"\xe9\t\n" + "\x0eOperationState\x12Z\n" + "\x06status\x18\x01 \x01(\x0e2B.temporal.server.chasm.lib.nexusoperation.proto.v1.OperationStatusR\x06status\x12\x1f\n" + "\vendpoint_id\x18\x02 \x01(\tR\n" + @@ -902,12 +934,14 @@ const file_temporal_server_chasm_lib_nexusoperation_proto_v1_operation_proto_raw "request_id\x18\b \x01(\tR\trequestId\x12\x1a\n" + "\bidentity\x18\t \x01(\tR\bidentity\x12\x16\n" + "\x06reason\x18\n" + - " \x01(\tR\x06reason\"\xee\x02\n" + + " \x01(\tR\x06reason\"\xcb\x04\n" + "\x14OperationRequestData\x125\n" + "\x05input\x18\x01 \x01(\v2\x1f.temporal.api.common.v1.PayloadR\x05input\x12{\n" + "\fnexus_header\x18\x02 \x03(\v2X.temporal.server.chasm.lib.nexusoperation.proto.v1.OperationRequestData.NexusHeaderEntryR\vnexusHeader\x12F\n" + "\ruser_metadata\x18\x03 \x01(\v2!.temporal.api.sdk.v1.UserMetadataR\fuserMetadata\x12\x1a\n" + - "\bidentity\x18\x04 \x01(\tR\bidentity\x1a>\n" + + "\bidentity\x18\x04 \x01(\tR\bidentity\x12l\n" + + "\x18service_caller_principal\x18\x05 \x01(\v22.temporal.server.api.common.v1.AttributedPrincipalR\x16serviceCallerPrincipal\x12m\n" + + "\x19end_user_caller_principal\x18\x06 \x01(\v22.temporal.server.api.common.v1.AttributedPrincipalR\x16endUserCallerPrincipal\x1a>\n" + "\x10NexusHeaderEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01*\xb0\x02\n" + @@ -962,6 +996,7 @@ var file_temporal_server_chasm_lib_nexusoperation_proto_v1_operation_proto_goTyp (*v11.Link)(nil), // 14: temporal.api.common.v1.Link (*v11.Payload)(nil), // 15: temporal.api.common.v1.Payload (*v12.UserMetadata)(nil), // 16: temporal.api.sdk.v1.UserMetadata + (*v13.AttributedPrincipal)(nil), // 17: temporal.server.api.common.v1.AttributedPrincipal } var file_temporal_server_chasm_lib_nexusoperation_proto_v1_operation_proto_depIdxs = []int32{ 0, // 0: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationState.status:type_name -> temporal.server.chasm.lib.nexusoperation.proto.v1.OperationStatus @@ -988,13 +1023,15 @@ var file_temporal_server_chasm_lib_nexusoperation_proto_v1_operation_proto_depId 15, // 21: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationRequestData.input:type_name -> temporal.api.common.v1.Payload 9, // 22: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationRequestData.nexus_header:type_name -> temporal.server.chasm.lib.nexusoperation.proto.v1.OperationRequestData.NexusHeaderEntry 16, // 23: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationRequestData.user_metadata:type_name -> temporal.api.sdk.v1.UserMetadata - 15, // 24: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationOutcome.Successful.result:type_name -> temporal.api.common.v1.Payload - 13, // 25: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationOutcome.Failed.failure:type_name -> temporal.api.failure.v1.Failure - 26, // [26:26] is the sub-list for method output_type - 26, // [26:26] is the sub-list for method input_type - 26, // [26:26] is the sub-list for extension type_name - 26, // [26:26] is the sub-list for extension extendee - 0, // [0:26] is the sub-list for field type_name + 17, // 24: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationRequestData.service_caller_principal:type_name -> temporal.server.api.common.v1.AttributedPrincipal + 17, // 25: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationRequestData.end_user_caller_principal:type_name -> temporal.server.api.common.v1.AttributedPrincipal + 15, // 26: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationOutcome.Successful.result:type_name -> temporal.api.common.v1.Payload + 13, // 27: temporal.server.chasm.lib.nexusoperation.proto.v1.OperationOutcome.Failed.failure:type_name -> temporal.api.failure.v1.Failure + 28, // [28:28] is the sub-list for method output_type + 28, // [28:28] is the sub-list for method input_type + 28, // [28:28] is the sub-list for extension type_name + 28, // [28:28] is the sub-list for extension extendee + 0, // [0:28] is the sub-list for field type_name } func init() { file_temporal_server_chasm_lib_nexusoperation_proto_v1_operation_proto_init() } diff --git a/chasm/lib/nexusoperation/invocation.go b/chasm/lib/nexusoperation/invocation.go index 42f5d35b52f..9da96d46f98 100644 --- a/chasm/lib/nexusoperation/invocation.go +++ b/chasm/lib/nexusoperation/invocation.go @@ -11,6 +11,7 @@ import ( enumspb "go.temporal.io/api/enums/v1" nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/serviceerror" + commonspb "go.temporal.io/server/api/common/v1" "go.temporal.io/server/api/historyservice/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/chasm" @@ -38,6 +39,8 @@ type startArgs struct { payload *commonpb.Payload nexusLinks []nexus.Link serializedRef []byte + serviceCallerPrincipal *commonspb.AttributedPrincipal + endUserCallerPrincipal *commonspb.AttributedPrincipal } // invocationTraceContext captures per-call contextual information needed to set up HTTP tracing. diff --git a/chasm/lib/nexusoperation/operation.go b/chasm/lib/nexusoperation/operation.go index 9947373f8a2..a8128afd44d 100644 --- a/chasm/lib/nexusoperation/operation.go +++ b/chasm/lib/nexusoperation/operation.go @@ -13,6 +13,7 @@ import ( nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/serviceerror" "go.temporal.io/api/workflowservice/v1" + commonspb "go.temporal.io/server/api/common/v1" persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/chasm" nexusoperationpb "go.temporal.io/server/chasm/lib/nexusoperation/gen/nexusoperationpb/v1" @@ -121,11 +122,21 @@ func newStandaloneOperation( ScheduledTime: timestamppb.New(ctx.Now(nil)), RequestId: uuid.NewString(), }) + // Capture the inbound RPC's principal as both the service caller and + // the end-user principal: for standalone Nexus operations the SDK + // client that invoked us is both the immediate caller and the + // originator. The auth interceptor wrote the principal headers onto + // the gRPC incoming metadata (and stripped any spoofed inbound + // values), so reading them via the chasm context's RequestHeader is + // safe. + callerPrincipal := newAttributedPrincipal(PrincipalFromContext(ctx), resolvedNameFromChasmContext(ctx)) op.RequestData = chasm.NewDataField(ctx, &nexusoperationpb.OperationRequestData{ - Input: frontendReq.GetInput(), - NexusHeader: frontendReq.GetNexusHeader(), - UserMetadata: frontendReq.GetUserMetadata(), - Identity: frontendReq.GetIdentity(), + Input: frontendReq.GetInput(), + NexusHeader: frontendReq.GetNexusHeader(), + UserMetadata: frontendReq.GetUserMetadata(), + Identity: frontendReq.GetIdentity(), + ServiceCallerPrincipal: callerPrincipal, + EndUserCallerPrincipal: callerPrincipal, }) op.Visibility = chasm.NewComponentField(ctx, chasm.NewVisibilityWithData( ctx, @@ -289,6 +300,11 @@ func (o *Operation) loadStartArgs( var ( invocationData InvocationData err error + // Principals captured at schedule time (AttributedPrincipal = identity + + // resolved-name snapshot). Read from RequestData below, independently of + // the input source. + serviceCaller *commonspb.AttributedPrincipal + endUserCaller *commonspb.AttributedPrincipal ) if store, ok := o.Store.TryGet(ctx); ok { invocationData, err = store.NexusOperationInvocationData(ctx, o) @@ -302,6 +318,16 @@ func (o *Operation) loadStartArgs( Header: requestData.GetNexusHeader(), } } + // Principals live on RequestData for both standalone operations and + // workflow-initiated operations. For the latter, SetCallerPrincipals writes + // a RequestData carrying only the principals alongside the parent Store that + // supplies the input — so read them here via TryGet regardless of the branch + // taken above. Nil principals are the graceful-degradation case (feature + // off, or no authenticated identity). + if requestData, ok := o.RequestData.TryGet(ctx); ok { + serviceCaller = requestData.GetServiceCallerPrincipal() + endUserCaller = requestData.GetEndUserCallerPrincipal() + } invocationData.NexusLinks = append(invocationData.NexusLinks, commonnexus.ConvertLinkNexusOperationToNexusLink(&commonpb.Link_NexusOperation{ Namespace: ctx.NamespaceEntry().Name().String(), @@ -329,6 +355,8 @@ func (o *Operation) loadStartArgs( header: invocationData.Header, nexusLinks: invocationData.NexusLinks, serializedRef: serializedRef, + serviceCallerPrincipal: serviceCaller, + endUserCallerPrincipal: endUserCaller, }, nil } diff --git a/chasm/lib/nexusoperation/operation_tasks.go b/chasm/lib/nexusoperation/operation_tasks.go index 681ea278297..646775c36a0 100644 --- a/chasm/lib/nexusoperation/operation_tasks.go +++ b/chasm/lib/nexusoperation/operation_tasks.go @@ -22,6 +22,7 @@ import ( "go.temporal.io/server/common/namespace" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/nexus/nexusrpc" + "go.temporal.io/server/common/nexus/principaltoken" "go.temporal.io/server/common/resource" queueserrors "go.temporal.io/server/service/history/queues/errors" "go.uber.org/fx" @@ -51,6 +52,13 @@ type operationInvocationTaskHandlerOptions struct { HTTPTraceProvider commonnexus.HTTPClientTraceProvider HistoryClient resource.HistoryClient ChasmRegistry *chasm.Registry + // PrincipalSigner mints the signed identity carrier attached to outbound + // Nexus dispatches. Optional: when absent, dispatch falls back to raw + // server-to-server principal headers. Cloud may supply a KMS-backed signer. + PrincipalSigner principaltoken.Signer `optional:"true"` + // PrincipalResolver snapshots the human-readable name for a principal at + // dispatch time (noop in OSS). Optional; nil resolves to no snapshot. + PrincipalResolver principaltoken.PrincipalResolver `optional:"true"` } type operationInvocationTaskHandler struct { @@ -66,6 +74,8 @@ type operationInvocationTaskHandler struct { httpTraceProvider commonnexus.HTTPClientTraceProvider historyClient resource.HistoryClient chasmRegistry *chasm.Registry + principalSigner principaltoken.Signer + principalResolver principaltoken.PrincipalResolver } func newOperationInvocationTaskHandler(opts operationInvocationTaskHandlerOptions) *operationInvocationTaskHandler { @@ -80,6 +90,8 @@ func newOperationInvocationTaskHandler(opts operationInvocationTaskHandlerOption httpTraceProvider: opts.HTTPTraceProvider, historyClient: opts.HistoryClient, chasmRegistry: opts.ChasmRegistry, + principalSigner: opts.PrincipalSigner, + principalResolver: opts.PrincipalResolver, } } @@ -168,6 +180,17 @@ func (h *operationInvocationTaskHandler) Execute( if h.config.UseNewFailureWireFormat(ns.Name().String()) { header.Set(nexusrpc.HeaderTemporalNexusFailureSupport, "true") } + // Propagate the captured caller principals to the handler (no signer ⇒ no + // token ⇒ nothing propagated). The namespace caller is this cluster's own + // namespace, self-asserted in the token. A minting failure must not fail the + // dispatch: log and proceed (equivalent to the feature being off). + namespaceCaller := &commonpb.Principal{Type: namespacePrincipalType, Name: ns.Name().String()} + if err := attachPrincipalIdentity( + ctx, header, h.principalSigner, h.principalResolver, + args.serviceCallerPrincipal, args.endUserCallerPrincipal, namespaceCaller, + ); err != nil { + h.logger.Warn("failed to mint principal token for nexus dispatch", tag.Error(err)) + } callCtx, cancel := context.WithTimeout(ctx, callTimeout) defer cancel() diff --git a/chasm/lib/nexusoperation/principal.go b/chasm/lib/nexusoperation/principal.go new file mode 100644 index 00000000000..631c5d01470 --- /dev/null +++ b/chasm/lib/nexusoperation/principal.go @@ -0,0 +1,124 @@ +package nexusoperation + +import ( + "context" + + "github.com/nexus-rpc/sdk-go/nexus" + commonpb "go.temporal.io/api/common/v1" + commonspb "go.temporal.io/server/api/common/v1" + "go.temporal.io/server/chasm" + nexusoperationpb "go.temporal.io/server/chasm/lib/nexusoperation/gen/nexusoperationpb/v1" + "go.temporal.io/server/common/headers" + "go.temporal.io/server/common/nexus/principaltoken" +) + +// namespacePrincipalType is the Principal.Type for a namespace caller. The +// Name is the caller's namespace; it is only meaningful scoped to the token's +// verified issuer (namespace names are not globally unique across clusters). +const namespacePrincipalType = "namespace" + +// newAttributedPrincipal wraps an authenticated principal as the at-rest +// AttributedPrincipal, with the display-name snapshot captured at write time +// (empty when the principal's Name is already human-readable). Returns nil for a +// nil principal so absence stays absence. +func newAttributedPrincipal(p *commonpb.Principal, resolvedName string) *commonspb.AttributedPrincipal { + if p == nil { + return nil + } + return &commonspb.AttributedPrincipal{Principal: p, ResolvedName: resolvedName} +} + +// resolvedNameFromChasmContext reads the immediate-caller display-name snapshot +// set by the auth interceptor (empty when the Name is already human-readable). +func resolvedNameFromChasmContext(ctx chasm.Context) string { + return ctx.RequestHeader(headers.PrincipalResolvedNameHeaderName) +} + +// PrincipalFromContext reads the immediate-caller principal from the chasm +// context's incoming metadata. Server-trusted: the auth interceptor wrote it +// after authentication and stripped any spoofed inbound values. Nil when no +// principal header is present (OSS without principal derivation, etc.). +func PrincipalFromContext(ctx chasm.Context) *commonpb.Principal { + typ := ctx.RequestHeader(headers.PrincipalTypeHeaderName) + name := ctx.RequestHeader(headers.PrincipalNameHeaderName) + if typ == "" && name == "" { + return nil + } + return &commonpb.Principal{Type: typ, Name: name} +} + +// SetCallerPrincipals records the caller identity on a workflow-initiated +// operation so the outbound dispatch can propagate it. serviceCaller is the +// worker that issued the schedule command; endUserCaller is the chain's +// originating identity (RootCallerPrincipal). No-op when both are nil. +func SetCallerPrincipals(ctx chasm.MutableContext, op *Operation, serviceCaller, endUserCaller *commonpb.Principal) { + if serviceCaller == nil && endUserCaller == nil { + return + } + op.RequestData = chasm.NewDataField(ctx, &nexusoperationpb.OperationRequestData{ + // The resolved name pairs with the immediate caller; the end-user + // (RootCallerPrincipal) carries no snapshot. + ServiceCallerPrincipal: newAttributedPrincipal(serviceCaller, resolvedNameFromChasmContext(ctx)), + EndUserCallerPrincipal: newAttributedPrincipal(endUserCaller, ""), + }) +} + +// attachPrincipalIdentity mints a signed principal token and attaches it as a +// single header on the outbound Nexus request. The token is the sole carrier of +// caller identity across the Nexus hop — there is no raw-header fallback, so +// when no signer is configured nothing is propagated (the feature is off). +// +// The handler verifies the token (signature, or trusted transport peer) and +// promotes the principals; it strips any spoofed inbound principal headers, so +// the signed token is the only path by which identity survives ingress. +func attachPrincipalIdentity( + ctx context.Context, + h nexus.Header, + signer principaltoken.Signer, + resolver principaltoken.PrincipalResolver, + serviceCaller, endUserCaller *commonspb.AttributedPrincipal, + namespaceCaller *commonpb.Principal, +) error { + if signer == nil || + (serviceCaller.GetPrincipal() == nil && endUserCaller.GetPrincipal() == nil) { + return nil + } + svcName, err := resolveDisplayName(ctx, resolver, serviceCaller) + if err != nil { + return err + } + euName, err := resolveDisplayName(ctx, resolver, endUserCaller) + if err != nil { + return err + } + token, err := signer.Sign(ctx, principaltoken.Content{ + ServiceCaller: serviceCaller.GetPrincipal(), + EndUser: endUserCaller.GetPrincipal(), + NamespaceCaller: namespaceCaller, + ServiceCallerResolvedName: svcName, + EndUserResolvedName: euName, + }) + if err != nil { + return err + } + h.Set(principaltoken.Header, token) + return nil +} + +// resolveDisplayName returns the human-readable name snapshot for a stored +// principal: the one captured at write time if present, otherwise resolved now +// (a noop in OSS, returning ""). Resolver errors are surfaced to the caller, +// which logs and proceeds without propagated identity. +func resolveDisplayName( + ctx context.Context, + resolver principaltoken.PrincipalResolver, + sp *commonspb.AttributedPrincipal, +) (string, error) { + if sp.GetResolvedName() != "" { + return sp.GetResolvedName(), nil + } + if resolver == nil || sp.GetPrincipal() == nil { + return "", nil + } + return resolver.Resolve(ctx, sp.GetPrincipal()) +} diff --git a/chasm/lib/nexusoperation/principal_test.go b/chasm/lib/nexusoperation/principal_test.go new file mode 100644 index 00000000000..27847443413 --- /dev/null +++ b/chasm/lib/nexusoperation/principal_test.go @@ -0,0 +1,61 @@ +package nexusoperation + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + + "github.com/nexus-rpc/sdk-go/nexus" + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + commonspb "go.temporal.io/server/api/common/v1" + "go.temporal.io/server/common/nexus/principaltoken" +) + +func testSigner(t *testing.T) principaltoken.Signer { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + s, err := principaltoken.NewECDSASigner(principaltoken.ECDSASignerOptions{ + Key: key, KID: "k1", Issuer: "test", + }) + require.NoError(t, err) + return s +} + +func stored(typ, name string) *commonspb.AttributedPrincipal { + return &commonspb.AttributedPrincipal{Principal: &commonpb.Principal{Type: typ, Name: name}} +} + +func TestAttachPrincipalIdentity_MintsToken(t *testing.T) { + t.Parallel() + + h := nexus.Header{} + err := attachPrincipalIdentity(context.Background(), h, testSigner(t), principaltoken.NoopResolver{}, + stored("service-accounts", "sa-worker"), stored("users", "alice@example.com"), + &commonpb.Principal{Type: namespacePrincipalType, Name: "caller-ns"}) + require.NoError(t, err) + require.NotEmpty(t, h.Get(principaltoken.Header), "a signed token should be attached") +} + +func TestAttachPrincipalIdentity_NoSignerNoPropagation(t *testing.T) { + t.Parallel() + + // Without a signer the token is the only carrier, so nothing is attached. + h := nexus.Header{} + err := attachPrincipalIdentity(context.Background(), h, nil, principaltoken.NoopResolver{}, + stored("users", "alice"), stored("users", "alice"), nil) + require.NoError(t, err) + require.Empty(t, h.Get(principaltoken.Header)) +} + +func TestAttachPrincipalIdentity_NoPrincipalsNoToken(t *testing.T) { + t.Parallel() + + h := nexus.Header{} + err := attachPrincipalIdentity(context.Background(), h, testSigner(t), principaltoken.NoopResolver{}, nil, nil, nil) + require.NoError(t, err) + require.Empty(t, h.Get(principaltoken.Header)) +} diff --git a/chasm/lib/nexusoperation/proto/v1/operation.proto b/chasm/lib/nexusoperation/proto/v1/operation.proto index 48b03c98f58..8ec22f2b736 100644 --- a/chasm/lib/nexusoperation/proto/v1/operation.proto +++ b/chasm/lib/nexusoperation/proto/v1/operation.proto @@ -8,6 +8,7 @@ import "google/protobuf/timestamp.proto"; import "temporal/api/common/v1/message.proto"; import "temporal/api/failure/v1/message.proto"; import "temporal/api/sdk/v1/user_metadata.proto"; +import "temporal/server/api/common/v1/principal.proto"; option go_package = "go.temporal.io/server/chasm/lib/nexusoperation/gen/nexusoperationpb;nexusoperationpb"; @@ -146,4 +147,23 @@ message OperationRequestData { map nexus_header = 2; temporal.api.sdk.v1.UserMetadata user_metadata = 3; string identity = 4; + + // Principal of the immediate caller that scheduled this Nexus operation, + // captured at schedule time from the inbound RPC's authenticated context. + // For workflow-initiated operations this is the worker's service-account + // principal; for standalone Nexus operations this is the SDK client's. + // + // Server-set, immutable from user code. Propagated as metadata on the + // outbound HTTP dispatch so the handler frontend can authorize on the + // immediate caller. + temporal.server.api.common.v1.AttributedPrincipal service_caller_principal = 5; + + // Principal that originated this workflow chain at the edge (captured + // from the root workflow's WorkflowExecutionInfo.RootCallerPrincipal at + // schedule time). Propagated alongside the service caller principal so + // the handler can audit and authorize on the end-user identity. + // + // Nil when no end-user principal is available (e.g. workflow started + // before the field existed, or chain broken at an activity boundary). + temporal.server.api.common.v1.AttributedPrincipal end_user_caller_principal = 6; } diff --git a/chasm/lib/workflow/nexus_events.go b/chasm/lib/workflow/nexus_events.go index 8e6296fc71b..7c5b5fb2fd9 100644 --- a/chasm/lib/workflow/nexus_events.go +++ b/chasm/lib/workflow/nexus_events.go @@ -52,6 +52,19 @@ func (d ScheduledEventDefinition) Apply(ctx chasm.MutableContext, wf *Workflow, Attempt: 0, }) + // Bridge the caller identity onto the operation so the outbound dispatch can + // propagate it. The service caller is the worker that issued this + // ScheduleNexusOperation command (present on the chasm context's incoming + // metadata); the end-user is the workflow chain's originating identity, + // inherited from the workflow's RootCallerPrincipal. This is the + // workflow-initiated analogue of newStandaloneOperation's principal capture. + nexusoperation.SetCallerPrincipals( + ctx, + op, + nexusoperation.PrincipalFromContext(ctx), + wf.GetRootCallerPrincipal(ctx), + ) + if err := nexusoperation.TransitionScheduled.Apply(op, ctx, nexusoperation.EventScheduled{}); err != nil { return err } diff --git a/chasm/lib/workflow/workflow.go b/chasm/lib/workflow/workflow.go index 2b5409a4147..62fe82690c4 100644 --- a/chasm/lib/workflow/workflow.go +++ b/chasm/lib/workflow/workflow.go @@ -39,15 +39,51 @@ type Workflow struct { // Updates indexed by update ID, used to store the update components. Updates chasm.Map[string, *WorkflowUpdate] + + // RootCallerPrincipal is the identity that originated this workflow chain + // at the edge — the principal that started the root workflow. Set once + // at workflow creation and inherited unchanged through child workflow + // starts and continue-as-new so the chain's originating identity is + // preserved end-to-end without per-event walks. + // + // Read at Nexus operation schedule time to populate the outbound + // dispatch's end-user principal headers. Nil for workflows started + // before the feature shipped or when the authorizer didn't derive a + // principal (OSS without a configured Authorizer). + RootCallerPrincipal chasm.Field[*commonpb.Principal] } +// NewWorkflow constructs the chasm Workflow root component for a new +// workflow execution. If rootCallerPrincipal is non-nil it is stored on +// the component as a Data field, capturing the originator of this +// workflow chain. For top-level starts, callers should pass the principal +// derived by the inbound RPC's authorizer; for child workflows and +// continue-as-new, callers should pass the parent / previous run's +// RootCallerPrincipal so the chain's identity is preserved. func NewWorkflow( - _ chasm.MutableContext, + ctx chasm.MutableContext, msPointer chasm.MSPointer, + rootCallerPrincipal *commonpb.Principal, ) *Workflow { - return &Workflow{ + wf := &Workflow{ MSPointer: msPointer, } + if rootCallerPrincipal != nil { + wf.RootCallerPrincipal = chasm.NewDataField(ctx, rootCallerPrincipal) + } + return wf +} + +// GetRootCallerPrincipal returns the originating end-user identity for +// this workflow chain, or nil if no principal was captured (workflow +// predates the feature; OSS without a configured Authorizer; chain +// broke at an activity boundary upstream). +func (w *Workflow) GetRootCallerPrincipal(ctx chasm.Context) *commonpb.Principal { + principal, ok := w.RootCallerPrincipal.TryGet(ctx) + if !ok { + return nil + } + return principal } func (w *Workflow) LifecycleState( diff --git a/common/authorization/authorizer.go b/common/authorization/authorizer.go index ac79cc837e0..029006506e5 100644 --- a/common/authorization/authorizer.go +++ b/common/authorization/authorizer.go @@ -43,6 +43,9 @@ type ( Reason string // Principal is the server-computed identity of the caller. Can be nil when not computed. Principal *commonpb.Principal + // PrincipalResolvedName is the display-name snapshot for Principal when its + // Name is an opaque ID (Cloud). Empty when Name is already human-readable. + PrincipalResolvedName string } // Decision is enum type for auth decision diff --git a/common/authorization/interceptor.go b/common/authorization/interceptor.go index 48a106f87b5..235f88cc8ae 100644 --- a/common/authorization/interceptor.go +++ b/common/authorization/interceptor.go @@ -168,12 +168,15 @@ func (a *Interceptor) Intercept( APIName: info.FullMethod, Request: req, } - principal, err := a.Authorize(ctx, claims, ct) + result, err := a.authorizeResult(ctx, claims, ct) if err != nil { return nil, err } - if a.enablePrincipalPropagation != nil && a.enablePrincipalPropagation(namespace) && principal != nil { - ctx = headers.SetPrincipal(ctx, principal) + if a.enablePrincipalPropagation != nil && a.enablePrincipalPropagation(namespace) && result.Principal != nil { + ctx = headers.SetPrincipal(ctx, result.Principal) + if result.PrincipalResolvedName != "" { + ctx = headers.SetPrincipalResolvedName(ctx, result.PrincipalResolvedName) + } } // Authorize target namespaces in cross-namespace commands @@ -302,8 +305,15 @@ func (a *Interceptor) EnhanceContext(ctx context.Context, authInfo *AuthInfo, cl // Logs and emits metrics when unauthorized. // Returns the principal identity and any authorization error. func (a *Interceptor) Authorize(ctx context.Context, claims *Claims, ct *CallTarget) (*commonpb.Principal, error) { + result, err := a.authorizeResult(ctx, claims, ct) + return result.Principal, err +} + +// authorizeResult is Authorize but returns the full Result (incl. the resolved-name +// snapshot) for callers that promote the principal onto the context. +func (a *Interceptor) authorizeResult(ctx context.Context, claims *Claims, ct *CallTarget) (Result, error) { if a.authorizer == nil { - return nil, nil + return Result{}, nil } mh := a.getMetricsHandler(ct.Namespace) @@ -315,19 +325,19 @@ func (a *Interceptor) Authorize(ctx context.Context, claims *Claims, ct *CallTar metrics.ServiceErrAuthorizeFailedCounter.With(mh).Record(1) a.logger.Error("Authorization error", tag.Error(err)) if a.exposeAuthorizerErrors() { - return nil, err + return Result{}, err } - return nil, errUnauthorized // return a generic error to the caller without disclosing details + return Result{}, errUnauthorized // return a generic error to the caller without disclosing details } if result.Decision != DecisionAllow { metrics.ServiceErrUnauthorizedCounter.With(mh).Record(1) // if a reason is included in the result, include it in the error message if result.Reason != "" { - return nil, serviceerror.NewPermissionDenied(RequestUnauthorized, result.Reason) + return Result{}, serviceerror.NewPermissionDenied(RequestUnauthorized, result.Reason) } - return nil, errUnauthorized // return a generic error to the caller without disclosing details + return Result{}, errUnauthorized // return a generic error to the caller without disclosing details } - return result.Principal, nil + return result, nil } // getMetricsHandler returns a metrics handler with a namespace tag diff --git a/common/config/config.go b/common/config/config.go index 75af925f53a..ca9519d9c72 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -138,6 +138,10 @@ type ( Metrics *metrics.Config `yaml:"metrics"` // Settings for authentication and authorization Authorization Authorization `yaml:"authorization"` + // NexusPrincipalPropagation configures signed caller-identity + // propagation across the Nexus hop. Empty (the default) disables the + // feature: nothing is signed and no inbound token is trusted. + NexusPrincipalPropagation NexusPrincipalPropagation `yaml:"nexusPrincipalPropagation"` } // RootTLS contains all TLS settings for the Temporal server @@ -635,6 +639,39 @@ type ( PersistenceCustomSearchAttributes map[string]int `yaml:"persistenceCustomSearchAttributes" validate:"persistence_custom_search_attributes"` } + // NexusPrincipalPropagation configures the signed token used to propagate + // caller identity across the Nexus hop. Keys are PEM-encoded EC keys, + // supplied inline (Data) or by file path (File). Leaving SigningKey empty + // disables minting; leaving TrustedIssuers empty disables verification. + NexusPrincipalPropagation struct { + // Issuer identifies this cluster as the "iss" of minted tokens. + Issuer string `yaml:"issuer"` + // SigningKeyData / SigningKeyFile is this cluster's PEM EC private key. + SigningKeyData string `yaml:"signingKeyData"` + SigningKeyFile string `yaml:"signingKeyFile"` + // SigningKeyID is the "kid" advertised in minted tokens and the JWKS. + SigningKeyID string `yaml:"signingKeyId"` + // TTL is the minted-token lifetime (default applied downstream if zero). + TTL time.Duration `yaml:"ttl"` + // Leeway tolerates clock skew on verification. + Leeway time.Duration `yaml:"leeway"` + // TrustMode selects how inbound tokens are trusted: "signature" (default, + // JWS verified against TrustedIssuers) or "transport" (trust the + // connection peer, e.g. cell-to-cell mTLS — no key distribution needed). + TrustMode string `yaml:"trustMode"` + // TrustedIssuers lists the issuer/kid public keys this cluster will + // accept inbound tokens from (signature mode). + TrustedIssuers []NexusTrustedIssuer `yaml:"trustedIssuers"` + } + + // NexusTrustedIssuer is one trusted (issuer, kid) -> PEM EC public key entry. + NexusTrustedIssuer struct { + Issuer string `yaml:"issuer"` + KeyID string `yaml:"keyId"` + PublicKeyData string `yaml:"publicKeyData"` + PublicKeyFile string `yaml:"publicKeyFile"` + } + Authorization struct { // Signing key provider for validating JWT tokens JWTKeyProvider JWTKeyProvider `yaml:"jwtKeyProvider"` diff --git a/common/headers/headers.go b/common/headers/headers.go index 37d0d6bd92d..d5abb756d05 100644 --- a/common/headers/headers.go +++ b/common/headers/headers.go @@ -2,6 +2,7 @@ package headers import ( "context" + "net/http" "strings" commonpb "go.temporal.io/api/common/v1" @@ -21,14 +22,29 @@ const ( CallerTypeHeaderName = "caller-type" CallOriginHeaderName = "call-initiation" + // Principal of the immediate caller (the worker / SDK client that issued + // this RPC). Set server-side by the auth interceptor; stripped on ingress + // to prevent spoofing by external callers. PrincipalTypeHeaderName = "temporal-principal-type" PrincipalNameHeaderName = "temporal-principal-name" + // Display-name snapshot for the immediate caller when its Name is an opaque + // ID (Cloud). Captured at auth time; empty when Name is already readable. + PrincipalResolvedNameHeaderName = "temporal-principal-resolved-name" + + // End-user principal: the identity that originated the request at the edge + // (e.g. who started the root workflow). Propagated frontend->history (see + // propagateHeaders) so it reaches NewWorkflow's RootCallerPrincipal seeding. + EndUserPrincipalTypeHeaderName = "temporal-end-user-principal-type" + EndUserPrincipalNameHeaderName = "temporal-end-user-principal-name" ExperimentHeaderName = "temporal-experiment" ) var ( - // propagateHeaders are the headers to propagate from the frontend to other services. + // propagateHeaders are the headers to propagate from the frontend to + // other services via gRPC metadata. The end-user principal pair is + // included so the chain-originating identity reaches history's + // RootCallerPrincipal seeding (see EndUserPrincipalTypeHeaderName). propagateHeaders = []string{ ClientNameHeaderName, ClientVersionHeaderName, @@ -39,6 +55,23 @@ var ( CallOriginHeaderName, PrincipalTypeHeaderName, PrincipalNameHeaderName, + PrincipalResolvedNameHeaderName, + EndUserPrincipalTypeHeaderName, + EndUserPrincipalNameHeaderName, + } + + // principalHeaderNames is the set of headers that must be stripped from + // inbound metadata / HTTP requests to prevent external callers from + // spoofing identity. Includes both immediate-caller and end-user pairs + // even though the end-user pair is not in propagateHeaders: it can + // arrive on the Nexus dispatch HTTP boundary, and any external attempt + // to inject it at a gRPC ingress must still be removed. + principalHeaderNames = []string{ + PrincipalTypeHeaderName, + PrincipalNameHeaderName, + PrincipalResolvedNameHeaderName, + EndUserPrincipalTypeHeaderName, + EndUserPrincipalNameHeaderName, } ) @@ -122,19 +155,36 @@ func IsExperimentRequested(ctx context.Context, experiment string) bool { return false } -// StripPrincipal removes principal headers from incoming metadata to prevent -// external callers from spoofing principal identity. +// StripPrincipalHTTP removes both the immediate-caller and end-user principal +// HTTP headers from an inbound request. This complements StripPrincipal at +// HTTP ingress boundaries (e.g. the Nexus dispatch and completion HTTP +// handlers) where the request never passes through the gRPC interceptor +// chain, so principal-bearing HTTP headers would otherwise survive into the +// authorized context. +func StripPrincipalHTTP(h http.Header) { + for _, name := range principalHeaderNames { + h.Del(name) + } +} + +// StripPrincipal removes both the immediate-caller and end-user principal +// headers from incoming metadata to prevent external callers from spoofing +// identity. Callers must invoke this on every external ingress boundary +// (gRPC frontend interceptor, Nexus dispatch HTTP handler, Nexus completion +// HTTP handler) before authorizing the request. func StripPrincipal(ctx context.Context) context.Context { mdIncoming, ok := metadata.FromIncomingContext(ctx) if !ok { return ctx } - mdIncoming.Delete(PrincipalTypeHeaderName) - mdIncoming.Delete(PrincipalNameHeaderName) + for _, h := range principalHeaderNames { + mdIncoming.Delete(h) + } return metadata.NewIncomingContext(ctx, mdIncoming) } -// SetPrincipal sets the principal type and name headers in the incoming metadata. +// SetPrincipal sets the immediate-caller principal headers in the incoming +// metadata. func SetPrincipal(ctx context.Context, principal *commonpb.Principal) context.Context { return setIncomingMD(ctx, map[string]string{ PrincipalTypeHeaderName: principal.GetType(), @@ -142,7 +192,21 @@ func SetPrincipal(ctx context.Context, principal *commonpb.Principal) context.Co }) } -// GetPrincipal retrieves the principal from the context headers. Returns nil if principal is not set. +// SetPrincipalResolvedName sets the immediate-caller display-name snapshot. +func SetPrincipalResolvedName(ctx context.Context, name string) context.Context { + return setIncomingMD(ctx, map[string]string{PrincipalResolvedNameHeaderName: name}) +} + +// GetPrincipalResolvedName retrieves the immediate-caller display-name snapshot, +// or "" if absent (Name is already human-readable). +func GetPrincipalResolvedName(ctx context.Context) string { + return GetValues(ctx, PrincipalResolvedNameHeaderName)[0] +} + +// GetPrincipal retrieves the immediate-caller principal from the context +// headers. Returns nil if no principal-carrying header is present (e.g. the +// caller-side did not opt into propagation, or the request crossed a worker +// boundary so the chain broke). func GetPrincipal(ctx context.Context) *commonpb.Principal { values := GetValues(ctx, PrincipalTypeHeaderName, PrincipalNameHeaderName) if values[0] == "" && values[1] == "" { @@ -151,6 +215,27 @@ func GetPrincipal(ctx context.Context) *commonpb.Principal { return &commonpb.Principal{Type: values[0], Name: values[1]} } +// SetEndUserPrincipal sets the end-user principal headers in the incoming +// metadata. The end-user principal identifies the original initiator of the +// request chain (e.g. the API-key holder who started the root workflow), +// distinct from the immediate caller whose RPC is currently being processed. +func SetEndUserPrincipal(ctx context.Context, principal *commonpb.Principal) context.Context { + return setIncomingMD(ctx, map[string]string{ + EndUserPrincipalTypeHeaderName: principal.GetType(), + EndUserPrincipalNameHeaderName: principal.GetName(), + }) +} + +// GetEndUserPrincipal retrieves the end-user principal from the context +// headers. Returns nil if no end-user principal-carrying header is present. +func GetEndUserPrincipal(ctx context.Context) *commonpb.Principal { + values := GetValues(ctx, EndUserPrincipalTypeHeaderName, EndUserPrincipalNameHeaderName) + if values[0] == "" && values[1] == "" { + return nil + } + return &commonpb.Principal{Type: values[0], Name: values[1]} +} + // setIncomingMD sets the key-value pairs in the incoming metadata. // Empty values are ignored. func setIncomingMD(ctx context.Context, kv map[string]string) context.Context { diff --git a/common/headers/headers_test.go b/common/headers/headers_test.go index f04684921da..5f4b76dcaa8 100644 --- a/common/headers/headers_test.go +++ b/common/headers/headers_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" "google.golang.org/grpc/metadata" ) @@ -198,3 +199,115 @@ func TestIsExperimentRequested(t *testing.T) { }) } } + +func TestSetGetPrincipal_RoundTrip(t *testing.T) { + t.Parallel() + + principal := &commonpb.Principal{Type: "users", Name: "alice@example.com"} + ctx := SetPrincipal(context.Background(), principal) + + got := GetPrincipal(ctx) + require.NotNil(t, got) + require.Equal(t, "users", got.GetType()) + require.Equal(t, "alice@example.com", got.GetName()) +} + +func TestSetGetEndUserPrincipal_RoundTrip(t *testing.T) { + t.Parallel() + + principal := &commonpb.Principal{Type: "service-accounts", Name: "sa-prod-payments"} + ctx := SetEndUserPrincipal(context.Background(), principal) + + got := GetEndUserPrincipal(ctx) + require.NotNil(t, got) + require.Equal(t, "service-accounts", got.GetType()) + require.Equal(t, "sa-prod-payments", got.GetName()) +} + +func TestSetPrincipal_DoesNotCollideWithEndUserPrincipal(t *testing.T) { + t.Parallel() + + caller := &commonpb.Principal{Type: "service-accounts", Name: "sa-worker"} + endUser := &commonpb.Principal{Type: "users", Name: "alice"} + + ctx := SetPrincipal(context.Background(), caller) + ctx = SetEndUserPrincipal(ctx, endUser) + + gotCaller := GetPrincipal(ctx) + require.Equal(t, "sa-worker", gotCaller.GetName()) + + gotEndUser := GetEndUserPrincipal(ctx) + require.Equal(t, "alice", gotEndUser.GetName()) +} + +func TestGetPrincipal_ReturnsNilWhenAbsent(t *testing.T) { + t.Parallel() + + require.Nil(t, GetPrincipal(context.Background())) + require.Nil(t, GetEndUserPrincipal(context.Background())) +} + +func TestStripPrincipal_RemovesBothPrincipalPairs(t *testing.T) { + t.Parallel() + + // Simulate an external caller attempting to inject identity headers. + ctx := metadata.NewIncomingContext(context.Background(), metadata.New(map[string]string{ + PrincipalTypeHeaderName: "attacker", + PrincipalNameHeaderName: "spoof", + EndUserPrincipalTypeHeaderName: "attacker", + EndUserPrincipalNameHeaderName: "spoof", + // A non-principal header should survive stripping. + ClientNameHeaderName: "legitimate-client", + })) + + ctx = StripPrincipal(ctx) + + require.Nil(t, GetPrincipal(ctx)) + require.Nil(t, GetEndUserPrincipal(ctx)) + + // Sanity check: non-principal metadata is unaffected. + md, ok := metadata.FromIncomingContext(ctx) + require.True(t, ok) + require.Equal(t, "legitimate-client", md.Get(ClientNameHeaderName)[0]) +} + +func TestPropagate_CarriesImmediateCallerPrincipal(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctx = metadata.NewIncomingContext(ctx, metadata.New(map[string]string{ + PrincipalTypeHeaderName: "service-accounts", + PrincipalNameHeaderName: "sa-worker", + })) + + ctx = Propagate(ctx) + + md, ok := metadata.FromOutgoingContext(ctx) + require.True(t, ok) + require.Equal(t, "service-accounts", md.Get(PrincipalTypeHeaderName)[0]) + require.Equal(t, "sa-worker", md.Get(PrincipalNameHeaderName)[0]) +} + +// The end-user principal pair is propagated frontend->history so the +// chain-originating identity reaches NewWorkflow's RootCallerPrincipal seeding. +func TestPropagate_CarriesEndUserPrincipal(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ctx = metadata.NewIncomingContext(ctx, metadata.New(map[string]string{ + PrincipalTypeHeaderName: "service-accounts", + PrincipalNameHeaderName: "sa-worker", + EndUserPrincipalTypeHeaderName: "users", + EndUserPrincipalNameHeaderName: "alice", + })) + + ctx = Propagate(ctx) + + md, ok := metadata.FromOutgoingContext(ctx) + require.True(t, ok) + + require.Equal(t, "service-accounts", md.Get(PrincipalTypeHeaderName)[0]) + require.Equal(t, "sa-worker", md.Get(PrincipalNameHeaderName)[0]) + require.Equal(t, "users", md.Get(EndUserPrincipalTypeHeaderName)[0]) + require.Equal(t, "alice", md.Get(EndUserPrincipalNameHeaderName)[0]) +} diff --git a/common/nexus/principaltoken/config.go b/common/nexus/principaltoken/config.go new file mode 100644 index 00000000000..5c275f330e4 --- /dev/null +++ b/common/nexus/principaltoken/config.go @@ -0,0 +1,148 @@ +package principaltoken + +import ( + "crypto" + "crypto/ecdsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "time" +) + +// TrustMode selects how the handler establishes trust in a propagated token. +type TrustMode string + +const ( + // TrustModeSignature (the default) verifies a JWS signature against a + // trusted issuer's public key. Topology-independent; needs key distribution. + TrustModeSignature TrustMode = "signature" + // TrustModeTransport trusts the token's claims because the *connection* is + // trusted (a sibling cell behind mutual TLS), without checking a signature. + // The Cloud P0 path: no key infrastructure, reuses the cell-to-cell mTLS + // trust the broker already runs. + TrustModeTransport TrustMode = "transport" +) + +// Config is the operator-facing configuration that materializes a signer and +// verifier from PEM key material and a trust mode. Keys are supplied as PEM +// bytes so they can come from a file, secret, or env var without this package +// knowing the source. +// +// Verification is gated by TrustMode: +// - signature (default): a Verifier is built only if TrustedIssuers is set. +// - transport: a Verifier is built that trusts the connection peer (the +// PeerTrustFunc, supplied separately by the host). +// +// An empty Config is the safe default (feature off: no Signer, no Verifier). +type Config struct { + // Issuer identifies this cluster as the "iss" in minted tokens. + Issuer string + // SigningKeyPEM is the PEM-encoded EC private key used to mint tokens + // (SEC1 "EC PRIVATE KEY" or PKCS#8 "PRIVATE KEY"). Empty → no Signer. + SigningKeyPEM []byte + // SigningKeyID is the "kid" advertised in minted tokens and in the JWKS. + SigningKeyID string + // TTL is the minted-token lifetime (default 60s if zero). + TTL time.Duration + // Leeway tolerates clock skew on verification (default 30s if zero). + Leeway time.Duration + // TrustMode selects the verifier; defaults to signature. + TrustMode TrustMode + // TrustedIssuers maps issuer -> kid -> PEM-encoded EC public key, used by + // the signature verifier. Ignored in transport mode. + TrustedIssuers map[string]map[string][]byte +} + +// Bundle is the materialized signing side of a Config: the Signer (nil when no +// signing key is configured) and the KeyProvider (always non-nil; serves this +// cluster's public JWKS, possibly empty). The Verifier is built separately by +// the host (it needs the trust mode and, for transport mode, a PeerTrustFunc). +type Bundle struct { + Signer Signer + KeyProvider KeyProvider +} + +// New builds a Bundle from cfg. Returns an error only on malformed key material. +func New(cfg Config) (*Bundle, error) { + own := map[string]crypto.PublicKey{} + + var signer Signer + if len(cfg.SigningKeyPEM) > 0 { + priv, err := ParseECPrivateKeyPEM(cfg.SigningKeyPEM) + if err != nil { + return nil, fmt.Errorf("principaltoken: signing key: %w", err) + } + s, err := NewECDSASigner(ECDSASignerOptions{ + Key: priv, KID: cfg.SigningKeyID, Issuer: cfg.Issuer, TTL: cfg.TTL, + }) + if err != nil { + return nil, err + } + signer = s + own[cfg.SigningKeyID] = priv.Public() + } + + trusted := map[string]map[string]crypto.PublicKey{} + for issuer, byKID := range cfg.TrustedIssuers { + m := map[string]crypto.PublicKey{} + for kid, pemBytes := range byKID { + pub, err := ParseECPublicKeyPEM(pemBytes) + if err != nil { + return nil, fmt.Errorf("principaltoken: trusted key %s/%s: %w", issuer, kid, err) + } + m[kid] = pub + } + trusted[issuer] = m + } + + return &Bundle{Signer: signer, KeyProvider: NewStaticKeyProvider(trusted, own)}, nil +} + +// NewVerifier builds the verifier selected by cfg.TrustMode, or nil when the +// feature is off for verification (signature mode with no trusted issuers). +// peerTrust is required for transport mode. +func NewVerifier(cfg Config, keys KeyProvider, peerTrust PeerTrustFunc) Verifier { + if cfg.TrustMode == TrustModeTransport { + return NewTransportVerifier(peerTrust) + } + if len(cfg.TrustedIssuers) > 0 { + return NewJWSVerifier(JWSVerifierOptions{Keys: keys, Leeway: cfg.Leeway}) + } + return nil +} + +// ParseECPrivateKeyPEM decodes a PEM-encoded EC private key (SEC1 or PKCS#8). +func ParseECPrivateKeyPEM(pemBytes []byte) (*ecdsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("no PEM block found") + } + if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + if ec, ok := key.(*ecdsa.PrivateKey); ok { + return ec, nil + } + return nil, errors.New("PKCS#8 key is not an EC private key") + } + return nil, errors.New("not a valid EC private key (SEC1 or PKCS#8)") +} + +// ParseECPublicKeyPEM decodes a PEM-encoded EC public key (PKIX/SPKI). +func ParseECPublicKeyPEM(pemBytes []byte) (*ecdsa.PublicKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("no PEM block found") + } + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, err + } + ec, ok := pub.(*ecdsa.PublicKey) + if !ok { + return nil, errors.New("PEM key is not an EC public key") + } + return ec, nil +} diff --git a/common/nexus/principaltoken/config_test.go b/common/nexus/principaltoken/config_test.go new file mode 100644 index 00000000000..58d385dbeb4 --- /dev/null +++ b/common/nexus/principaltoken/config_test.go @@ -0,0 +1,144 @@ +package principaltoken + +import ( + "context" + "crypto/ecdsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" +) + +func privPEM(t *testing.T, k *ecdsa.PrivateKey) []byte { + t.Helper() + der, err := x509.MarshalECPrivateKey(k) + require.NoError(t, err) + return pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: der}) +} + +func pkcs8PEM(t *testing.T, k *ecdsa.PrivateKey) []byte { + t.Helper() + der, err := x509.MarshalPKCS8PrivateKey(k) + require.NoError(t, err) + return pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: der}) +} + +func pubPEM(t *testing.T, k *ecdsa.PrivateKey) []byte { + t.Helper() + der, err := x509.MarshalPKIXPublicKey(&k.PublicKey) + require.NoError(t, err) + return pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: der}) +} + +func TestConfig_EndToEnd_SEC1(t *testing.T) { + key := newKey(t) + // Cluster A signs; cluster B (this config) trusts A and also signs itself. + cfg := Config{ + Issuer: testIssuer, + SigningKeyPEM: privPEM(t, key), + SigningKeyID: testKID, + TrustedIssuers: map[string]map[string][]byte{ + testIssuer: {testKID: pubPEM(t, key)}, + }, + } + bundle, err := New(cfg) + require.NoError(t, err) + require.NotNil(t, bundle.Signer) + verifier := NewVerifier(cfg, bundle.KeyProvider, nil) + require.NotNil(t, verifier) + + eu := &commonpb.Principal{Type: "users", Name: "alice@example.com"} + svc := &commonpb.Principal{Type: "service-accounts", Name: "worker"} + tok, err := bundle.Signer.Sign(context.Background(), Content{ServiceCaller: svc, EndUser: eu}) + require.NoError(t, err) + + got, err := verifier.Verify(context.Background(), tok) + require.NoError(t, err) + require.True(t, eq(eu, got.EndUser)) + require.True(t, eq(svc, got.ServiceCaller)) +} + +func TestConfig_AcceptsPKCS8SigningKey(t *testing.T) { + key := newKey(t) + cfg := Config{ + Issuer: testIssuer, + SigningKeyPEM: pkcs8PEM(t, key), + SigningKeyID: testKID, + TrustedIssuers: map[string]map[string][]byte{ + testIssuer: {testKID: pubPEM(t, key)}, + }, + } + bundle, err := New(cfg) + require.NoError(t, err) + tok, err := bundle.Signer.Sign(context.Background(), Content{ + EndUser: &commonpb.Principal{Type: "users", Name: "a"}, + }) + require.NoError(t, err) + _, err = NewVerifier(cfg, bundle.KeyProvider, nil).Verify(context.Background(), tok) + require.NoError(t, err) +} + +func TestConfig_EmptyIsFeatureOff(t *testing.T) { + bundle, err := New(Config{}) + require.NoError(t, err) + require.Nil(t, bundle.Signer, "no signing key => no signer") + require.NotNil(t, bundle.KeyProvider, "key provider always present (serves empty JWKS)") + require.Nil(t, NewVerifier(Config{}, bundle.KeyProvider, nil), "no trusted issuers => no verifier") +} + +func TestConfig_SignOnlyAndVerifyOnly(t *testing.T) { + key := newKey(t) + + signCfg := Config{Issuer: testIssuer, SigningKeyPEM: privPEM(t, key), SigningKeyID: testKID} + signOnly, err := New(signCfg) + require.NoError(t, err) + require.NotNil(t, signOnly.Signer) + require.Nil(t, NewVerifier(signCfg, signOnly.KeyProvider, nil)) + + verifyCfg := Config{TrustedIssuers: map[string]map[string][]byte{testIssuer: {testKID: pubPEM(t, key)}}} + verifyOnly, err := New(verifyCfg) + require.NoError(t, err) + require.Nil(t, verifyOnly.Signer) + require.NotNil(t, NewVerifier(verifyCfg, verifyOnly.KeyProvider, nil)) +} + +func TestConfig_MalformedKeysError(t *testing.T) { + _, err := New(Config{SigningKeyPEM: []byte("not a pem"), SigningKeyID: "k"}) + require.Error(t, err) + + _, err = New(Config{TrustedIssuers: map[string]map[string][]byte{"iss": {"k": []byte("nope")}}}) + require.Error(t, err) +} + +func TestJWKSHandler_ServesSigningKey(t *testing.T) { + key := newKey(t) + bundle, err := New(Config{Issuer: testIssuer, SigningKeyPEM: privPEM(t, key), SigningKeyID: testKID}) + require.NoError(t, err) + + srv := httptest.NewServer(JWKSHandler(bundle.KeyProvider)) + defer srv.Close() + + resp, err := http.Get(srv.URL) + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "application/json", resp.Header.Get("Content-Type")) + + var doc struct { + Keys []struct { + Kid string `json:"kid"` + Kty string `json:"kty"` + Use string `json:"use"` + } `json:"keys"` + } + require.NoError(t, json.NewDecoder(resp.Body).Decode(&doc)) + require.Len(t, doc.Keys, 1) + require.Equal(t, testKID, doc.Keys[0].Kid) + require.Equal(t, "EC", doc.Keys[0].Kty) + require.Equal(t, "sig", doc.Keys[0].Use) +} diff --git a/common/nexus/principaltoken/fx.go b/common/nexus/principaltoken/fx.go new file mode 100644 index 00000000000..6e408fc6301 --- /dev/null +++ b/common/nexus/principaltoken/fx.go @@ -0,0 +1,64 @@ +package principaltoken + +import ( + "context" + + "go.uber.org/fx" +) + +// Module materializes the principal-token components from a Config and exposes +// them all at once: the Signer (outbound dispatch), the Verifier (inbound Nexus +// HTTP handler), and the KeyProvider (JWKS endpoint). Include it once in any +// service graph that needs to sign and/or verify propagated identity. +// +// The default Config is empty, so the feature is OFF: Signer and Verifier +// resolve to nil and callers fall back to their no-propagation behavior. +// Operators / Cloud activate it by overriding the Config — supply real key +// material (a PEM signing key and/or trusted-issuer public keys) via +// fx.Decorate: +// +// fx.Decorate(func(principaltoken.Config) principaltoken.Config { +// return principaltoken.Config{Issuer: ..., SigningKeyPEM: ..., ...} +// }) +// +// No call-site change is needed to turn it on. +var Module = fx.Module( + "common.nexus.principaltoken", + fx.Provide(defaultConfig), + fx.Provide(New), // Config -> *Bundle + fx.Provide(signerFromBundle), + fx.Provide(keyProviderFromBundle), + fx.Provide(defaultResolver), + fx.Provide(defaultPeerTrust), + fx.Provide(NewVerifier), // (Config, KeyProvider, PeerTrustFunc) -> Verifier +) + +// defaultConfig is the feature-off default. Override via fx.Decorate. +func defaultConfig() Config { + return Config{} +} + +// defaultResolver is the OSS default: names are already human-readable, so +// nothing is resolved. Cloud overrides this provider with an ID->name resolver. +func defaultResolver() PrincipalResolver { + return NoopResolver{} +} + +// defaultPeerTrust is the OSS default for transport mode: trust nothing. A host +// that uses TrustModeTransport (e.g. Cloud cell-to-cell mTLS) MUST override this +// with a function that recognizes trusted server peers from the request context. +func defaultPeerTrust() PeerTrustFunc { + return func(context.Context) bool { return false } +} + +// signerFromBundle exposes the bundle's signer (nil when no signing key is +// configured) for the outbound dispatch task handler. +func signerFromBundle(b *Bundle) Signer { + return b.Signer +} + +// keyProviderFromBundle exposes the key provider (always non-nil) for serving +// this cluster's public JWKS. +func keyProviderFromBundle(b *Bundle) KeyProvider { + return b.KeyProvider +} diff --git a/common/nexus/principaltoken/fx_test.go b/common/nexus/principaltoken/fx_test.go new file mode 100644 index 00000000000..0773ca6a9b3 --- /dev/null +++ b/common/nexus/principaltoken/fx_test.go @@ -0,0 +1,28 @@ +package principaltoken + +import ( + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/fx" +) + +// TestModule_FxGraph validates that the module's provider chain +// (Config -> Bundle -> Signer/Verifier/KeyProvider) resolves, so it wires +// cleanly into both the frontend (verify) and history (sign) service graphs. +func TestModule_FxGraph(t *testing.T) { + require.NoError(t, fx.ValidateApp( + Module, + fx.Invoke(func(Signer, Verifier, KeyProvider, PrincipalResolver) {}), + )) +} + +// TestModule_DefaultIsFeatureOff confirms the default config yields nil signer +// and verifier (feature off, safe default) and an always-present key provider. +func TestModule_DefaultIsFeatureOff(t *testing.T) { + b, err := New(defaultConfig()) + require.NoError(t, err) + require.Nil(t, signerFromBundle(b)) + require.NotNil(t, keyProviderFromBundle(b)) + require.Nil(t, NewVerifier(defaultConfig(), keyProviderFromBundle(b), defaultPeerTrust())) +} diff --git a/common/nexus/principaltoken/jwks.go b/common/nexus/principaltoken/jwks.go new file mode 100644 index 00000000000..e24518fbf55 --- /dev/null +++ b/common/nexus/principaltoken/jwks.go @@ -0,0 +1,22 @@ +package principaltoken + +import ( + "net/http" +) + +// JWKSHandler returns an HTTP handler that serves this cluster's public +// verification keys as a JWKS document. Peer clusters fetch it to verify tokens +// this cluster mints. It exposes only public keys and is safe to serve +// unauthenticated. +func JWKSHandler(kp KeyProvider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + jwks, err := kp.PublicJWKS(r.Context()) + if err != nil { + http.Error(w, "failed to render JWKS", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "public, max-age=300") + _, _ = w.Write(jwks) + } +} diff --git a/common/nexus/principaltoken/keyprovider.go b/common/nexus/principaltoken/keyprovider.go new file mode 100644 index 00000000000..b1f0ffb6ec4 --- /dev/null +++ b/common/nexus/principaltoken/keyprovider.go @@ -0,0 +1,72 @@ +package principaltoken + +import ( + "context" + "crypto" + "encoding/json" + + jose "github.com/go-jose/go-jose/v4" +) + +// KeyProvider resolves public keys for verification and publishes this +// cluster's own public keys for distribution. It deliberately does NOT expose +// the signing private key: minting is the Signer's concern, so a Cloud +// KMS-backed Signer never has to hand a raw key through this interface. +// +// Verification keys are resolved by (issuer, kid): issuer identifies the +// minting cluster, kid the specific key (enables rotation and multi-key). +type KeyProvider interface { + // VerificationKey returns the public key a token from issuer signed with + // kid should be verified against, or an error if no such trusted key. + VerificationKey(ctx context.Context, issuer, kid string) (crypto.PublicKey, error) + // PublicJWKS returns this cluster's published public keys as a JWKS + // document, for serving at a JWKS endpoint so peers can verify our tokens. + PublicJWKS(ctx context.Context) ([]byte, error) +} + +// StaticKeyProvider is the OSS default: trusted issuer keys and this cluster's +// own published keys are supplied at construction (loaded from config). It has +// no network dependency, so it works in air-gapped on-prem deployments where a +// peer JWKS endpoint may be unreachable. +type StaticKeyProvider struct { + // trusted maps issuer -> kid -> public key, used for verification. + trusted map[string]map[string]crypto.PublicKey + // own maps kid -> public key for this cluster, published via PublicJWKS. + own map[string]crypto.PublicKey +} + +// NewStaticKeyProvider builds a provider from a trusted-issuer table and this +// cluster's own public keys (by kid). Either map may be nil/empty: an empty +// trusted table means "verify nothing" (safe default); an empty own map means +// "publish an empty JWKS". +func NewStaticKeyProvider( + trusted map[string]map[string]crypto.PublicKey, + own map[string]crypto.PublicKey, +) *StaticKeyProvider { + return &StaticKeyProvider{trusted: trusted, own: own} +} + +func (p *StaticKeyProvider) VerificationKey(_ context.Context, issuer, kid string) (crypto.PublicKey, error) { + byKid, ok := p.trusted[issuer] + if !ok { + return nil, ErrVerification + } + key, ok := byKid[kid] + if !ok { + return nil, ErrVerification + } + return key, nil +} + +func (p *StaticKeyProvider) PublicJWKS(_ context.Context) ([]byte, error) { + set := jose.JSONWebKeySet{} + for kid, pub := range p.own { + set.Keys = append(set.Keys, jose.JSONWebKey{ + Key: pub, + KeyID: kid, + Algorithm: signingAlg, + Use: "sig", + }) + } + return json.Marshal(set) +} diff --git a/common/nexus/principaltoken/principaltoken.go b/common/nexus/principaltoken/principaltoken.go new file mode 100644 index 00000000000..b50f5386ed1 --- /dev/null +++ b/common/nexus/principaltoken/principaltoken.go @@ -0,0 +1,143 @@ +// Package principaltoken carries caller identity across the Nexus hop as a +// compact JWS (ES256): trust comes from the signature, not the transport, so it +// works regardless of network topology. The durable artifact is the Principal; +// the token is minted fresh per hop and never persisted. Sign/verify live behind +// interfaces so Cloud can swap implementations (KMS signer, JWKS, transport +// trust). It carries two principals (RFC 8693 style): the end-user (root) to +// preserve end-to-end and the immediate service caller, as structured claims +// (a Principal Name may contain "/", so it is not round-trippable as a "sub"). +package principaltoken + +import ( + "context" + "errors" + + "github.com/golang-jwt/jwt/v4" + commonpb "go.temporal.io/api/common/v1" +) + +// Header carries the signed principal token across the Nexus hop. Stripped from +// inbound external requests (headers.StripPrincipalHTTP), so only the +// server-to-server caller-side write can set it. +const Header = "Temporal-Nexus-Principal-Token" + +// signingAlg is the only accepted JWS algorithm. Asymmetric so verifiers hold +// only public keys and cannot mint tokens. +const signingAlg = "ES256" + +var ( + // ErrNoToken: carrier header absent; treated as "no propagated identity". + ErrNoToken = errors.New("principaltoken: no token present") + // ErrVerification: any verification failure. Opaque on purpose — does not + // leak which check failed. + ErrVerification = errors.New("principaltoken: token verification failed") + // ErrNoSigningKey: no signing key configured (feature off). + ErrNoSigningKey = errors.New("principaltoken: no signing key configured") +) + +// Content is the identity payload to mint into a token. +type Content struct { + // ServiceCaller is the immediate actor (the worker / SDK client that issued + // the outbound Nexus call). May be nil. + ServiceCaller *commonpb.Principal + // EndUser is the chain-originating identity to preserve end-to-end. May be + // nil. For a single-hop standalone operation it equals ServiceCaller. + EndUser *commonpb.Principal + // Display-name snapshots for the respective principals (empty in OSS). They + // travel in the token because the handler often cannot resolve the caller's + // identity itself (e.g. cross-account). + ServiceCallerResolvedName string + EndUserResolvedName string + // NamespaceCaller is the caller's own namespace, self-asserted by the minting + // cluster; only meaningful scoped to the issuer. May be nil. + NamespaceCaller *commonpb.Principal +} + +// Verified holds the trusted principals extracted from a token that passed the +// verifier's trust check. +type Verified struct { + ServiceCaller *commonpb.Principal + EndUser *commonpb.Principal + NamespaceCaller *commonpb.Principal + ServiceCallerResolvedName string + EndUserResolvedName string + Issuer string +} + +// Signer mints a signed carrier token for the given content. The OSS default +// is ECDSASigner; Cloud may provide a KMS-backed implementation. +type Signer interface { + Sign(ctx context.Context, content Content) (string, error) +} + +// Verifier validates a carrier token and returns the trusted principals. +// Implementations must return ErrVerification (not a descriptive error) on any +// failure, and must never return principals from an untrusted token. +type Verifier interface { + Verify(ctx context.Context, token string) (*Verified, error) +} + +// principalClaim is the structured, round-trippable representation of a +// Principal inside the token (avoids the lossy display string), plus the +// resolved-name snapshot. +type principalClaim struct { + Type string `json:"type,omitempty"` + Name string `json:"name,omitempty"` + ResolvedName string `json:"resolved_name,omitempty"` +} + +func toPrincipalClaim(p *commonpb.Principal, resolvedName string) *principalClaim { + empty := p.GetType() == "" && p.GetName() == "" + if empty && resolvedName == "" { + return nil + } + return &principalClaim{ + Type: p.GetType(), + Name: p.GetName(), + ResolvedName: resolvedName, + } +} + +func (c *principalClaim) toPrincipal() *commonpb.Principal { + if c == nil || (c.Type == "" && c.Name == "") { + return nil + } + return &commonpb.Principal{Type: c.Type, Name: c.Name} +} + +func (c *principalClaim) resolvedName() string { + if c == nil { + return "" + } + return c.ResolvedName +} + +// displayString renders a principal as type/name. Informational only — set as +// "sub" for debugging/auditing, never parsed back. +func displayString(p *commonpb.Principal) string { + if p == nil { + return "" + } + return p.GetType() + "/" + p.GetName() +} + +// tokenClaims is the JWS payload: end-user (subject), service caller, and +// namespace caller as structured claims. +type tokenClaims struct { + jwt.RegisteredClaims + EndUser *principalClaim `json:"https://temporal.io/end_user,omitempty"` + ServiceCaller *principalClaim `json:"https://temporal.io/service_caller,omitempty"` + NamespaceCaller *principalClaim `json:"https://temporal.io/namespace_caller,omitempty"` +} + +// toVerified projects the claims into a Verified (shared by both verifiers). +func (c *tokenClaims) toVerified() *Verified { + return &Verified{ + ServiceCaller: c.ServiceCaller.toPrincipal(), + EndUser: c.EndUser.toPrincipal(), + NamespaceCaller: c.NamespaceCaller.toPrincipal(), + ServiceCallerResolvedName: c.ServiceCaller.resolvedName(), + EndUserResolvedName: c.EndUser.resolvedName(), + Issuer: c.Issuer, + } +} diff --git a/common/nexus/principaltoken/principaltoken_test.go b/common/nexus/principaltoken/principaltoken_test.go new file mode 100644 index 00000000000..8e38fc833d2 --- /dev/null +++ b/common/nexus/principaltoken/principaltoken_test.go @@ -0,0 +1,220 @@ +package principaltoken + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" +) + +const ( + testIssuer = "cluster-a" + testKID = "key-1" +) + +func newKey(t *testing.T) *ecdsa.PrivateKey { + t.Helper() + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + return k +} + +// harness wires a signer + JWS verifier sharing a trusted key, with a +// controllable clock shared between them. +type harness struct { + signer *ECDSASigner + verifier *JWSVerifier + now time.Time +} + +func newHarness(t *testing.T, key *ecdsa.PrivateKey, ttl, leeway time.Duration) *harness { + t.Helper() + h := &harness{now: time.Date(2026, 6, 5, 12, 0, 0, 0, time.UTC)} + nowFn := func() time.Time { return h.now } + + signer, err := NewECDSASigner(ECDSASignerOptions{Key: key, KID: testKID, Issuer: testIssuer, TTL: ttl, NowFn: nowFn}) + require.NoError(t, err) + h.signer = signer + + kp := NewStaticKeyProvider( + map[string]map[string]crypto.PublicKey{testIssuer: {testKID: key.Public()}}, + map[string]crypto.PublicKey{testKID: key.Public()}, + ) + h.verifier = NewJWSVerifier(JWSVerifierOptions{Keys: kp, Leeway: leeway, NowFn: nowFn}) + return h +} + +func TestSignVerify_RoundTrip(t *testing.T) { + h := newHarness(t, newKey(t), time.Minute, time.Second) + svc := &commonpb.Principal{Type: "service-accounts", Name: "worker"} + eu := &commonpb.Principal{Type: "users", Name: "alice@example.com"} + + tok, err := h.signer.Sign(context.Background(), Content{ServiceCaller: svc, EndUser: eu}) + require.NoError(t, err) + got, err := h.verifier.Verify(context.Background(), tok) + require.NoError(t, err) + require.True(t, eq(svc, got.ServiceCaller)) + require.True(t, eq(eu, got.EndUser)) + require.Equal(t, testIssuer, got.Issuer) +} + +// Guards the reason we use structured claims instead of a slash-joined "sub": +// a Name may contain "/". +func TestSignVerify_PreservesSlashInName(t *testing.T) { + h := newHarness(t, newKey(t), time.Minute, time.Second) + eu := &commonpb.Principal{Type: "service-accounts", Name: "my-project/sa-prod"} + tok, err := h.signer.Sign(context.Background(), Content{EndUser: eu}) + require.NoError(t, err) + got, err := h.verifier.Verify(context.Background(), tok) + require.NoError(t, err) + require.True(t, eq(eu, got.EndUser)) +} + +// Cloud case: principal.Name is an opaque ID and the human-readable snapshot +// rides in the token so the handler (which can't resolve the ID) can surface it. +func TestSignVerify_CarriesResolvedName(t *testing.T) { + h := newHarness(t, newKey(t), time.Minute, time.Second) + eu := &commonpb.Principal{Type: "cloud-identity", Name: "id-uuid-123"} + tok, err := h.signer.Sign(context.Background(), Content{ + EndUser: eu, EndUserResolvedName: "alice@example.com", + }) + require.NoError(t, err) + got, err := h.verifier.Verify(context.Background(), tok) + require.NoError(t, err) + require.Equal(t, "alice@example.com", got.EndUserResolvedName) +} + +func TestVerify_RejectsTamperedToken(t *testing.T) { + h := newHarness(t, newKey(t), time.Minute, time.Second) + tok, err := h.signer.Sign(context.Background(), Content{EndUser: &commonpb.Principal{Type: "users", Name: "alice"}}) + require.NoError(t, err) + b := []byte(tok) + b[len(b)/2] ^= 0x01 + _, err = h.verifier.Verify(context.Background(), string(b)) + require.ErrorIs(t, err, ErrVerification) +} + +func TestVerify_RejectsWrongKey(t *testing.T) { + h := newHarness(t, newKey(t), time.Minute, time.Second) + other, err := NewECDSASigner(ECDSASignerOptions{Key: newKey(t), KID: testKID, Issuer: testIssuer, NowFn: func() time.Time { return h.now }}) + require.NoError(t, err) + tok, err := other.Sign(context.Background(), Content{EndUser: &commonpb.Principal{Type: "users", Name: "mallory"}}) + require.NoError(t, err) + _, err = h.verifier.Verify(context.Background(), tok) + require.ErrorIs(t, err, ErrVerification) +} + +func TestVerify_RejectsUnknownIssuerOrKID(t *testing.T) { + key := newKey(t) + nowFn := func() time.Time { return time.Date(2026, 6, 5, 12, 0, 0, 0, time.UTC) } + tok, err := mint(t, key, "cluster-unknown", "key-x", nowFn) + require.NoError(t, err) + kp := NewStaticKeyProvider(map[string]map[string]crypto.PublicKey{testIssuer: {testKID: key.Public()}}, nil) + v := NewJWSVerifier(JWSVerifierOptions{Keys: kp, NowFn: nowFn}) + _, err = v.Verify(context.Background(), tok) + require.ErrorIs(t, err, ErrVerification) +} + +func TestVerify_RejectsExpired(t *testing.T) { + h := newHarness(t, newKey(t), time.Minute, 5*time.Second) + tok, err := h.signer.Sign(context.Background(), Content{EndUser: &commonpb.Principal{Type: "users", Name: "alice"}}) + require.NoError(t, err) + h.now = h.now.Add(time.Minute + 6*time.Second) + _, err = h.verifier.Verify(context.Background(), tok) + require.ErrorIs(t, err, ErrVerification) +} + +func TestVerify_ToleratesClockSkewWithinLeeway(t *testing.T) { + h := newHarness(t, newKey(t), 10*time.Second, 30*time.Second) + tok, err := h.signer.Sign(context.Background(), Content{EndUser: &commonpb.Principal{Type: "users", Name: "alice"}}) + require.NoError(t, err) + h.now = h.now.Add(20 * time.Second) // 10s past exp, within 30s leeway + _, err = h.verifier.Verify(context.Background(), tok) + require.NoError(t, err) +} + +func TestVerify_EmptyTokenIsErrNoToken(t *testing.T) { + h := newHarness(t, newKey(t), time.Minute, time.Second) + _, err := h.verifier.Verify(context.Background(), "") + require.ErrorIs(t, err, ErrNoToken) +} + +func TestSign_NoKeyIsErr(t *testing.T) { + _, err := NewECDSASigner(ECDSASignerOptions{Issuer: testIssuer}) + require.ErrorIs(t, err, ErrNoSigningKey) +} + +func TestStaticKeyProvider_PublicJWKS(t *testing.T) { + key := newKey(t) + kp := NewStaticKeyProvider(nil, map[string]crypto.PublicKey{testKID: key.Public()}) + jwks, err := kp.PublicJWKS(context.Background()) + require.NoError(t, err) + require.Contains(t, string(jwks), testKID) + require.Contains(t, string(jwks), "\"kty\":\"EC\"") +} + +// TransportVerifier trusts the claims because the peer is trusted, without a +// signature check — the Cloud cell-to-cell mTLS path. +func TestTransportVerifier(t *testing.T) { + key := newKey(t) + nowFn := func() time.Time { return time.Date(2026, 6, 5, 12, 0, 0, 0, time.UTC) } + svc := &commonpb.Principal{Type: "service-accounts", Name: "worker"} + eu := &commonpb.Principal{Type: "users", Name: "alice"} + signed, err := mintWith(t, key, Content{ServiceCaller: svc, EndUser: eu}, nowFn) + require.NoError(t, err) + + t.Run("trusted peer accepted without signature check", func(t *testing.T) { + v := NewTransportVerifier(func(context.Context) bool { return true }) + got, err := v.Verify(context.Background(), signed) + require.NoError(t, err) + require.True(t, eq(eu, got.EndUser)) + require.True(t, eq(svc, got.ServiceCaller)) + }) + t.Run("untrusted peer refused", func(t *testing.T) { + v := NewTransportVerifier(func(context.Context) bool { return false }) + _, err := v.Verify(context.Background(), signed) + require.ErrorIs(t, err, ErrVerification) + }) + t.Run("nil trust func refused", func(t *testing.T) { + v := NewTransportVerifier(nil) + _, err := v.Verify(context.Background(), signed) + require.ErrorIs(t, err, ErrVerification) + }) +} + +func TestNewVerifier_SelectsByTrustMode(t *testing.T) { + key := newKey(t) + kp := NewStaticKeyProvider(map[string]map[string]crypto.PublicKey{testIssuer: {testKID: key.Public()}}, nil) + trustAll := func(context.Context) bool { return true } + + require.IsType(t, &TransportVerifier{}, NewVerifier(Config{TrustMode: TrustModeTransport}, kp, trustAll)) + require.IsType(t, &JWSVerifier{}, + NewVerifier(Config{TrustedIssuers: map[string]map[string][]byte{testIssuer: {testKID: nil}}}, kp, trustAll)) + require.Nil(t, NewVerifier(Config{}, kp, trustAll), "signature mode with no trusted issuers => off") +} + +// --- helpers --- + +func mint(t *testing.T, key *ecdsa.PrivateKey, issuer, kid string, nowFn func() time.Time) (string, error) { + t.Helper() + s, err := NewECDSASigner(ECDSASignerOptions{Key: key, KID: kid, Issuer: issuer, NowFn: nowFn}) + require.NoError(t, err) + return s.Sign(context.Background(), Content{EndUser: &commonpb.Principal{Type: "users", Name: "alice"}}) +} + +func mintWith(t *testing.T, key *ecdsa.PrivateKey, c Content, nowFn func() time.Time) (string, error) { + t.Helper() + s, err := NewECDSASigner(ECDSASignerOptions{Key: key, KID: testKID, Issuer: testIssuer, NowFn: nowFn}) + require.NoError(t, err) + return s.Sign(context.Background(), c) +} + +func eq(a, b *commonpb.Principal) bool { + return a.GetType() == b.GetType() && a.GetName() == b.GetName() +} diff --git a/common/nexus/principaltoken/resolver.go b/common/nexus/principaltoken/resolver.go new file mode 100644 index 00000000000..75717c05db2 --- /dev/null +++ b/common/nexus/principaltoken/resolver.go @@ -0,0 +1,29 @@ +package principaltoken + +import ( + "context" + + commonpb "go.temporal.io/api/common/v1" +) + +// PrincipalResolver snapshots the human-readable name for a principal at the +// time identity is captured/propagated. In OSS principal.Name is already +// human-readable, so the default NoopResolver returns "" and readers fall back +// to principal.Name. Cloud injects a resolver that maps an opaque identity ID +// to its current display name, captured at write time so audit records survive +// later renames/deletes. +type PrincipalResolver interface { + // Resolve returns the human-readable display name for p, or "" if there is + // nothing to resolve (the name is already human-readable). It must not fail + // the caller's operation on lookup error — implementations should return + // ("", nil) to degrade to the opaque name rather than block propagation. + Resolve(ctx context.Context, p *commonpb.Principal) (resolvedName string, err error) +} + +// NoopResolver is the OSS default: principal.Name is already human-readable, so +// there is nothing to resolve. +type NoopResolver struct{} + +func (NoopResolver) Resolve(context.Context, *commonpb.Principal) (string, error) { + return "", nil +} diff --git a/common/nexus/principaltoken/signer.go b/common/nexus/principaltoken/signer.go new file mode 100644 index 00000000000..3bc0fe2528c --- /dev/null +++ b/common/nexus/principaltoken/signer.go @@ -0,0 +1,78 @@ +package principaltoken + +import ( + "context" + "crypto/ecdsa" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +// ECDSASigner is the OSS default Signer: it mints ES256-signed tokens from a +// single in-process ECDSA private key. Cloud may replace it with a KMS-backed +// Signer; nothing else in the codebase depends on the concrete type. +type ECDSASigner struct { + key *ecdsa.PrivateKey + kid string + issuer string + ttl time.Duration + nowFn func() time.Time +} + +// ECDSASignerOptions configures an ECDSASigner. +type ECDSASignerOptions struct { + // Key is the ECDSA private key used to sign. Required. + Key *ecdsa.PrivateKey + // KID identifies the key in the token header so verifiers can select the + // matching public key (and so keys can be rotated). Required. + KID string + // Issuer identifies this minting cluster; verifiers resolve keys by it. + // Required. + Issuer string + // TTL is the token lifetime. Short by design (covers the hop RPC, not the + // async operation); re-mint on retry. Defaults to 60s if zero. + TTL time.Duration + // NowFn overrides the clock (tests). Defaults to time.Now. + NowFn func() time.Time +} + +// NewECDSASigner constructs an ECDSASigner. Returns ErrNoSigningKey if no key +// is supplied (lets call sites treat "feature off" uniformly). +func NewECDSASigner(opts ECDSASignerOptions) (*ECDSASigner, error) { + if opts.Key == nil { + return nil, ErrNoSigningKey + } + ttl := opts.TTL + if ttl <= 0 { + ttl = 60 * time.Second + } + nowFn := opts.NowFn + if nowFn == nil { + nowFn = time.Now + } + return &ECDSASigner{ + key: opts.Key, + kid: opts.KID, + issuer: opts.Issuer, + ttl: ttl, + nowFn: nowFn, + }, nil +} + +func (s *ECDSASigner) Sign(_ context.Context, content Content) (string, error) { + now := s.nowFn() + claims := tokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: s.issuer, + Subject: displayString(content.EndUser), + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(s.ttl)), + }, + EndUser: toPrincipalClaim(content.EndUser, content.EndUserResolvedName), + ServiceCaller: toPrincipalClaim(content.ServiceCaller, content.ServiceCallerResolvedName), + NamespaceCaller: toPrincipalClaim(content.NamespaceCaller, ""), + } + token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token.Header["kid"] = s.kid + return token.SignedString(s.key) +} diff --git a/common/nexus/principaltoken/transport_verifier.go b/common/nexus/principaltoken/transport_verifier.go new file mode 100644 index 00000000000..d93ce457458 --- /dev/null +++ b/common/nexus/principaltoken/transport_verifier.go @@ -0,0 +1,44 @@ +package principaltoken + +import ( + "context" + + "github.com/golang-jwt/jwt/v4" +) + +// PeerTrustFunc reports whether the inbound connection is a trusted server peer +// (e.g. cell-to-cell mTLS). +type PeerTrustFunc func(ctx context.Context) bool + +// TransportVerifier trusts a token's claims because the connection is trusted, +// not because it is signed — it reads claims unverified, gated on PeerTrustFunc. +// Not the OSS default (Cloud selects it via config); must never accept tokens +// from an unauthenticated peer. +type TransportVerifier struct { + trusted PeerTrustFunc + parser *jwt.Parser +} + +// NewTransportVerifier constructs a TransportVerifier. trusted is required. +func NewTransportVerifier(trusted PeerTrustFunc) *TransportVerifier { + return &TransportVerifier{ + trusted: trusted, + parser: jwt.NewParser(jwt.WithoutClaimsValidation()), + } +} + +func (v *TransportVerifier) Verify(ctx context.Context, token string) (*Verified, error) { + if token == "" { + return nil, ErrNoToken + } + // Channel trust is the gate. If the peer is not trusted we refuse outright, + // regardless of token contents. + if v.trusted == nil || !v.trusted(ctx) { + return nil, ErrVerification + } + var claims tokenClaims + if _, _, err := v.parser.ParseUnverified(token, &claims); err != nil { + return nil, ErrVerification + } + return claims.toVerified(), nil +} diff --git a/common/nexus/principaltoken/verifier.go b/common/nexus/principaltoken/verifier.go new file mode 100644 index 00000000000..a883da02eae --- /dev/null +++ b/common/nexus/principaltoken/verifier.go @@ -0,0 +1,87 @@ +package principaltoken + +import ( + "context" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +// JWSVerifier is the OSS default Verifier: it validates the ES256 signature +// against the (issuer, kid) public key from the KeyProvider, then enforces +// expiry (with leeway) and issuer presence. +type JWSVerifier struct { + keys KeyProvider + leeway time.Duration + nowFn func() time.Time + parser *jwt.Parser +} + +// JWSVerifierOptions configures a JWSVerifier. +type JWSVerifierOptions struct { + // Keys resolves trusted public keys. Required. + Keys KeyProvider + // Leeway tolerates clock skew between caller and handler clusters when + // checking exp/iat. Defaults to 30s. + Leeway time.Duration + // NowFn overrides the clock (tests). Defaults to time.Now. + NowFn func() time.Time +} + +// NewJWSVerifier constructs a JWSVerifier. +func NewJWSVerifier(opts JWSVerifierOptions) *JWSVerifier { + leeway := opts.Leeway + if leeway <= 0 { + leeway = 30 * time.Second + } + nowFn := opts.NowFn + if nowFn == nil { + nowFn = time.Now + } + return &JWSVerifier{ + keys: opts.Keys, + leeway: leeway, + nowFn: nowFn, + // Restrict to the asymmetric algorithm and do claim validation + // ourselves (golang-jwt/v4 has no leeway option). + parser: jwt.NewParser( + jwt.WithValidMethods([]string{signingAlg}), + jwt.WithoutClaimsValidation(), + ), + } +} + +func (v *JWSVerifier) Verify(ctx context.Context, token string) (*Verified, error) { + if token == "" { + return nil, ErrNoToken + } + var claims tokenClaims + keyFunc := func(t *jwt.Token) (any, error) { + kid, ok := t.Header["kid"].(string) + if !ok || kid == "" || claims.Issuer == "" { + return nil, ErrVerification + } + key, err := v.keys.VerificationKey(ctx, claims.Issuer, kid) + if err != nil { + return nil, ErrVerification + } + return key, nil + } + + if _, err := v.parser.ParseWithClaims(token, &claims, keyFunc); err != nil { + // Collapse all parse/signature errors into the opaque sentinel. + return nil, ErrVerification + } + + now := v.nowFn() + // Expiry is required and checked with leeway. + if claims.ExpiresAt == nil || now.Add(-v.leeway).After(claims.ExpiresAt.Time) { + return nil, ErrVerification + } + // iat, if present, must not be in the future beyond leeway. + if claims.IssuedAt != nil && now.Add(v.leeway).Before(claims.IssuedAt.Time) { + return nil, ErrVerification + } + + return claims.toVerified(), nil +} diff --git a/go.mod b/go.mod index 083721514d4..f706122264f 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( go.opentelemetry.io/otel/sdk v1.43.0 go.opentelemetry.io/otel/sdk/metric v1.43.0 go.opentelemetry.io/otel/trace v1.43.0 - go.temporal.io/api v1.62.13 + go.temporal.io/api v1.62.14-0.20260606022533-fc01c810514c go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2 go.temporal.io/sdk v1.41.1 go.uber.org/fx v1.24.0 diff --git a/go.sum b/go.sum index 1f83be49354..d4229462639 100644 --- a/go.sum +++ b/go.sum @@ -471,8 +471,10 @@ go.opentelemetry.io/proto/slim/otlp/collector/profiles/v1development v0.3.0 h1:R go.opentelemetry.io/proto/slim/otlp/collector/profiles/v1development v0.3.0/go.mod h1:I89cynRj8y+383o7tEQVg2SVA6SRgDVIouWPUVXjx0U= go.opentelemetry.io/proto/slim/otlp/profiles/v1development v0.3.0 h1:CQvJSldHRUN6Z8jsUeYv8J0lXRvygALXIzsmAeCcZE0= go.opentelemetry.io/proto/slim/otlp/profiles/v1development v0.3.0/go.mod h1:xSQ+mEfJe/GjK1LXEyVOoSI1N9JV9ZI923X5kup43W4= -go.temporal.io/api v1.62.13 h1:xMa8Nt5oAMX+LvlCJA44wjTCc1H09i2rG9poB1/xvH4= -go.temporal.io/api v1.62.13/go.mod h1:0k75tRljEuELWGeXjEZZO7zYqBln4+1FrG6+IMOMy7Q= +go.temporal.io/api v1.62.14-0.20260606004703-c1686522a401 h1:dz1t3hwLY2JNZzRpdW41ihXtLpb7INX+Qn6IPU4Re5E= +go.temporal.io/api v1.62.14-0.20260606004703-c1686522a401/go.mod h1:0k75tRljEuELWGeXjEZZO7zYqBln4+1FrG6+IMOMy7Q= +go.temporal.io/api v1.62.14-0.20260606022533-fc01c810514c h1:WcWcOtKUvwLGSkmJizWmSmft6MmZLcULebgYa6EzXoM= +go.temporal.io/api v1.62.14-0.20260606022533-fc01c810514c/go.mod h1:0k75tRljEuELWGeXjEZZO7zYqBln4+1FrG6+IMOMy7Q= go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2 h1:1hKeH3GyR6YD6LKMHGCZ76t6h1Sgha0hXVQBxWi3dlQ= go.temporal.io/auto-scaled-workers v0.0.0-20260407181057-edd947d743d2/go.mod h1:T8dnzVPeO+gaUTj9eDgm/lT2lZH4+JXNvrGaQGyVi50= go.temporal.io/sdk v1.41.1 h1:yOpvsHyDD1lNuwlGBv/SUodCPhjv9nDeC9lLHW/fJUA= diff --git a/proto/internal/temporal/server/api/common/v1/principal.proto b/proto/internal/temporal/server/api/common/v1/principal.proto new file mode 100644 index 00000000000..3565ec786f5 --- /dev/null +++ b/proto/internal/temporal/server/api/common/v1/principal.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package temporal.server.api.common.v1; + +import "temporal/api/common/v1/message.proto"; + +option go_package = "go.temporal.io/server/api/common/v1;commonspb"; + +// AttributedPrincipal is the at-rest representation of a captured caller +// identity used by Nexus storage. It embeds the stable identity +// (temporal.api.common.v1.Principal) verbatim and adds a write-time snapshot of +// the human-readable name, for audit fidelity when the identity is an opaque ID +// that may later be renamed or deleted. +message AttributedPrincipal { + // The stable identity, copied verbatim from the authenticated Principal. + // In OSS, principal.name is already human-readable (e.g. jwt/, + // mtls/). In Cloud, principal.name is an opaque ID (never surfaced to + // users) and principal.type marks it as a cloud identity. + temporal.api.common.v1.Principal principal = 1; + + // resolved_name is a snapshot of the human-readable identity name at write + // time. Populated only where principal.name is an opaque ID (Cloud); EMPTY in + // OSS, where principal.name is already human-readable. Readers MUST fall back + // to principal.name when this is empty. + string resolved_name = 2; +} diff --git a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto index afa08ab7dfc..894655a3e0b 100644 --- a/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto +++ b/proto/internal/temporal/server/api/matchingservice/v1/request_response.proto @@ -549,6 +549,11 @@ message DispatchNexusTaskRequest { // Nexus request extracted by the frontend and translated into Temporal API format. temporal.api.nexus.v1.Request request = 3; temporal.server.api.taskqueue.v1.TaskForwardInfo forward_info = 4; + // Verified caller identity the frontend attributed to this request, carried + // with the task so matching can surface it on PollNexusTaskQueueResponse for + // the handler worker. Server-set and trusted; never sourced from the inbound + // request. Unset when caller attribution is not configured. + temporal.api.nexus.v1.NexusCallerInfo caller_info = 5; } message DispatchNexusTaskResponse { diff --git a/service/frontend/fx.go b/service/frontend/fx.go index 2e70184a3e9..f7bcf4ca48c 100644 --- a/service/frontend/fx.go +++ b/service/frontend/fx.go @@ -97,6 +97,7 @@ var Module = fx.Options( fx.Provide(SlowRequestLoggerInterceptorProvider), fx.Provide(MaskInternalErrorDetailsInterceptorProvider), fx.Provide(ContextMetadataInterceptorProvider), + fx.Provide(NewPrincipalTokenInterceptor), fx.Provide(GrpcServerOptionsProvider), fx.Provide(VisibilityManagerProvider), fx.Provide(ThrottledLoggerRpsFnProvider), @@ -239,6 +240,7 @@ func GrpcServerOptionsProvider( sdkVersionInterceptor *interceptor.SDKVersionInterceptor, callerInfoInterceptor *interceptor.CallerInfoInterceptor, authInterceptor *authorization.Interceptor, + principalTokenInterceptor *PrincipalTokenInterceptor, maskInternalErrorDetailsInterceptor *interceptor.MaskInternalErrorDetailsInterceptor, contextMetadataInterceptor *interceptor.ContextMetadataInterceptor, slowRequestLoggerInterceptor *interceptor.SlowRequestLoggerInterceptor, @@ -284,6 +286,10 @@ func GrpcServerOptionsProvider( namespaceLogInterceptor.Intercept, // TODO: Deprecate this with a outer custom interceptor metrics.NewServerMetricsContextInjectorInterceptor(), authInterceptor.Intercept, + // Promote a verified forwarded principal token (after auth has stripped + // any spoofed principal headers) so a workflow started in response to a + // Nexus operation inherits the original end-user as RootCallerPrincipal. + principalTokenInterceptor.Intercept, // Handover interceptor has to above redirection because the request will route to the correct cluster after handover completed. // And retry cannot be performed before customInterceptors. namespaceHandoverInterceptor.Intercept, diff --git a/service/frontend/nexus_completion_http_handler.go b/service/frontend/nexus_completion_http_handler.go index b9c514837be..2c25fc102e3 100644 --- a/service/frontend/nexus_completion_http_handler.go +++ b/service/frontend/nexus_completion_http_handler.go @@ -393,10 +393,14 @@ func (h *nexusCompletionHandler) forwardCompleteOperation(ctx context.Context, r func (h *nexusCompletionHTTPHandler) RegisterRoutes(r *mux.Router) { r.Path("/" + commonnexus.RouteCompletionCallback.Representation()).HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Strip principal-bearing HTTP headers at this external ingress + // boundary so completion callers cannot spoof identity. + headers.StripPrincipalHTTP(r.Header) r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes) h.httpHandler.ServeHTTP(w, r) }) r.Path(commonnexus.PathCompletionCallbackNoIdentifier).HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headers.StripPrincipalHTTP(r.Header) r.Body = http.MaxBytesReader(w, r.Body, rpc.MaxNexusAPIRequestBodyBytes) h.httpHandler.ServeHTTP(w, r) }) diff --git a/service/frontend/nexus_handler.go b/service/frontend/nexus_handler.go index 98710d504ba..80de7a39dc7 100644 --- a/service/frontend/nexus_handler.go +++ b/service/frontend/nexus_handler.go @@ -15,6 +15,7 @@ import ( "time" "github.com/nexus-rpc/sdk-go/nexus" + commonpb "go.temporal.io/api/common/v1" enumspb "go.temporal.io/api/enums/v1" nexuspb "go.temporal.io/api/nexus/v1" "go.temporal.io/api/serviceerror" @@ -32,6 +33,7 @@ import ( "go.temporal.io/server/common/namespace" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/nexus/nexusrpc" + "go.temporal.io/server/common/nexus/principaltoken" "go.temporal.io/server/common/rpc/interceptor" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/timestamppb" @@ -84,6 +86,7 @@ type operationContext struct { forwardingEnabledForNamespace dynamicconfig.BoolPropertyFnWithNamespaceFilter headersBlacklist dynamicconfig.TypedPropertyFn[*regexp.Regexp] metricTagConfig dynamicconfig.TypedPropertyFn[chasmnexus.NexusMetricTagConfig] + principalVerifier principaltoken.Verifier cleanupFunctions []func(map[string]string, error) } @@ -119,15 +122,53 @@ func (c *operationContext) capturePanicAndRecordMetrics(ctxPtr *context.Context, } } -func (c *operationContext) matchingRequest(req *nexuspb.Request) *matchingservice.DispatchNexusTaskRequest { +func (c *operationContext) matchingRequest(ctx context.Context, req *nexuspb.Request) *matchingservice.DispatchNexusTaskRequest { req.Endpoint = c.endpointName return &matchingservice.DispatchNexusTaskRequest{ NamespaceId: c.namespace.ID().String(), TaskQueue: &taskqueuepb.TaskQueue{Name: c.taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, Request: req, + CallerInfo: nexusCallerInfo(ctx), } } +// nexusCallerInfo builds the verified caller identity to attach to the outbound +// matching task: the immediate (service) caller and end-user (root) come from +// the context (promoted by promotePrincipalToken, or set by the auth interceptor +// for direct callers); the namespace caller comes from the verified token +// (stashed by promotePrincipalToken). Returns nil when nothing is attributed, +// leaving the field unset. +func nexusCallerInfo(ctx context.Context) *nexuspb.NexusCallerInfo { + service := headers.GetPrincipal(ctx) + root := headers.GetEndUserPrincipal(ctx) + nsCaller := namespaceCallerFromContext(ctx) + if service == nil && root == nil && nsCaller == nil { + return nil + } + return &nexuspb.NexusCallerInfo{Root: root, Service: service, Namespace: nsCaller} +} + +// namespaceCallerContextKey carries the verified namespace caller from +// promotePrincipalToken to nexusCallerInfo without putting it on gRPC metadata +// (it is not needed downstream of the poll response, unlike the service/root +// principals which the authorizer and history seeding read from headers). +type namespaceCallerContextKey struct{} + +func withNamespaceCaller(ctx context.Context, p *commonpb.Principal) context.Context { + if p == nil { + return ctx + } + return context.WithValue(ctx, namespaceCallerContextKey{}, p) +} + +func namespaceCallerFromContext(ctx context.Context) *commonpb.Principal { + p, ok := ctx.Value(namespaceCallerContextKey{}).(*commonpb.Principal) + if !ok { + return nil + } + return p +} + func (c *operationContext) augmentContext(ctx context.Context, header nexus.Header) context.Context { ctx = metrics.AddMetricsContext(ctx) ctx = interceptor.AddTelemetryContext(ctx, c.metricsHandlerForInterceptors) @@ -153,9 +194,51 @@ func (c *operationContext) augmentContext(ctx context.Context, header nexus.Head } } } + ctx = c.promotePrincipalToken(ctx, header) return headers.Propagate(ctx) } +// promotePrincipalToken verifies the propagated token and promotes the trusted +// principals into the incoming metadata, where the authorizer and history +// stamping read them. The token isn't stripped because its signature (not its +// position) establishes trust; on any verification failure nothing is promoted. +func (c *operationContext) promotePrincipalToken(ctx context.Context, header nexus.Header) context.Context { + if c.principalVerifier == nil { + return ctx + } + token := header.Get(principaltoken.Header) + if token == "" { + return ctx + } + verified, err := c.principalVerifier.Verify(ctx, token) + if err != nil { + // Don't surface details; an invalid/forged token is simply dropped. + c.logger.Warn("rejected nexus principal token", tag.Error(err)) + return ctx + } + // Surface the human-readable name (resolved-name snapshot); principalForDisplay + // falls back to the principal's own Name when empty (OSS). + if verified.ServiceCaller != nil { + ctx = headers.SetPrincipal(ctx, principalForDisplay(verified.ServiceCaller, verified.ServiceCallerResolvedName)) + } + if verified.EndUser != nil { + ctx = headers.SetEndUserPrincipal(ctx, principalForDisplay(verified.EndUser, verified.EndUserResolvedName)) + } + ctx = withNamespaceCaller(ctx, verified.NamespaceCaller) + return ctx +} + +// principalForDisplay returns the principal with its Name replaced by the +// resolved-name snapshot when one is present (Cloud), so audit surfaces a +// human-readable name rather than an opaque ID. Empty snapshot (OSS) leaves the +// principal unchanged. +func principalForDisplay(p *commonpb.Principal, resolvedName string) *commonpb.Principal { + if p == nil || resolvedName == "" { + return p + } + return &commonpb.Principal{Type: p.GetType(), Name: resolvedName} +} + func (c *operationContext) interceptRequest( ctx context.Context, request *matchingservice.DispatchNexusTaskRequest, @@ -330,6 +413,7 @@ type nexusHandler struct { useForwardByEndpoint dynamicconfig.BoolPropertyFn metricTagConfig dynamicconfig.TypedPropertyFn[chasmnexus.NexusMetricTagConfig] httpTraceProvider commonnexus.HTTPClientTraceProvider + principalVerifier principaltoken.Verifier } // Extracts a nexusContext from the given ctx and returns an operationContext with tagged metrics and logging. @@ -351,6 +435,7 @@ func (h *nexusHandler) getOperationContext(ctx context.Context, method string) ( forwardingEnabledForNamespace: h.forwardingEnabledForNamespace, headersBlacklist: h.headersBlacklist, metricTagConfig: h.metricTagConfig, + principalVerifier: h.principalVerifier, cleanupFunctions: make([]func(map[string]string, error), 0), } oc.metricsHandlerForInterceptors = h.metricsHandler.WithTags( @@ -414,7 +499,7 @@ func (h *nexusHandler) StartOperation( RequestId: options.RequestID, Links: links, } - request := oc.matchingRequest(&nexuspb.Request{ + request := oc.matchingRequest(ctx, &nexuspb.Request{ ScheduledTime: timestamppb.New(oc.requestStartTime), Header: options.Header, Variant: &nexuspb.Request_StartOperation{ @@ -637,7 +722,7 @@ func (h *nexusHandler) CancelOperation(ctx context.Context, service, operation, oc.enrichNexusOperationMetrics(service, operation, options.Header) defer oc.capturePanicAndRecordMetrics(&ctx, &retErr) - request := oc.matchingRequest(&nexuspb.Request{ + request := oc.matchingRequest(ctx, &nexuspb.Request{ Header: options.Header, ScheduledTime: timestamppb.New(oc.requestStartTime), Variant: &nexuspb.Request_CancelOperation{ diff --git a/service/frontend/nexus_operation_http_handler.go b/service/frontend/nexus_operation_http_handler.go index e47ef3825ef..a9f3cd74f86 100644 --- a/service/frontend/nexus_operation_http_handler.go +++ b/service/frontend/nexus_operation_http_handler.go @@ -15,12 +15,14 @@ import ( persistencespb "go.temporal.io/server/api/persistence/v1" "go.temporal.io/server/common/authorization" "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/headers" "go.temporal.io/server/common/log" "go.temporal.io/server/common/log/tag" "go.temporal.io/server/common/metrics" "go.temporal.io/server/common/namespace" commonnexus "go.temporal.io/server/common/nexus" "go.temporal.io/server/common/nexus/nexusrpc" + "go.temporal.io/server/common/nexus/principaltoken" "go.temporal.io/server/common/resource" "go.temporal.io/server/common/routing" "go.temporal.io/server/common/rpc" @@ -44,6 +46,7 @@ type NexusOperationHTTPHandler struct { namespaceRateLimitInterceptor interceptor.NamespaceRateLimitInterceptor namespaceConcurrencyLimitInterceptor *interceptor.ConcurrentRequestLimitInterceptor rateLimitInterceptor *interceptor.RateLimitInterceptor + principalKeys principaltoken.KeyProvider } func NewNexusOperationHTTPHandler( @@ -64,8 +67,11 @@ func NewNexusOperationHTTPHandler( rateLimitInterceptor *interceptor.RateLimitInterceptor, logger log.Logger, httpTraceProvider commonnexus.HTTPClientTraceProvider, + principalVerifier principaltoken.Verifier, + principalKeys principaltoken.KeyProvider, ) *NexusOperationHTTPHandler { return &NexusOperationHTTPHandler{ + principalKeys: principalKeys, base: nexusrpc.BaseHTTPHandler{ Logger: log.NewSlogLogger(logger), FailureConverter: nexusrpc.DefaultFailureConverter(), @@ -97,6 +103,7 @@ func NewNexusOperationHTTPHandler( useForwardByEndpoint: serviceConfig.NexusForwardRequestUseEndpoint, metricTagConfig: serviceConfig.NexusOperationsMetricTagConfig, httpTraceProvider: httpTraceProvider, + principalVerifier: principalVerifier, }, GetResultTimeout: serviceConfig.KeepAliveMaxConnectionIdle(), Logger: log.NewSlogLogger(logger), @@ -110,8 +117,20 @@ func (h *NexusOperationHTTPHandler) RegisterRoutes(r *mux.Router) { HandlerFunc(h.dispatchNexusTaskByNamespaceAndTaskQueue) r.PathPrefix("/" + commonnexus.RouteDispatchNexusTaskByEndpoint.Representation() + "/"). HandlerFunc(h.dispatchNexusTaskByEndpoint) + // Public JWKS endpoint exposing this cluster's principal-token verification + // keys, so peer clusters using signature trust can validate tokens this + // cluster mints. Public, unauthenticated (exposes only public keys); serves + // an empty key set when the feature is off. + if h.principalKeys != nil { + r.HandleFunc(principalTokenJWKSPath, principaltoken.JWKSHandler(h.principalKeys)). + Methods(http.MethodGet) + } } +// principalTokenJWKSPath is the public JWKS endpoint for Nexus principal-token +// verification keys, served on the frontend HTTP listener. +const principalTokenJWKSPath = "/_temporal/nexus/principal-token/jwks" + func (h *NexusOperationHTTPHandler) writeFailure(writer http.ResponseWriter, r *http.Request, err error) { h.preprocessErrorCounter.Record(1) h.base.WriteFailure(writer, r, err) @@ -119,6 +138,12 @@ func (h *NexusOperationHTTPHandler) writeFailure(writer http.ResponseWriter, r * // Handler for [nexushttp.RouteSet.DispatchNexusTaskByNamespaceAndTaskQueue]. func (h *NexusOperationHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w http.ResponseWriter, r *http.Request) { + // Strip principal-bearing HTTP headers before any handler code observes + // them. The caller-side server is the only authorized writer of these + // headers; if they arrived from an external caller they must be discarded + // before authorization, otherwise the trust model breaks. + headers.StripPrincipalHTTP(r.Header) + var err error nc := h.baseNexusContext(configs.DispatchNexusTaskByNamespaceAndTaskQueueAPIName, r.Header) params := prepareRequest(commonnexus.RouteDispatchNexusTaskByNamespaceAndTaskQueue, w, r) @@ -159,6 +184,10 @@ func (h *NexusOperationHTTPHandler) dispatchNexusTaskByNamespaceAndTaskQueue(w h // Handler for [nexushttp.RouteSet.DispatchNexusTaskByEndpoint]. func (h *NexusOperationHTTPHandler) dispatchNexusTaskByEndpoint(w http.ResponseWriter, r *http.Request) { + // Strip principal-bearing HTTP headers before any handler code observes + // them. See dispatchNexusTaskByNamespaceAndTaskQueue for rationale. + headers.StripPrincipalHTTP(r.Header) + endpointIDEscaped := prepareRequest(commonnexus.RouteDispatchNexusTaskByEndpoint, w, r) endpointID, err := url.PathUnescape(endpointIDEscaped) diff --git a/service/frontend/principal_token_interceptor.go b/service/frontend/principal_token_interceptor.go new file mode 100644 index 00000000000..1be1d65bedc --- /dev/null +++ b/service/frontend/principal_token_interceptor.go @@ -0,0 +1,59 @@ +package frontend + +import ( + "context" + + "go.temporal.io/server/common/headers" + "go.temporal.io/server/common/nexus/principaltoken" + "google.golang.org/grpc" +) + +// PrincipalTokenInterceptor verifies a forwarded principal token on an inbound +// gRPC request and promotes the end-user onto the context — the gRPC analogue of +// promotePrincipalToken — so a workflow started from a Nexus operation inherits +// the original end-user as RootCallerPrincipal. Must run after the auth +// interceptor (which strips spoofed principal headers); trust is the token's +// (signature, or trusted peer), so honoring it on the external edge is safe. +type PrincipalTokenInterceptor struct { + verifier principaltoken.Verifier +} + +// NewPrincipalTokenInterceptor constructs the interceptor. verifier may be nil +// (feature off), in which case Intercept is a pass-through. +func NewPrincipalTokenInterceptor(verifier principaltoken.Verifier) *PrincipalTokenInterceptor { + return &PrincipalTokenInterceptor{verifier: verifier} +} + +func (i *PrincipalTokenInterceptor) Intercept( + ctx context.Context, + req any, + _ *grpc.UnaryServerInfo, + handler grpc.UnaryHandler, +) (any, error) { + return handler(i.promote(ctx), req) +} + +// promote verifies the token (if any) and, on success, sets the end-user +// principal on the context. It deliberately does NOT touch the immediate-caller +// Principal: that was set by the auth interceptor from THIS RPC's authenticated +// identity (e.g. the worker), which is the correct immediate caller — the +// token's service caller describes the upstream Nexus hop, not this RPC. +func (i *PrincipalTokenInterceptor) promote(ctx context.Context) context.Context { + if i.verifier == nil { + return ctx + } + token := headers.GetValues(ctx, principaltoken.Header)[0] + if token == "" { + return ctx + } + verified, err := i.verifier.Verify(ctx, token) + if err != nil { + // An invalid/forged token is simply dropped; the request proceeds with + // no propagated end-user (equivalent to the feature being off). + return ctx + } + if verified.EndUser != nil { + ctx = headers.SetEndUserPrincipal(ctx, principalForDisplay(verified.EndUser, verified.EndUserResolvedName)) + } + return ctx +} diff --git a/service/history/api/describeworkflow/api.go b/service/history/api/describeworkflow/api.go index 938e007ed82..79affbdcda8 100644 --- a/service/history/api/describeworkflow/api.go +++ b/service/history/api/describeworkflow/api.go @@ -343,6 +343,12 @@ func Invoke( return nil, serviceerror.NewInternal("failed to construct describe response") } + // Surface the chain-originating identity from the CHASM workflow + // component. It is stored only on CHASM (not on WorkflowExecutionInfo + // at rest), so describe projects it here. Empty/nil when caller + // attribution is not configured or the workflow predates the feature. + result.WorkflowExecutionInfo.RootCallerPrincipal = wf.GetRootCallerPrincipal(chasmCtx) + outboundCB := func(endpoint string) bool { cb := outboundQueueCBPool.Get(tasks.TaskGroupNamespaceIDAndDestination{ TaskGroup: nexusoperation.TaskGroupName, diff --git a/service/history/api/respondworkflowtaskcompleted/workflow_task_completed_handler_test.go b/service/history/api/respondworkflowtaskcompleted/workflow_task_completed_handler_test.go index 431393d5113..d22ad891f5b 100644 --- a/service/history/api/respondworkflowtaskcompleted/workflow_task_completed_handler_test.go +++ b/service/history/api/respondworkflowtaskcompleted/workflow_task_completed_handler_test.go @@ -85,7 +85,7 @@ func TestCommandProtocolMessage(t *testing.T) { if opts.chasmEnabled { out.chasmWorkflowRegistry = chasmworkflow.NewRegistry() mockCtx := &chasm.MockMutableContext{} - wf := chasmworkflow.NewWorkflow(mockCtx, chasm.MSPointer{}) + wf := chasmworkflow.NewWorkflow(mockCtx, chasm.MSPointer{}, nil) out.ms.EXPECT().ChasmEnabled().Return(true).AnyTimes() out.ms.EXPECT().EnsureChasmWorkflowComponent(gomock.Any()).AnyTimes() out.ms.EXPECT().ChasmWorkflowComponent(gomock.Any()).Return(wf, mockCtx, nil) diff --git a/service/history/workflow/mutable_state_impl.go b/service/history/workflow/mutable_state_impl.go index a0b1f302229..7b4edf435b4 100644 --- a/service/history/workflow/mutable_state_impl.go +++ b/service/history/workflow/mutable_state_impl.go @@ -710,14 +710,28 @@ func (ms *MutableStateImpl) ChasmWorkflowComponent(ctx context.Context) (*chasmw func (ms *MutableStateImpl) EnsureChasmWorkflowComponent(ctx context.Context) { // Initialize chasm tree once for new workflows. - // Using context.Background() because this is done outside an actual request context and the - // chasmworkflow.NewWorkflow does not actually use it currently. root, ok := ms.chasmTree.(*chasm.Node) softassert.That(ms.logger, ok, "chasmTree cast failed") if root.ArchetypeID() == chasm.UnspecifiedArchetypeID { mutableContext := chasm.NewMutableContext(ctx, root) - if err := root.SetRootComponent(chasmworkflow.NewWorkflow(mutableContext, chasm.NewMSPointer(ms))); err != nil { + // Source the chain-originating principal from the inbound RPC, in + // priority order: + // 1. The propagated end-user principal, if a verified Nexus principal + // token was promoted onto the context (headers.GetEndUserPrincipal): + // a workflow started in response to a Nexus dispatch inherits the + // ORIGINAL end-user identity rather than the worker/service identity + // that issued the start RPC. + // 2. The inbound RPC's immediate-caller principal (top-level starts + // where the workflow chain originates with this RPC). + // + // Both may be nil; the chasm Workflow tolerates an empty + // RootCallerPrincipal as the graceful-degradation state. + rootCallerPrincipal := headers.GetEndUserPrincipal(ctx) + if rootCallerPrincipal == nil { + rootCallerPrincipal = headers.GetPrincipal(ctx) + } + if err := root.SetRootComponent(chasmworkflow.NewWorkflow(mutableContext, chasm.NewMSPointer(ms), rootCallerPrincipal)); err != nil { softassert.Fail(ms.logger, "SetRootComponent failed", tag.Error(err)) } } @@ -7697,6 +7711,13 @@ func (ms *MutableStateImpl) closeTransaction( if event.Principal == nil { event.Principal = principal } + // The chain-originating principal (RootCallerPrincipal) is + // now stored on the chasm Workflow component rather than + // on WorkflowExecutionInfo. See + // chasm/lib/workflow.NewWorkflow — it captures the inbound + // RPC's principal at component creation time, and child / + // continue-as-new paths inherit from their predecessor. + // No additional stamping work is needed here. } } for _, event := range bufferEvents { diff --git a/service/matching/matching_engine.go b/service/matching/matching_engine.go index d8c3ebc7e16..8c517c6fe78 100644 --- a/service/matching/matching_engine.go +++ b/service/matching/matching_engine.go @@ -2619,6 +2619,9 @@ pollLoop: TaskToken: serializedToken, Request: nexusReq, PollerScalingDecision: task.pollerScalingDecision, + // Verified caller identity carried with the task; surfaced to the + // handler worker so it can authorize/audit the caller. + CallerInfo: task.nexus.request.GetCallerInfo(), }, }, nil } diff --git a/tests/nexus_propagation_test.go b/tests/nexus_propagation_test.go new file mode 100644 index 00000000000..cdab9d5af76 --- /dev/null +++ b/tests/nexus_propagation_test.go @@ -0,0 +1,391 @@ +package tests + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + commonpb "go.temporal.io/api/common/v1" + enumspb "go.temporal.io/api/enums/v1" + nexuspb "go.temporal.io/api/nexus/v1" + "go.temporal.io/api/operatorservice/v1" + taskqueuepb "go.temporal.io/api/taskqueue/v1" + "go.temporal.io/api/workflowservice/v1" + sdkclient "go.temporal.io/sdk/client" + sdkworker "go.temporal.io/sdk/worker" + "go.temporal.io/sdk/workflow" + "go.temporal.io/server/common/authorization" + "go.temporal.io/server/common/cluster" + "go.temporal.io/server/common/config" + "go.temporal.io/server/common/dynamicconfig" + "go.temporal.io/server/common/log" + cnexus "go.temporal.io/server/common/nexus" + "go.temporal.io/server/common/nexus/principaltoken" + "go.temporal.io/server/common/primitives" + "go.temporal.io/server/common/testing/await" + "go.temporal.io/server/tests/testcore" + "go.uber.org/fx" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/durationpb" +) + +const ( + propIssuer = "test-issuer" + propKID = "test-kid" + propCallerSA = "nexus-caller" + propEndUser = "end-user-alice" + propJWKSPath = "/_temporal/nexus/principal-token/jwks" + propTestDuration = 60 * time.Second +) + +// TestNexusPrincipalPropagation_TwoClusters exercises signed caller-identity +// propagation across the Nexus hop over two real, independent full clusters +// (built on temporal/temporal via testcore — not TestEnv, no xdc replication): +// +// caller cluster (mint) ── HTTP ──▶ handler cluster (verify) ──▶ fake worker +// +// It runs both verifier trust modes: +// - signature: the handler verifies the JWS against the caller's public key. +// - transport: the handler trusts the connection peer (cell-mTLS analogue) +// and reads the claims without a signature check. +// +// It asserts the propagated token carries BOTH caller identities kept distinct: +// the service caller (the worker that issued the ScheduleNexusOperation command) +// and the end user (the identity that started the root workflow, sourced from the +// persisted RootCallerPrincipal — not carried through the worker). It then +// asserts those identities surface to the handler worker as caller_principals on +// the poll response, and that the handler's JWKS endpoint serves its public key. +// +// The feature is configured the official way (config.Global.NexusPrincipalPropagation +// on TestClusterConfig), exercising the production config mapper. +func TestNexusPrincipalPropagation_TwoClusters(t *testing.T) { + for _, mode := range []principaltoken.TrustMode{principaltoken.TrustModeSignature, principaltoken.TrustModeTransport} { + t.Run(string(mode), func(t *testing.T) { + runPropagation(t, mode) + }) + } +} + +func runPropagation(t *testing.T, mode principaltoken.TrustMode) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + privDER, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + pubDER, err := x509.MarshalPKIXPublicKey(&key.PublicKey) + require.NoError(t, err) + + // Both clusters share the signing key (the caller signs; in signature mode + // the handler trusts that key, in transport mode it trusts the peer). + global := config.Global{ + NexusPrincipalPropagation: config.NexusPrincipalPropagation{ + Issuer: propIssuer, + SigningKeyID: propKID, + TrustMode: string(mode), + SigningKeyData: string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER})), + TrustedIssuers: []config.NexusTrustedIssuer{ + {Issuer: propIssuer, KeyID: propKID, PublicKeyData: string(pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: pubDER}))}, + }, + }, + } + dc := map[dynamicconfig.Key]any{ + dynamicconfig.EnablePrincipalPropagation.Key(): true, + dynamicconfig.RefreshNexusEndpointsMinWait.Key(): time.Millisecond, + // RootCallerPrincipal lives on the CHASM workflow component, so the + // handler-started workflow must be CHASM-backed for the attach assertion. + dynamicconfig.EnableChasm.Key(): true, + } + + // In transport mode the handler frontend trusts the connection peer. Tests + // have no cell mTLS, so override the (deny-all) default PeerTrustFunc with a + // trust-all one on the handler frontend. + var handlerFx map[primitives.ServiceName][]fx.Option + if mode == principaltoken.TrustModeTransport { + handlerFx = map[primitives.ServiceName][]fx.Option{ + primitives.FrontendService: {fx.Decorate(func(principaltoken.PeerTrustFunc) principaltoken.PeerTrustFunc { + return func(context.Context) bool { return true } + })}, + } + } + + caller := startPropagationCluster(t, "caller", global, dc, nil) + defer func() { _ = caller.TearDownCluster() }() + handler := startPropagationCluster(t, "handler", global, dc, handlerFx) + defer func() { _ = handler.TearDownCluster() }() + + ctx, cancel := context.WithTimeout(testcore.NewContext(), propTestDuration) + defer cancel() + + ns := "nexus-prop-" + uuid.NewString()[:8] + registerPropagationNamespace(ctx, t, caller, ns) + registerPropagationNamespace(ctx, t, handler, ns) + + // Two distinct identities so the propagated token can be asserted to keep + // them separate: the end user starts the root workflow (becomes its + // RootCallerPrincipal), while the worker that later completes the workflow + // task and issues the ScheduleNexusOperation command is the service caller. + // The authorizer attributes by API: StartWorkflowExecution -> end user; + // everything else (poll/respond task) -> worker. If the caller-side end-user + // sourcing were broken and fell back to the immediate caller, EndUser would + // equal the worker and the assertion below would fail. + endUserPrincipal := &commonpb.Principal{Type: "users", Name: propEndUser} + workerPrincipal := &commonpb.Principal{Type: "service-accounts", Name: propCallerSA} + caller.Host().SetOnAuthorize(func(_ context.Context, _ *authorization.Claims, ct *authorization.CallTarget) (authorization.Result, error) { + p := workerPrincipal + if strings.HasSuffix(ct.APIName, "/StartWorkflowExecution") { + p = endUserPrincipal + } + return authorization.Result{Decision: authorization.DecisionAllow, Principal: p}, nil + }) + handler.Host().SetOnAuthorize(func(context.Context, *authorization.Claims, *authorization.CallTarget) (authorization.Result, error) { + return authorization.Result{Decision: authorization.DecisionAllow}, nil + }) + + handlerTaskQueue := "handler-tq-" + uuid.NewString() + dispatchURL := fmt.Sprintf("http://%s/%s", + handler.Host().FrontendHTTPAddress(), + cnexus.RouteDispatchNexusTaskByNamespaceAndTaskQueue.Path(cnexus.NamespaceAndTaskQueue{Namespace: ns, TaskQueue: handlerTaskQueue}), + ) + endpointName := "ep-" + uuid.NewString() + _, err = caller.OperatorClient().CreateNexusEndpoint(ctx, &operatorservice.CreateNexusEndpointRequest{ + Spec: &nexuspb.EndpointSpec{ + Name: endpointName, + Target: &nexuspb.EndpointTarget{ + Variant: &nexuspb.EndpointTarget_External_{External: &nexuspb.EndpointTarget_External{Url: dispatchURL}}, + }, + }, + }) + require.NoError(t, err) + + // Mock handler worker. It plays the role the SDK would: it (1) observes the + // server-attributed caller_info on the poll response, and (2) forwards the + // received principal token on a StartWorkflowExecution, so we can show the + // server re-validates it and attaches the original end-user to that workflow + // (the cross-user-code-boundary leg) — all without changing the SDK. + tokenCh := make(chan string, 1) + callerInfoCh := make(chan *nexuspb.NexusCallerInfo, 1) + startedWFCh := make(chan string, 1) + go pollNexusTaskOnce(ctx, t, handler.FrontendClient(), ns, handlerTaskQueue, + func(res *workflowservice.PollNexusTaskQueueResponse) *nexuspb.Response { + select { + case tokenCh <- res.GetRequest().GetHeader()[principaltoken.Header]: + callerInfoCh <- res.GetCallerInfo() + // Forward the token on a workflow start, as a real SDK handler + // would; the server re-validates it and seeds RootCallerPrincipal. + wfID := "handler-wf-" + uuid.NewString() + fwdCtx := metadata.AppendToOutgoingContext(ctx, principaltoken.Header, + res.GetRequest().GetHeader()[principaltoken.Header]) + _, startErr := handler.FrontendClient().StartWorkflowExecution(fwdCtx, &workflowservice.StartWorkflowExecutionRequest{ + Namespace: ns, + WorkflowId: wfID, + WorkflowType: &commonpb.WorkflowType{Name: "handler-wf"}, + TaskQueue: &taskqueuepb.TaskQueue{Name: "handler-wf-tq", Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, + RequestId: uuid.NewString(), + }) + assert.NoError(t, startErr) + startedWFCh <- wfID + default: + } + return &nexuspb.Response{Variant: &nexuspb.Response_StartOperation{ + StartOperation: &nexuspb.StartOperationResponse{Variant: &nexuspb.StartOperationResponse_SyncSuccess{ + SyncSuccess: &nexuspb.StartOperationResponse_Sync{Payload: res.GetRequest().GetStartOperation().GetPayload()}, + }}, + }} + }) + + callerWF := func(wctx workflow.Context) (string, error) { + c := workflow.NewNexusClient(endpointName, "test-service") + var result string + return result, c.ExecuteOperation(wctx, "test-operation", "input", workflow.NexusOperationOptions{}).Get(wctx, &result) + } + sdkClient, err := sdkclient.Dial(sdkclient.Options{HostPort: caller.Host().FrontendGRPCAddress(), Namespace: ns}) + require.NoError(t, err) + defer sdkClient.Close() + callerTaskQueue := "caller-tq-" + uuid.NewString() + w := sdkworker.New(sdkClient, callerTaskQueue, sdkworker.Options{}) + w.RegisterWorkflow(callerWF) + require.NoError(t, w.Start()) + defer w.Stop() + _, err = sdkClient.ExecuteWorkflow(ctx, sdkclient.StartWorkflowOptions{TaskQueue: callerTaskQueue}, callerWF) + require.NoError(t, err) + + var token string + select { + case token = <-tokenCh: + case <-time.After(30 * time.Second): + t.Fatal("handler did not receive a Nexus task with a principal token") + } + require.NotEmpty(t, token, "expected a signed principal token propagated to the handler") + + // Verify the propagated token the same way the handler's configured verifier + // would, and assert it carries the caller identity. + verified, err := propagationVerifier(t, mode, key).Verify(ctx, token) + require.NoError(t, err, "propagated token must verify under %s trust mode", mode) + + // Service caller is the worker that issued the ScheduleNexusOperation command. + require.NotNil(t, verified.ServiceCaller) + require.Equal(t, "service-accounts", verified.ServiceCaller.GetType()) + require.Equal(t, propCallerSA, verified.ServiceCaller.GetName()) + + // End user is the identity that started the root workflow, sourced from the + // workflow's persisted RootCallerPrincipal (NOT carried through the worker). + // It must be present and distinct from the service caller. + require.NotNil(t, verified.EndUser, "end-user principal must propagate (from RootCallerPrincipal)") + require.Equal(t, "users", verified.EndUser.GetType()) + require.Equal(t, propEndUser, verified.EndUser.GetName()) + require.NotEqual(t, verified.ServiceCaller.GetName(), verified.EndUser.GetName(), + "end-user must be distinct from the service caller, not a fallback to the immediate caller") + + // The worker-observable caller_info on the poll response must carry the same + // verified identities (service caller + originating end user). This is the + // end-to-end check that the server attributes and surfaces the identity to + // handler workers (independent of the raw token). + var callerInfo *nexuspb.NexusCallerInfo + select { + case callerInfo = <-callerInfoCh: + case <-time.After(5 * time.Second): + t.Fatal("handler poll response did not include caller_info") + } + require.NotNil(t, callerInfo, "expected caller_info on the poll response") + require.Equal(t, "service-accounts", callerInfo.GetService().GetType()) + require.Equal(t, propCallerSA, callerInfo.GetService().GetName()) + require.Equal(t, "users", callerInfo.GetRoot().GetType()) + require.Equal(t, propEndUser, callerInfo.GetRoot().GetName()) + + // Cross-boundary leg: the workflow the mock worker started by forwarding the + // token must inherit the ORIGINAL end-user as its RootCallerPrincipal (the + // server re-validated the token and attached it), surfaced on describe. + var startedWF string + select { + case startedWF = <-startedWFCh: + case <-time.After(10 * time.Second): + t.Fatal("mock worker did not start a handler workflow") + } + desc, err := handler.FrontendClient().DescribeWorkflowExecution(ctx, &workflowservice.DescribeWorkflowExecutionRequest{ + Namespace: ns, + Execution: &commonpb.WorkflowExecution{WorkflowId: startedWF}, + }) + require.NoError(t, err) + rootCaller := desc.GetWorkflowExecutionInfo().GetRootCallerPrincipal() + require.NotNil(t, rootCaller, "handler-started workflow must inherit the original end-user") + require.Equal(t, "users", rootCaller.GetType()) + require.Equal(t, propEndUser, rootCaller.GetName(), + "the forwarded token's end-user must be re-validated and attached as RootCallerPrincipal") + + // The handler cluster's JWKS endpoint must serve its public verification key. + assertJWKSServed(t, handler.Host().FrontendHTTPAddress()) +} + +// propagationVerifier builds the verifier the handler uses for the given mode: +// signature → JWS against the public key; transport → trust-all peer. +func propagationVerifier(t *testing.T, mode principaltoken.TrustMode, key *ecdsa.PrivateKey) principaltoken.Verifier { + if mode == principaltoken.TrustModeTransport { + return principaltoken.NewTransportVerifier(func(context.Context) bool { return true }) + } + return principaltoken.NewJWSVerifier(principaltoken.JWSVerifierOptions{ + Keys: principaltoken.NewStaticKeyProvider( + map[string]map[string]crypto.PublicKey{propIssuer: {propKID: key.Public()}}, nil), + }) +} + +func assertJWKSServed(t *testing.T, frontendHTTPAddr string) { + resp, err := http.Get("http://" + frontendHTTPAddr + propJWKSPath) //nolint:noctx + require.NoError(t, err) + defer func() { _ = resp.Body.Close() }() + require.Equal(t, http.StatusOK, resp.StatusCode) + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Contains(t, string(body), propKID, "JWKS must advertise the cluster signing key") + require.Contains(t, string(body), "\"kty\":\"EC\"") +} + +// startPropagationCluster brings up a single independent full cluster (built on +// temporal/temporal via testcore), with the official Global config, dynamic +// config, and optional per-service fx options. No replication — the clusters +// only interact over the Nexus HTTP hop. +func startPropagationCluster( + t *testing.T, + name string, + global config.Global, + dc map[dynamicconfig.Key]any, + fxOpts map[primitives.ServiceName][]fx.Option, +) *testcore.TestCluster { + cfg := &testcore.TestClusterConfig{ + ClusterMetadata: cluster.Config{ + EnableGlobalNamespace: false, + FailoverVersionIncrement: 10, + MasterClusterName: name, + CurrentClusterName: name, + ClusterInformation: map[string]cluster.ClusterInformation{ + name: {Enabled: true, InitialFailoverVersion: 1}, + }, + }, + HistoryConfig: testcore.HistoryConfig{NumHistoryShards: 1}, + Persistence: testcore.GetPersistenceTestDefaults(), + GlobalConfig: global, + DynamicConfigOverrides: dc, + ServiceFxOptions: fxOpts, + EnableMetricsCapture: true, + } + cfg.Persistence.DBName += "_" + name + c, err := testcore.NewTestClusterFactory().NewCluster(t, cfg, log.NewTestLogger()) + require.NoError(t, err) + return c +} + +func registerPropagationNamespace(ctx context.Context, t *testing.T, c *testcore.TestCluster, ns string) { + _, err := c.FrontendClient().RegisterNamespace(ctx, &workflowservice.RegisterNamespaceRequest{ + Namespace: ns, + WorkflowExecutionRetentionPeriod: durationpb.New(24 * time.Hour), + }) + require.NoError(t, err) + await.RequireTruef(t, func() bool { + _, err := c.FrontendClient().DescribeNamespace(ctx, &workflowservice.DescribeNamespaceRequest{Namespace: ns}) + return err == nil + }, 15*time.Second, 100*time.Millisecond, "namespace %q did not become available", ns) +} + +// pollNexusTaskOnce is a minimal taskpoller "fake worker": it polls a single +// Nexus task and responds, giving the test direct access to the forwarded +// request headers (where the propagated token lands) without an SDK worker. +func pollNexusTaskOnce( + ctx context.Context, + t *testing.T, + frontendClient workflowservice.WorkflowServiceClient, + ns, taskQueue string, + handler func(*workflowservice.PollNexusTaskQueueResponse) *nexuspb.Response, +) { + res, err := frontendClient.PollNexusTaskQueue(ctx, &workflowservice.PollNexusTaskQueueRequest{ + Namespace: ns, + Identity: uuid.NewString(), + TaskQueue: &taskqueuepb.TaskQueue{Name: taskQueue, Kind: enumspb.TASK_QUEUE_KIND_NORMAL}, + }) + if ctx.Err() != nil { + return + } + // assert (not require) inside this goroutine: require's FailNow must only run + // in the test's own goroutine. + assert.NoError(t, err) + resp := handler(res) + _, err = frontendClient.RespondNexusTaskCompleted(ctx, &workflowservice.RespondNexusTaskCompletedRequest{ + Namespace: ns, + Identity: uuid.NewString(), + TaskToken: res.TaskToken, + Response: resp, + }) + if err != nil && ctx.Err() == nil { + assert.NoError(t, err) + } +} diff --git a/tests/testcore/onebox.go b/tests/testcore/onebox.go index 379ff74d021..ea5128236e2 100644 --- a/tests/testcore/onebox.go +++ b/tests/testcore/onebox.go @@ -91,6 +91,7 @@ type ( testHooks testhooks.TestHooks logger log.Logger clusterMetadataConfig *cluster.Config + globalConfig config.Global persistenceConfig config.Persistence metadataMgr persistence.MetadataManager clusterMetadataMgr persistence.ClusterMetadataManager @@ -175,9 +176,13 @@ type ( MockAdminClient map[string]adminservice.AdminServiceClient NamespaceReplicationTaskExecutor nsreplication.TaskExecutor DCRedirectionPolicy config.DCRedirectionPolicy - DynamicConfigOverrides map[dynamicconfig.Key]any - TLSConfigProvider *encryption.FixedTLSConfigProvider - CaptureMetricsHandler *metricstest.CaptureHandler + // GlobalConfig is merged into each service's config.Global, so tests can + // drive process-wide settings (e.g. NexusPrincipalPropagation) through + // the official server config rather than fx injection. + GlobalConfig config.Global + DynamicConfigOverrides map[dynamicconfig.Key]any + TLSConfigProvider *encryption.FixedTLSConfigProvider + CaptureMetricsHandler *metricstest.CaptureHandler // ServiceFxOptions is populated by WithFxOptionsForService. ServiceFxOptions map[primitives.ServiceName][]fx.Option TaskCategoryRegistry tasks.TaskCategoryRegistry @@ -197,6 +202,7 @@ func newTemporal(t *testing.T, params *TemporalParams) *TemporalImpl { impl := &TemporalImpl{ logger: params.Logger, clusterMetadataConfig: params.ClusterMetadataConfig, + globalConfig: params.GlobalConfig, persistenceConfig: params.PersistenceConfig, metadataMgr: params.MetadataMgr, clusterMetadataMgr: params.ClusterMetadataManager, @@ -763,6 +769,7 @@ func (c *TemporalImpl) frontendConfigProvider() *config.Config { func (c *TemporalImpl) configProvider(serviceName primitives.ServiceName) *config.Config { return &config.Config{ + Global: c.globalConfig, Services: map[string]config.Service{ string(serviceName): { RPC: config.RPC{}, diff --git a/tests/testcore/test_cluster.go b/tests/testcore/test_cluster.go index 49c0261ed7e..456bb49af2a 100644 --- a/tests/testcore/test_cluster.go +++ b/tests/testcore/test_cluster.go @@ -76,9 +76,13 @@ type ( // TestClusterConfig are config for a test cluster TestClusterConfig struct { - EnableArchival bool - IsMasterCluster bool - ClusterMetadata cluster.Config + EnableArchival bool + IsMasterCluster bool + ClusterMetadata cluster.Config + // GlobalConfig populates config.Global for every service in the cluster, + // letting tests configure process-wide settings (e.g. + // NexusPrincipalPropagation) via the official server config. + GlobalConfig config.Global Persistence persistencetests.TestBaseOptions FrontendConfig FrontendConfig HistoryConfig HistoryConfig @@ -348,6 +352,7 @@ func newClusterWithPersistenceTestBaseFactory( MockAdminClient: clusterConfig.MockAdminClient, NamespaceReplicationTaskExecutor: nsreplication.NewTaskExecutor(clusterConfig.ClusterMetadata.CurrentClusterName, testBase.MetadataManager, nsreplication.NewNoopDataMerger(), nsreplication.NewDefaultAdmitter(), logger, testhooks.TestHooks{}), DCRedirectionPolicy: clusterConfig.DCRedirectionPolicy, + GlobalConfig: clusterConfig.GlobalConfig, DynamicConfigOverrides: clusterConfig.DynamicConfigOverrides, TLSConfigProvider: tlsConfigProvider, ServiceFxOptions: clusterConfig.ServiceFxOptions,