Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/config/sso.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ type SSOProviderConfig struct {
ClientID string `yaml:"client_id" json:"clientId" mapstructure:"client_id"`
ClientSecret string `yaml:"client_secret" json:"clientSecret" mapstructure:"client_secret"`
IssuerURL string `yaml:"issuer_url" json:"issuerUrl" mapstructure:"issuer_url"`
WellKnownURL string `yaml:"well_known_url" json:"wellKnownUrl" mapstructure:"well_known_url"`
AuthURL string `yaml:"auth_url" json:"authUrl" mapstructure:"auth_url"`
TokenURL string `yaml:"token_url" json:"tokenUrl" mapstructure:"token_url"`
UserInfoURL string `yaml:"user_info_url" json:"userInfoUrl" mapstructure:"user_info_url"`
Expand Down
100 changes: 99 additions & 1 deletion internal/service/sso/providers/base_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package providers

import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"

"github.com/compliance-framework/api/internal/config"
Expand All @@ -23,7 +27,7 @@ type BaseOIDCProvider struct {

// NewBaseOIDCProvider creates a new generic OIDC provider
func NewBaseOIDCProvider(ctx context.Context, cfg *config.SSOProviderConfig, callbackURL string, logger *zap.SugaredLogger) (*BaseOIDCProvider, error) {
provider, err := oidc.NewProvider(ctx, cfg.IssuerURL)
provider, err := newOIDCProvider(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("failed to create OIDC provider: %w", err)
}
Expand Down Expand Up @@ -52,6 +56,100 @@ func NewBaseOIDCProvider(ctx context.Context, cfg *config.SSOProviderConfig, cal
}, nil
}

func newOIDCProvider(ctx context.Context, cfg *config.SSOProviderConfig) (*oidc.Provider, error) {
wellKnownURL := strings.TrimSpace(cfg.WellKnownURL)
if wellKnownURL == "" {
return oidc.NewProvider(ctx, cfg.IssuerURL)
}

providerConfig, err := fetchOIDCProviderConfig(ctx, wellKnownURL)
if err != nil {
return nil, err
}
if providerConfig.IssuerURL != cfg.IssuerURL {
return nil, fmt.Errorf("oidc: configured issuer URL %q did not match the issuer URL returned by provider %q", cfg.IssuerURL, providerConfig.IssuerURL)
}
internalIssuerURL, err := issuerURLFromWellKnownURL(wellKnownURL)
if err != nil {
return nil, err
}
rewriteServerSideOIDCEndpoints(providerConfig, cfg.IssuerURL, internalIssuerURL)

return providerConfig.NewProvider(ctx), nil
}

func fetchOIDCProviderConfig(ctx context.Context, wellKnownURL string) (*oidc.ProviderConfig, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL, nil)
if err != nil {
return nil, err
}

client := http.DefaultClient
if configuredClient, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok && configuredClient != nil {
client = configuredClient
}

resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() {
_ = resp.Body.Close()
}()

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("unable to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%s: %s", resp.Status, body)
}

var providerConfig oidc.ProviderConfig
if err := json.Unmarshal(body, &providerConfig); err != nil {
return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %w", err)
}
return &providerConfig, nil
}

func issuerURLFromWellKnownURL(wellKnownURL string) (string, error) {
parsedURL, err := url.Parse(wellKnownURL)
if err != nil {
return "", err
}

wellKnownPath := "/.well-known/openid-configuration"
if !strings.HasSuffix(parsedURL.Path, wellKnownPath) {
return "", fmt.Errorf("well_known_url %q must end with %s", wellKnownURL, wellKnownPath)
}

parsedURL.Path = strings.TrimSuffix(parsedURL.Path, wellKnownPath)
parsedURL.RawQuery = ""
parsedURL.Fragment = ""
return strings.TrimSuffix(parsedURL.String(), "/"), nil
}

func rewriteServerSideOIDCEndpoints(providerConfig *oidc.ProviderConfig, publicIssuerURL string, internalIssuerURL string) {
providerConfig.TokenURL = rewriteIssuerURL(providerConfig.TokenURL, publicIssuerURL, internalIssuerURL)
providerConfig.UserInfoURL = rewriteIssuerURL(providerConfig.UserInfoURL, publicIssuerURL, internalIssuerURL)
providerConfig.JWKSURL = rewriteIssuerURL(providerConfig.JWKSURL, publicIssuerURL, internalIssuerURL)
}

func rewriteIssuerURL(value string, publicIssuerURL string, internalIssuerURL string) string {
if value == "" {
return ""
}

publicIssuerURL = strings.TrimSuffix(publicIssuerURL, "/")
if value == publicIssuerURL {
return internalIssuerURL
}
if strings.HasPrefix(value, publicIssuerURL+"/") {
return internalIssuerURL + strings.TrimPrefix(value, publicIssuerURL)
}
return value
}

func (p *BaseOIDCProvider) GetAuthURL(state string) string {
return p.oauth2Config.AuthCodeURL(state)
}
Expand Down
64 changes: 64 additions & 0 deletions internal/service/sso/providers/base_oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ package providers

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/compliance-framework/api/internal/config"
Expand Down Expand Up @@ -169,3 +173,63 @@ func TestBaseOIDCProvider_GetUserInfoMissingIDToken(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "no id_token")
}

func TestBaseOIDCProvider_UsesWellKnownURLForDiscovery(t *testing.T) {
mock := testutil.NewMockOIDCServer(t)
defer mock.Close()

publicIssuerURL := "https://dex.example.com/dex"
mock.IssuerURL = publicIssuerURL

cfg := &config.SSOProviderConfig{
Name: "dex",
ClientID: "test-client",
ClientSecret: "test-secret",
IssuerURL: publicIssuerURL,
WellKnownURL: fmt.Sprintf("%s/.well-known/openid-configuration", mock.Server.URL),
}
logger := zap.NewNop().Sugar()

provider, err := NewBaseOIDCProvider(context.Background(), cfg, "https://app.example.com/callback", logger)
require.NoError(t, err)

require.Equal(t, fmt.Sprintf("%s/auth", publicIssuerURL), provider.oauth2Config.Endpoint.AuthURL)
require.Equal(t, fmt.Sprintf("%s/token", mock.Server.URL), provider.oauth2Config.Endpoint.TokenURL)

rawIDToken, err := mock.SignIDToken(map[string]any{"email": "dev@example.com"})
require.NoError(t, err)

token := (&oauth2.Token{AccessToken: "token"}).WithExtra(map[string]any{
"id_token": rawIDToken,
})

userInfo, err := provider.GetUserInfo(context.Background(), token)
require.NoError(t, err)
require.Equal(t, "dev@example.com", userInfo.Email)
}

func TestBaseOIDCProvider_RejectsWellKnownURLIssuerMismatch(t *testing.T) {
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := json.NewEncoder(w).Encode(map[string]any{
"issuer": "https://unexpected.example.com/dex",
"jwks_uri": "https://unexpected.example.com/dex/keys",
"authorization_endpoint": "https://unexpected.example.com/dex/auth",
"token_endpoint": "https://unexpected.example.com/dex/token",
})
require.NoError(t, err)
}))
defer discoveryServer.Close()

cfg := &config.SSOProviderConfig{
Name: "dex",
ClientID: "test-client",
ClientSecret: "test-secret",
IssuerURL: "https://dex.example.com/dex",
WellKnownURL: fmt.Sprintf("%s/.well-known/openid-configuration", discoveryServer.URL),
}
logger := zap.NewNop().Sugar()

_, err := NewBaseOIDCProvider(context.Background(), cfg, "https://app.example.com/callback", logger)
require.Error(t, err)
require.Contains(t, err.Error(), "configured issuer URL")
}
2 changes: 2 additions & 0 deletions sso.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ providers:
# client_id: "${OIDC_CLIENT_ID}"
# client_secret: "${OIDC_CLIENT_SECRET}"
# issuer_url: "${OIDC_ISSUER_URL}"
# # Optional: fetch discovery from an internal URL while preserving issuer_url for ID token validation.
# well_known_url: "${OIDC_WELL_KNOWN_URL}"
# scopes:
# - "openid"
# - "email"
Expand Down
Loading