diff --git a/internal/service/sso/providers/base_oidc.go b/internal/service/sso/providers/base_oidc.go index 3796f0c4..5fad5f48 100644 --- a/internal/service/sso/providers/base_oidc.go +++ b/internal/service/sso/providers/base_oidc.go @@ -3,6 +3,7 @@ package providers import ( "context" "fmt" + "strings" "github.com/compliance-framework/api/internal/config" "github.com/compliance-framework/api/internal/service/sso/types" @@ -92,6 +93,9 @@ func (p *BaseOIDCProvider) GetUserInfo(ctx context.Context, token *oauth2.Token) if familyName, ok := claims["family_name"].(string); ok { userInfo.LastName = familyName } + if userInfo.FirstName == "" && userInfo.LastName == "" { + userInfo.FirstName, userInfo.LastName = splitDisplayName(userInfo.Name) + } if hd, ok := claims["hd"].(string); ok { userInfo.HostedDomain = hd } @@ -102,6 +106,17 @@ func (p *BaseOIDCProvider) GetUserInfo(ctx context.Context, token *oauth2.Token) return userInfo, nil } +func splitDisplayName(name string) (string, string) { + parts := strings.Fields(name) + if len(parts) == 0 { + return "", "" + } + if len(parts) == 1 { + return parts[0], "" + } + return parts[0], strings.Join(parts[1:], " ") +} + func (p *BaseOIDCProvider) extractGroups(claims map[string]interface{}) []string { claimGroups := buildClaimGroups(claims) diff --git a/internal/service/sso/providers/base_oidc_test.go b/internal/service/sso/providers/base_oidc_test.go index f3ab63fb..316f62dd 100644 --- a/internal/service/sso/providers/base_oidc_test.go +++ b/internal/service/sso/providers/base_oidc_test.go @@ -58,6 +58,100 @@ func TestBaseOIDCProvider_GetUserInfo(t *testing.T) { require.ElementsMatch(t, []string{"ccf-admins", "ccf-engineering"}, userInfo.Groups) } +func TestBaseOIDCProvider_GetUserInfoSplitsNameWhenStructuredNameMissing(t *testing.T) { + mock := testutil.NewMockOIDCServer(t) + defer mock.Close() + + claims := map[string]any{ + "email": "alice@example.com", + "name": "Alice Dex Admin", + } + + rawIDToken, err := mock.SignIDToken(claims) + require.NoError(t, err) + + cfg := &config.SSOProviderConfig{ + Name: "test-oidc", + ClientID: "test-client", + IssuerURL: mock.IssuerURL, + } + logger := zap.NewNop().Sugar() + provider, err := NewBaseOIDCProvider(context.Background(), cfg, "https://app.example.com/callback", logger) + require.NoError(t, err) + + token := (&oauth2.Token{AccessToken: "token"}).WithExtra(map[string]any{ + "id_token": rawIDToken, + }) + + userInfo, err := provider.GetUserInfo(context.Background(), token) + require.NoError(t, err) + + require.Equal(t, "Alice Dex Admin", userInfo.Name) + require.Equal(t, "Alice", userInfo.FirstName) + require.Equal(t, "Dex Admin", userInfo.LastName) +} + +func TestBaseOIDCProvider_GetUserInfoKeepsStructuredNameClaims(t *testing.T) { + mock := testutil.NewMockOIDCServer(t) + defer mock.Close() + + claims := map[string]any{ + "email": "dev@example.com", + "name": "Display Name", + "given_name": "Structured", + "family_name": "Person", + } + + rawIDToken, err := mock.SignIDToken(claims) + require.NoError(t, err) + + cfg := &config.SSOProviderConfig{ + Name: "test-oidc", + ClientID: "test-client", + IssuerURL: mock.IssuerURL, + } + logger := zap.NewNop().Sugar() + provider, err := NewBaseOIDCProvider(context.Background(), cfg, "https://app.example.com/callback", logger) + require.NoError(t, err) + + token := (&oauth2.Token{AccessToken: "token"}).WithExtra(map[string]any{ + "id_token": rawIDToken, + }) + + userInfo, err := provider.GetUserInfo(context.Background(), token) + require.NoError(t, err) + + require.Equal(t, "Display Name", userInfo.Name) + require.Equal(t, "Structured", userInfo.FirstName) + require.Equal(t, "Person", userInfo.LastName) +} + +func TestSplitDisplayName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + firstName string + lastName string + }{ + {name: "", firstName: "", lastName: ""}, + {name: "Alice", firstName: "Alice", lastName: ""}, + {name: "Alice Admin", firstName: "Alice", lastName: "Admin"}, + {name: " Alice Dex Admin ", firstName: "Alice", lastName: "Dex Admin"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + firstName, lastName := splitDisplayName(tt.name) + + require.Equal(t, tt.firstName, firstName) + require.Equal(t, tt.lastName, lastName) + }) + } +} + func TestBaseOIDCProvider_GetUserInfoMissingIDToken(t *testing.T) { cfg := &config.SSOProviderConfig{ Name: "test-oidc",