diff --git a/auth/oauth/oauth.go b/auth/oauth/oauth.go index 0df9d5c4..0525e8cd 100644 --- a/auth/oauth/oauth.go +++ b/auth/oauth/oauth.go @@ -2,14 +2,28 @@ package oauth import ( "context" + "encoding/json" "errors" "fmt" + "net/http" + "net/url" "strings" + "time" "github.com/coreos/go-oidc/v3/oidc" + "github.com/rs/zerolog/log" "golang.org/x/oauth2" ) +// hostConfigTimeout bounds the /.well-known/databricks-config lookup (and the +// subsequent OIDC discovery) so they cannot stall connection setup; on any +// failure we fall back to bare-host OIDC discovery. +const hostConfigTimeout = 10 * time.Second + +// accountIDPlaceholder is the token the host-metadata oidc_endpoint uses for the +// account id, e.g. "https:///oidc/accounts/{account_id}". +const accountIDPlaceholder = "{account_id}" + var azureTenants = map[string]string{ ".dev.azuredatabricks.net": "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc", ".staging.azuredatabricks.net": "4a67d088-db5c-48f1-9ff2-0aace800ae68", @@ -35,7 +49,19 @@ func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) return oauth2.Endpoint{AuthURL: authURL, TokenURL: tokenURL}, nil } - issuerURL := fmt.Sprintf("https://%s/oidc", hostName) + // AWS / GCP. Resolve the OIDC issuer via /.well-known/databricks-config so that + // unified / SPOG hosts (one host fronting workspaces across multiple accounts) + // use their account-rooted endpoint instead of the account-agnostic console login. + // For normal workspace hosts this resolves to https:///oidc, identical to + // the previous behavior. + // + // NOTE: this client uses the default transport, matching the existing + // oidc.NewProvider discovery below. A connector-supplied transport / TLS config + // (WithTransport, WithSkipTLSHostVerify) is not yet threaded into the OAuth + // endpoint-resolution path; that is a pre-existing limitation, tracked separately. + client := &http.Client{Timeout: hostConfigTimeout} + issuerURL := resolveOIDCIssuer(ctx, client, hostName) + ctx = oidc.ClientContext(ctx, client) ctx = oidc.InsecureIssuerURLContext(ctx, issuerURL) provider, err := oidc.NewProvider(ctx, issuerURL) if err != nil { @@ -47,6 +73,126 @@ func GetEndpoint(ctx context.Context, hostName string) (oauth2.Endpoint, error) return endpoint, err } +// hostMetadata is the subset of /.well-known/databricks-config we consume. +type hostMetadata struct { + OIDCEndpoint string `json:"oidc_endpoint"` + AccountID string `json:"account_id"` +} + +// resolveOIDCIssuer returns the OIDC issuer URL to use for AWS/GCP OAuth discovery. +// +// On a unified / SPOG host, the bare-host OIDC discovery doc points at the +// account-agnostic account-console login. That mints a token for the caller's +// default account, which the target workspace rejects ("Invalid Token") when the +// workspace belongs to a different account. Such hosts advertise the correct, +// account-rooted OIDC endpoint via /.well-known/databricks-config (with an +// {account_id} placeholder); we consult it and substitute the account id. +// +// For a normal workspace host the advertised endpoint is just https:///oidc, +// so the result is identical to the historical bare-host issuer. Any failure or +// unusable value (endpoint absent, non-200, unparseable, missing/empty account_id, +// non-https, non-Databricks host, timeout) falls back to the bare-host issuer, +// preserving existing behavior. +func resolveOIDCIssuer(ctx context.Context, client *http.Client, hostName string) string { + fallback := fmt.Sprintf("https://%s/oidc", hostName) + + cfgURL := fmt.Sprintf("https://%s/.well-known/databricks-config", hostName) + meta, ok := fetchHostMetadata(ctx, client, cfgURL) + if !ok || meta.OIDCEndpoint == "" { + log.Debug().Msgf("oauth: no usable databricks-config for %q; using bare-host OIDC issuer", hostName) + return fallback + } + + // An account-rooted endpoint needs a non-empty account_id; otherwise the + // placeholder would resolve to a malformed ".../accounts/" issuer. Fall back + // rather than emit it (the function's documented contract). + if strings.Contains(meta.OIDCEndpoint, accountIDPlaceholder) && meta.AccountID == "" { + log.Warn().Msgf("oauth: databricks-config for %q has an %s placeholder but empty account_id; using bare-host OIDC issuer", hostName, accountIDPlaceholder) + return fallback + } + + issuer := substituteAccountID(meta) + if !isValidDatabricksIssuer(issuer) { + log.Warn().Msgf("oauth: databricks-config for %q advertised an unusable oidc_endpoint %q; using bare-host OIDC issuer", hostName, issuer) + return fallback + } + return issuer +} + +// substituteAccountID resolves the {account_id} placeholder in the advertised +// oidc_endpoint. Workspace hosts have no placeholder and are returned unchanged. +func substituteAccountID(meta hostMetadata) string { + return strings.ReplaceAll(meta.OIDCEndpoint, accountIDPlaceholder, meta.AccountID) +} + +// isValidDatabricksIssuer guards the metadata-supplied OIDC issuer before it is +// passed to discovery: it must be an https URL on a recognized Databricks domain +// with the {account_id} placeholder fully resolved. This bounds the trust placed +// in the host-supplied document (the issuer-match check is disabled via +// InsecureIssuerURLContext because the discovered issuer is cross-host) and avoids +// cleartext OAuth from an http:// endpoint. +func isValidDatabricksIssuer(issuer string) bool { + if strings.Contains(issuer, accountIDPlaceholder) { + return false + } + u, err := url.Parse(issuer) + if err != nil || u.Scheme != "https" || u.Hostname() == "" { + return false + } + return isDatabricksHost(u.Hostname()) +} + +// databricksHostSuffixes is the set of DNS suffixes treated as first-party +// Databricks hosts when deciding whether to trust a metadata-supplied OIDC issuer. +// +// Suffix matching (not the substring matching InferCloudFromHost uses for cloud +// routing) is deliberate here: this is a trust gate for a cross-host issuer, so +// "databricks.com.evil.example" must NOT pass. ".databricks.com" covers the +// .cloud/.dev/.gcp variants and the bare unified/SPOG custom URLs. +var databricksHostSuffixes = []string{ + ".databricks.com", + ".cloud.databricks.us", + ".azuredatabricks.net", + ".databricks.azure.us", + ".databricks.azure.cn", +} + +func isDatabricksHost(host string) bool { + host = strings.ToLower(host) + for _, suffix := range databricksHostSuffixes { + if strings.HasSuffix(host, suffix) { + return true + } + } + return false +} + +// fetchHostMetadata GETs /.well-known/databricks-config and decodes it. The bool +// is false on any failure (request error, non-200, unparseable body) so callers +// fall back to bare-host discovery. +func fetchHostMetadata(ctx context.Context, client *http.Client, url string) (hostMetadata, bool) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return hostMetadata{}, false + } + + resp, err := client.Do(req) + if err != nil { + return hostMetadata{}, false + } + defer resp.Body.Close() //nolint:errcheck + + if resp.StatusCode != http.StatusOK { + return hostMetadata{}, false + } + + var meta hostMetadata + if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { + return hostMetadata{}, false + } + return meta, true +} + func GetScopes(hostName string, scopes []string) []string { for _, s := range []string{oidc.ScopeOfflineAccess} { if !HasScope(scopes, s) { @@ -135,6 +281,16 @@ func InferCloudFromHost(hostname string) CloudType { return GCP } } + + // Unified / SPOG (Single Pane of Glass) AWS hosts use bare *.databricks.com + // custom URLs (e.g. .databricks.com, .staging.databricks.com) that + // match none of the lists above. Treat them as AWS. This is checked last so the + // more specific Azure (.azuredatabricks.net) and GCP (.gcp.databricks.com) hosts + // are classified first. + if strings.Contains(hostname, "databricks.com") { + return AWS + } + return Unknown } diff --git a/auth/oauth/oauth_test.go b/auth/oauth/oauth_test.go new file mode 100644 index 00000000..e9e64291 --- /dev/null +++ b/auth/oauth/oauth_test.go @@ -0,0 +1,258 @@ +package oauth + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestInferCloudFromHost(t *testing.T) { + cases := []struct { + host string + want CloudType + }{ + // Standard per-workspace hosts. + {"dbc-1234.cloud.databricks.com", AWS}, + {"example.cloud.databricks.us", AWS}, + {"foo.dev.databricks.com", AWS}, + {"adb-123.azuredatabricks.net", Azure}, + {"x.databricks.azure.us", Azure}, + {"y.databricks.azure.cn", Azure}, + {"ws.gcp.databricks.com", GCP}, + // SPOG / unified custom-URL AWS hosts (the fix): must classify as AWS, + // not Unknown, and must NOT be swallowed by the GCP/Azure checks. + {"pecoaws.databricks.com", AWS}, + {"dogfood.staging.databricks.com", AWS}, + // Azure SPOG stays Azure. + {"dogfood-spog.staging.azuredatabricks.net", Azure}, + // GCP custom host must remain GCP even though it contains "databricks.com". + {"foo.gcp.databricks.com", GCP}, + // Truly unrelated host stays Unknown. + {"example.com", Unknown}, + } + + for _, tc := range cases { + t.Run(tc.host, func(t *testing.T) { + if got := InferCloudFromHost(tc.host); got != tc.want { + t.Fatalf("InferCloudFromHost(%q) = %v, want %v", tc.host, got, tc.want) + } + }) + } +} + +func TestGetAzureDnsZone(t *testing.T) { + // Documents current behavior: the generic suffix is matched first, so staging + // and dev Azure hosts resolve to the prod tenant. (Separate fix tracked.) + cases := []struct { + host string + want string + }{ + {"adb-123.azuredatabricks.net", ".azuredatabricks.net"}, + {"x.databricks.azure.us", ".databricks.azure.us"}, + {"nope.example.com", ""}, + } + for _, tc := range cases { + t.Run(tc.host, func(t *testing.T) { + if got := GetAzureDnsZone(tc.host); got != tc.want { + t.Fatalf("GetAzureDnsZone(%q) = %q, want %q", tc.host, got, tc.want) + } + }) + } +} + +// TestResolveOIDCIssuer drives resolveOIDCIssuer end-to-end against an httptest +// server: the server stands in for the connection host, so a 200 with a given +// databricks-config body exercises the real fetch + substitution + validation + +// fallback wiring. +func TestResolveOIDCIssuer(t *testing.T) { + const fallbackForUnreachable = "https://no-such-host.invalid/oidc" + + cases := []struct { + name string + // body served at /.well-known/databricks-config (status 200). Empty body + // with status!=200 simulates a missing endpoint. + status int + body string + // want is the exact resolved issuer. If wantFallback is true we instead + // assert the bare-host fallback (https:///oidc). + want string + wantFallback bool + }{ + { + name: "unified host substitutes account_id", + status: 200, + body: `{"oidc_endpoint":"https://spog.databricks.com/oidc/accounts/{account_id}","account_id":"acc-123","host_type":"unified"}`, + want: "https://spog.databricks.com/oidc/accounts/acc-123", + }, + { + name: "workspace host returned unchanged", + status: 200, + body: `{"oidc_endpoint":"https://ws.cloud.databricks.com/oidc","account_id":"acc-123","host_type":"workspace"}`, + want: "https://ws.cloud.databricks.com/oidc", + }, + { + name: "placeholder with empty account_id falls back", + status: 200, + body: `{"oidc_endpoint":"https://spog.databricks.com/oidc/accounts/{account_id}","account_id":""}`, + wantFallback: true, + }, + { + name: "empty oidc_endpoint falls back", + status: 200, + body: `{"account_id":"acc-123"}`, + wantFallback: true, + }, + { + name: "non-https endpoint falls back", + status: 200, + body: `{"oidc_endpoint":"http://spog.databricks.com/oidc","account_id":"acc-123"}`, + wantFallback: true, + }, + { + name: "non-databricks host falls back", + status: 200, + body: `{"oidc_endpoint":"https://evil.example.com/oidc","account_id":"acc-123"}`, + wantFallback: true, + }, + { + name: "suffix-spoof host falls back", + status: 200, + body: `{"oidc_endpoint":"https://databricks.com.evil.example/oidc/accounts/{account_id}","account_id":"acc-123"}`, + wantFallback: true, + }, + { + name: "404 falls back", + status: 404, + wantFallback: true, + }, + { + name: "garbage body falls back", + status: 200, + body: `not json`, + wantFallback: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/databricks-config" { + w.WriteHeader(http.StatusNotFound) + return + } + if tc.status != 200 { + w.WriteHeader(tc.status) + return + } + _, _ = w.Write([]byte(tc.body)) + })) + defer srv.Close() + + host := strings.TrimPrefix(srv.URL, "https://") + got := resolveOIDCIssuer(context.Background(), srv.Client(), host) + + if tc.wantFallback { + if want := "https://" + host + "/oidc"; got != want { + t.Fatalf("issuer = %q, want fallback %q", got, want) + } + return + } + if got != tc.want { + t.Fatalf("issuer = %q, want %q", got, tc.want) + } + }) + } + + // Sanity: an unreachable host falls back without panicking. + if got := resolveOIDCIssuer(context.Background(), &http.Client{}, "no-such-host.invalid"); got != fallbackForUnreachable { + t.Fatalf("unreachable host issuer = %q, want %q", got, fallbackForUnreachable) + } +} + +func TestIsValidDatabricksIssuer(t *testing.T) { + cases := []struct { + issuer string + want bool + }{ + // Real hosts from the validated SPOG/non-SPOG flows must stay valid. + {"https://dogfood.staging.databricks.com/oidc/accounts/7a99b43c", true}, // SPOG (unified) issuer + {"https://e2-dogfood.staging.cloud.databricks.com/oidc", true}, // non-SPOG workspace + {"https://accounts.staging.cloud.databricks.com/oidc/accounts/x", true}, // account-rooted + {"https://spog.databricks.com/oidc/accounts/acc-123", true}, + {"https://ws.cloud.databricks.com/oidc", true}, + {"https://adb-1.azuredatabricks.net/oidc", true}, + {"https://x.cloud.databricks.us/oidc", true}, + // Rejections. + {"https://spog.databricks.com/oidc/accounts/{account_id}", false}, // unresolved placeholder + {"http://spog.databricks.com/oidc", false}, // not https + {"https://evil.example.com/oidc", false}, // not a databricks host + {"https://databricks.com.evil.example/oidc", false}, // suffix-spoof: must NOT pass + {"https://notdatabricks.com/oidc", false}, // "-databricks.com" is not ".databricks.com" + {"https://databricks.com/oidc", false}, // bare apex, no subdomain + {"://bad", false}, // unparseable + {"", false}, + } + for _, tc := range cases { + t.Run(tc.issuer, func(t *testing.T) { + if got := isValidDatabricksIssuer(tc.issuer); got != tc.want { + t.Fatalf("isValidDatabricksIssuer(%q) = %v, want %v", tc.issuer, got, tc.want) + } + }) + } +} + +func TestIsDatabricksHost(t *testing.T) { + cases := []struct { + host string + want bool + }{ + {"dogfood.staging.databricks.com", true}, + {"e2-dogfood.staging.cloud.databricks.com", true}, + {"foo.gcp.databricks.com", true}, + {"adb-1.azuredatabricks.net", true}, + {"x.cloud.databricks.us", true}, + {"DOGFOOD.STAGING.DATABRICKS.COM", true}, // case-insensitive + {"databricks.com.evil.example", false}, + {"notdatabricks.com", false}, + {"databricks.com", false}, + {"evil.example", false}, + {"", false}, + } + for _, tc := range cases { + t.Run(tc.host, func(t *testing.T) { + if got := isDatabricksHost(tc.host); got != tc.want { + t.Fatalf("isDatabricksHost(%q) = %v, want %v", tc.host, got, tc.want) + } + }) + } +} + +func TestFetchHostMetadata_failuresFallBack(t *testing.T) { + t.Run("404", func(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + if _, ok := fetchHostMetadata(context.Background(), srv.Client(), srv.URL); ok { + t.Fatal("ok=true on 404, want false (fallback)") + } + }) + + t.Run("garbage body", func(t *testing.T) { + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("not json")) + })) + defer srv.Close() + if _, ok := fetchHostMetadata(context.Background(), srv.Client(), srv.URL); ok { + t.Fatal("ok=true on garbage body, want false (fallback)") + } + }) + + t.Run("unreachable", func(t *testing.T) { + if _, ok := fetchHostMetadata(context.Background(), &http.Client{}, "https://127.0.0.1:1/nope"); ok { + t.Fatal("ok=true on unreachable host, want false (fallback)") + } + }) +}