Skip to content

Commit 769dcd0

Browse files
authored
fix(security): add missing authorization checks to various services (#5217)
1 parent df93120 commit 769dcd0

File tree

6 files changed

+138
-6
lines changed

6 files changed

+138
-6
lines changed

server/router/api/v1/idp_service.go

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,17 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
3838
response := &v1pb.ListIdentityProvidersResponse{
3939
IdentityProviders: []*v1pb.IdentityProvider{},
4040
}
41+
42+
// Default to lowest-privilege role, update later based on real role
43+
currentUserRole := store.RoleUser
44+
currentUser, err := s.GetCurrentUser(ctx)
45+
if err == nil && currentUser != nil {
46+
currentUserRole = currentUser.Role
47+
}
48+
4149
for _, identityProvider := range identityProviders {
42-
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
50+
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
51+
response.IdentityProviders = append(response.IdentityProviders, redactIdentityProviderResponse(identityProviderConverted, currentUserRole))
4352
}
4453
return response, nil
4554
}
@@ -58,10 +67,27 @@ func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.Ge
5867
if identityProvider == nil {
5968
return nil, status.Errorf(codes.NotFound, "identity provider not found")
6069
}
61-
return convertIdentityProviderFromStore(identityProvider), nil
70+
71+
// Default to lowest-privilege role, update later based on real role
72+
currentUserRole := store.RoleUser
73+
currentUser, err := s.GetCurrentUser(ctx)
74+
if err == nil && currentUser != nil {
75+
currentUserRole = currentUser.Role
76+
}
77+
78+
identityProviderConverted := convertIdentityProviderFromStore(identityProvider)
79+
return redactIdentityProviderResponse(identityProviderConverted, currentUserRole), nil
6280
}
6381

6482
func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb.UpdateIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
83+
currentUser, err := s.GetCurrentUser(ctx)
84+
if err != nil {
85+
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
86+
}
87+
if currentUser == nil || currentUser.Role != store.RoleHost {
88+
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
89+
}
90+
6591
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
6692
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
6793
}
@@ -95,6 +121,14 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
95121
}
96122

97123
func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb.DeleteIdentityProviderRequest) (*emptypb.Empty, error) {
124+
currentUser, err := s.GetCurrentUser(ctx)
125+
if err != nil {
126+
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
127+
}
128+
if currentUser == nil || currentUser.Role != store.RoleHost {
129+
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
130+
}
131+
98132
id, err := ExtractIdentityProviderIDFromName(request.Name)
99133
if err != nil {
100134
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
@@ -183,3 +217,13 @@ func convertIdentityProviderConfigToStore(identityProviderType v1pb.IdentityProv
183217
}
184218
return nil
185219
}
220+
221+
func redactIdentityProviderResponse(identityProvider *v1pb.IdentityProvider, userRole store.Role) *v1pb.IdentityProvider {
222+
if userRole != store.RoleHost {
223+
if identityProvider.Type == v1pb.IdentityProvider_OAUTH2 {
224+
identityProvider.Config.GetOauth2Config().ClientSecret = ""
225+
}
226+
}
227+
228+
return identityProvider
229+
}

server/router/api/v1/memo_attachment_service.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ import (
1414
)
1515

1616
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
17+
user, err := s.GetCurrentUser(ctx)
18+
if err != nil {
19+
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
20+
}
21+
if user == nil {
22+
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
23+
}
1724
memoUID, err := ExtractMemoUIDFromName(request.Name)
1825
if err != nil {
1926
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
@@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
2229
if err != nil {
2330
return nil, status.Errorf(codes.Internal, "failed to get memo")
2431
}
32+
if memo.CreatorID != user.ID && !isSuperUser(user) {
33+
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
34+
}
2535
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
2636
MemoID: &memo.ID,
2737
})

server/router/api/v1/memo_relation_service.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ import (
1414
)
1515

1616
func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMemoRelationsRequest) (*emptypb.Empty, error) {
17+
user, err := s.GetCurrentUser(ctx)
18+
if err != nil {
19+
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
20+
}
21+
if user == nil {
22+
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
23+
}
1724
memoUID, err := ExtractMemoUIDFromName(request.Name)
1825
if err != nil {
1926
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
@@ -22,6 +29,9 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
2229
if err != nil {
2330
return nil, status.Errorf(codes.Internal, "failed to get memo")
2431
}
32+
if memo.CreatorID != user.ID && !isSuperUser(user) {
33+
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
34+
}
2535
referenceType := store.MemoRelationReference
2636
// Delete all reference relations first.
2737
if err := s.Store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{

server/router/api/v1/reaction_service.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,35 @@ func (s *APIV1Service) UpsertMemoReaction(ctx context.Context, request *v1pb.Ups
5555
}
5656

5757
func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.DeleteMemoReactionRequest) (*emptypb.Empty, error) {
58+
user, err := s.GetCurrentUser(ctx)
59+
if err != nil {
60+
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
61+
}
62+
if user == nil {
63+
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
64+
}
65+
5866
reactionID, err := ExtractReactionIDFromName(request.Name)
5967
if err != nil {
6068
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
6169
}
6270

71+
// Get reaction and check ownership
72+
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
73+
ID: &reactionID,
74+
})
75+
if err != nil {
76+
return nil, status.Errorf(codes.Internal, "failed to list reactions")
77+
}
78+
if len(reactions) == 0 {
79+
return nil, status.Errorf(codes.NotFound, "reaction not found")
80+
}
81+
82+
reaction := reactions[0]
83+
if reaction.CreatorID != user.ID && !isSuperUser(user) {
84+
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
85+
}
86+
6387
if err := s.Store.DeleteReaction(ctx, &store.DeleteReaction{
6488
ID: reactionID,
6589
}); err != nil {

server/router/api/v1/test/idp_service_test.go

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ func TestGetIdentityProvider(t *testing.T) {
233233
Name: created.Name,
234234
}
235235

236+
// Test unauthenticated, should not contain client secret
236237
resp, err := ts.Service.GetIdentityProvider(ctx, getReq)
237238
require.NoError(t, err)
238239
require.NotNil(t, resp)
@@ -241,7 +242,18 @@ func TestGetIdentityProvider(t *testing.T) {
241242
require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type)
242243
require.NotNil(t, resp.Config.GetOauth2Config())
243244
require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId)
244-
require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret)
245+
require.Equal(t, "", resp.Config.GetOauth2Config().ClientSecret)
246+
247+
// Test as host user, should contain client secret
248+
respHostUser, err := ts.Service.GetIdentityProvider(userCtx, getReq)
249+
require.NoError(t, err)
250+
require.NotNil(t, respHostUser)
251+
require.Equal(t, created.Name, respHostUser.Name)
252+
require.Equal(t, "Test Provider", respHostUser.Title)
253+
require.Equal(t, v1pb.IdentityProvider_OAUTH2, respHostUser.Type)
254+
require.NotNil(t, respHostUser.Config.GetOauth2Config())
255+
require.Equal(t, "test-client", respHostUser.Config.GetOauth2Config().ClientId)
256+
require.Equal(t, "test-secret", respHostUser.Config.GetOauth2Config().ClientSecret)
245257
})
246258

247259
t.Run("GetIdentityProvider not found", func(t *testing.T) {
@@ -353,14 +365,21 @@ func TestUpdateIdentityProvider(t *testing.T) {
353365
ts := NewTestService(t)
354366
defer ts.Cleanup()
355367

368+
// Create host user
369+
hostUser, err := ts.CreateHostUser(ctx, "admin")
370+
require.NoError(t, err)
371+
372+
// Set user context
373+
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
374+
356375
req := &v1pb.UpdateIdentityProviderRequest{
357376
IdentityProvider: &v1pb.IdentityProvider{
358377
Name: "identity-providers/1",
359378
Title: "Updated Provider",
360379
},
361380
}
362381

363-
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
382+
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
364383
require.Error(t, err)
365384
require.Contains(t, err.Error(), "update_mask is required")
366385
})
@@ -369,6 +388,13 @@ func TestUpdateIdentityProvider(t *testing.T) {
369388
ts := NewTestService(t)
370389
defer ts.Cleanup()
371390

391+
// Create host user
392+
hostUser, err := ts.CreateHostUser(ctx, "admin")
393+
require.NoError(t, err)
394+
395+
// Set user context
396+
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
397+
372398
req := &v1pb.UpdateIdentityProviderRequest{
373399
IdentityProvider: &v1pb.IdentityProvider{
374400
Name: "invalid-name",
@@ -379,7 +405,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
379405
},
380406
}
381407

382-
_, err := ts.Service.UpdateIdentityProvider(ctx, req)
408+
_, err = ts.Service.UpdateIdentityProvider(userCtx, req)
383409
require.Error(t, err)
384410
require.Contains(t, err.Error(), "invalid identity provider name")
385411
})
@@ -445,11 +471,18 @@ func TestDeleteIdentityProvider(t *testing.T) {
445471
ts := NewTestService(t)
446472
defer ts.Cleanup()
447473

474+
// Create host user
475+
hostUser, err := ts.CreateHostUser(ctx, "admin")
476+
require.NoError(t, err)
477+
478+
// Set user context
479+
userCtx := ts.CreateUserContext(ctx, hostUser.ID)
480+
448481
req := &v1pb.DeleteIdentityProviderRequest{
449482
Name: "invalid-name",
450483
}
451484

452-
_, err := ts.Service.DeleteIdentityProvider(ctx, req)
485+
_, err = ts.Service.DeleteIdentityProvider(userCtx, req)
453486
require.Error(t, err)
454487
require.Contains(t, err.Error(), "invalid identity provider name")
455488
})

server/router/api/v1/user_service.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,17 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
169169
// Unauthenticated or non-HOST users can only create normal users
170170
roleToAssign = store.RoleUser
171171
}
172+
173+
// Only allow user registration if it is enabled in the settings, or if the user is a superuser
174+
if currentUser == nil || !isSuperUser(currentUser) {
175+
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
176+
if err != nil {
177+
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
178+
}
179+
if workspaceGeneralSetting.DisallowUserRegistration {
180+
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
181+
}
182+
}
172183
}
173184

174185
if !base.UIDMatcher.MatchString(strings.ToLower(request.User.Username)) {

0 commit comments

Comments
 (0)