Skip to content
48 changes: 46 additions & 2 deletions server/router/api/v1/idp_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,17 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
response := &v1pb.ListIdentityProvidersResponse{
IdentityProviders: []*v1pb.IdentityProvider{},
}

// Default to lowest-privilege role, update later based on real role
currentUserRole := store.RoleUser
currentUser, err := s.GetCurrentUser(ctx)
if err == nil && currentUser != nil {
currentUserRole = currentUser.Role
}

for _, identityProvider := range identityProviders {
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
}
return response, nil
}
Expand All @@ -58,10 +67,27 @@ func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.Ge
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
return convertIdentityProviderFromStore(identityProvider), nil

// Default to lowest-privilege role, update later based on real role
currentUserRole := store.RoleUser
currentUser, err := s.GetCurrentUser(ctx)
if err == nil && currentUser != nil {
currentUserRole = currentUser.Role
}

identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
}

func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil || currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}

if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
Expand Down Expand Up @@ -95,6 +121,14 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
}

func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
currentUser, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser == nil || currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}

id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
Expand Down Expand Up @@ -183,3 +217,13 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv
}
return nil
}

func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
if userRole != store.RoleHost {
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
identityProvider.Config.GetOauth2Config().ClientSecret = ""
}
}

return identityProvider
}
10 changes: 10 additions & 0 deletions server/router/api/v1/memo_attachment_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ import (
)

func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
Expand All @@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
MemoID: &memo.ID,
})
Expand Down
10 changes: 10 additions & 0 deletions server/router/api/v1/memo_relation_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ import (
)

func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
memoUID, err := ExtractMemoUIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
Expand All @@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo")
}
if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
referenceType := store.MemoRelationReference
// Delete all reference relations first.
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
Expand Down
24 changes: 24 additions & 0 deletions server/router/api/v1/reaction_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,35 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
}

func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}

reactionID, err := ExtractReactionIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
}

// Get reaction and check ownership
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
ID: &reactionID,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
if len(reactions) == 0 {
return nil, status.Errorf(codes.NotFound, "reaction not found")
}

reaction := reactions[0]
if reaction.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}

if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
ID: reactionID,
}); err != nil {
Expand Down
41 changes: 37 additions & 4 deletions server/router/api/v1/test/idp_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ func TestGetIdentityProvider(t *testing.T) {
Name: created.Name,
}

// Test unauthenticated, should not contain client secret
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
require.NoError(t, err)
require.NotNil(t, resp)
Expand All @@ -241,7 +242,18 @@ func TestGetIdentityProvider(t *testing.T) {
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
require.NotNil(t, resp.Config.GetOauth2Config())
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret)
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)

// Test as host user, should contain client secret
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
require.NoError(t, err)
require.NotNil(t, respHostUser)
require.Equal(t, created.Name, respHostUser.Name)
require.Equal(t, "Test Provider", respHostUser.Title)
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
require.NotNil(t, respHostUser.Config.GetOauth2Config())
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
})

t.Run("GetIdentityProvider not found", func(t *testing.T) {
Expand Down Expand Up @@ -353,14 +365,21 @@ func TestUpdateIdentityProvider(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()

// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)

// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)

req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "identityProviders/1",
Title: "Updated Provider",
},
}

_, err := ts.Service.UpdateIdentityProvider(ctx, req)
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "update_mask is required")
})
Expand All @@ -369,6 +388,13 @@ func TestUpdateIdentityProvider(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()

// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)

// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)

req := &v1pb.UpdateIdentityProviderRequest{
IdentityProvider: &v1pb.IdentityProvider{
Name: "invalid-name",
Expand All @@ -379,7 +405,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
},
}

_, err := ts.Service.UpdateIdentityProvider(ctx, req)
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
Expand Down Expand Up @@ -445,11 +471,18 @@ func TestDeleteIdentityProvider(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()

// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)

// Set user context
userCtx := ts.CreateUserContext(ctx, hostUser.ID)

req := &v1pb.DeleteIdentityProviderRequest{
Name: "invalid-name",
}

_, err := ts.Service.DeleteIdentityProvider(ctx, req)
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid identity provider name")
})
Expand Down
11 changes: 11 additions & 0 deletions server/router/api/v1/user_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
// Unauthenticated or non-HOST users can only create normal users
roleToAssign = store.RoleUser
}

// Only allow user registration if it is enabled in the settings, or if the user is a superuser
if currentUser == nil || !isSuperUser(currentUser) {
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
}
if workspaceGeneralSetting.DisallowUserRegistration {
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
}
}
}

if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {
Expand Down