diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index c2ae5fb18..db98cce7d 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -25,8 +25,10 @@ This is a C# based repository that produces several CLIs that are used by custom ### Go Port Directories - `cmd/gei/`, `cmd/ado2gh/`, `cmd/bbs2gh/`: Go CLI entry points - `pkg/scriptgen/`: PowerShell script generation (ported from C#) +- `pkg/github/`: GitHub API client (REST + GraphQL) - `pkg/logger/`, `pkg/env/`: Shared Go packages - `internal/cmdutil/`: Command utility helpers +- `internal/sharedcmd/`: Shared commands (download-logs, version, wait-for-migration, etc.) ## Key Guidelines 1. Follow C# best practices and idiomatic patterns @@ -38,14 +40,18 @@ This is a C# based repository that produces several CLIs that are used by custom ## Go Port Sync Requirements -**Current state:** The Go port has the base framework and `generate-script` commands for all three CLIs. Script generation has full behavioral parity with C#. +**Current state:** The Go port has `generate-script` commands, the GitHub API client, and shared commands (download-logs, version, wait-for-migration, grant-migrator-role, revoke-migrator-role, create-team, add-team-members, lock-ado-repo, disable-ado-repo, configure-autolink). -**When making C# changes to script generation logic:** -- If you modify `GenerateScriptCommandHandler.cs` in any of the three CLIs, you MUST make the corresponding change in Go: - - `src/gei/Commands/GenerateScript/` → `cmd/gei/generate_script.go` + `pkg/scriptgen/generator.go` - - `src/ado2gh/Commands/GenerateScript/` → `cmd/ado2gh/generate_script.go` - - `src/bbs2gh/Commands/GenerateScript/` → `cmd/bbs2gh/generate_script.go` -- Run `go test ./...` to verify the Go changes compile and tests pass -- Generated PowerShell scripts must be identical between C# and Go +**When making C# changes, check if the Go port needs updating:** -**When making other C# changes:** No Go sync required yet. The remaining commands are not yet ported. +| C# Area | Go Equivalent | Sync Required? | +|----------|--------------|----------------| +| `GenerateScriptCommandHandler.cs` (any CLI) | `cmd/{cli}/generate_script.go` + `pkg/scriptgen/generator.go` | **Yes** — scripts must be identical | +| `src/Octoshift/Services/GithubApi.cs` | `pkg/github/client.go` | **Yes** — API behavior must match | +| `src/Octoshift/Services/GithubClient.cs` | `pkg/github/client.go` | **Yes** — HTTP/auth behavior must match | +| Shared commands in `src/Octoshift/Commands/` | `internal/sharedcmd/` | **Yes** — command behavior must match | +| `src/gei/Commands/DownloadLogs/` | `cmd/gei/download_logs.go` | **Yes** | +| ADO/BBS API clients or commands | Not yet ported | No | +| `migrate-repo` commands | Not yet ported | No | + +**Testing:** Run `go test ./...` to verify Go changes. Run `golangci-lint run` to check for lint issues. diff --git a/.gitignore b/.gitignore index 6bbde164b..4e5d918a0 100644 --- a/.gitignore +++ b/.gitignore @@ -360,6 +360,11 @@ MigrationBackup/ /src/OctoshiftCLI.IntegrationTests/Properties/launchSettings.json /src/ado2gh/Properties/launchSettings.json +# Go binaries (built from cmd/) +/gei +/ado2gh +/bbs2gh + # Go coverage reports coverage/ *.out diff --git a/cmd/ado2gh/main.go b/cmd/ado2gh/main.go index 5da395a52..90ff6db0a 100644 --- a/cmd/ado2gh/main.go +++ b/cmd/ado2gh/main.go @@ -2,13 +2,22 @@ package main import ( "context" + "net/http" "os" + "strings" "github.com/github/gh-gei/pkg/env" "github.com/github/gh-gei/pkg/logger" + "github.com/github/gh-gei/pkg/status" + versionpkg "github.com/github/gh-gei/pkg/version" "github.com/spf13/cobra" ) +// contextKey is an unexported type for context keys in this package. +type contextKey string + +const loggerKey contextKey = "logger" + var ( version = "dev" verbose bool @@ -25,11 +34,16 @@ func newRootCmd() *cobra.Command { Use: "ado2gh", Short: "Azure DevOps to GitHub migration CLI", Long: "Automate end-to-end Azure DevOps Repos to GitHub migrations.", - PersistentPreRun: func(cmd *cobra.Command, args []string) { + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { log := logger.New(verbose) - ctx := context.WithValue(cmd.Context(), "logger", log) + ctx := context.WithValue(cmd.Context(), loggerKey, log) cmd.SetContext(ctx) log.Debug("Execution started") + + checkVersion(ctx, log) + checkGitHubStatus(ctx, log) + + return nil }, SilenceUsage: true, SilenceErrors: true, @@ -64,7 +78,7 @@ func newRootCmd() *cobra.Command { } func getLogger(cmd *cobra.Command) *logger.Logger { - if log, ok := cmd.Context().Value("logger").(*logger.Logger); ok { + if log, ok := cmd.Context().Value(loggerKey).(*logger.Logger); ok { return log } return logger.New(false) @@ -76,17 +90,41 @@ func getEnvProvider() *env.Provider { func checkVersion(ctx context.Context, log *logger.Logger) { envProvider := getEnvProvider() - if envProvider.SkipVersionCheck() == "true" || envProvider.SkipVersionCheck() == "1" { + skip := envProvider.SkipVersionCheck() + if strings.EqualFold(skip, "true") || skip == "1" { log.Info("Skipped latest version check due to GEI_SKIP_VERSION_CHECK environment variable") return } - log.Info("You are running ado2gh CLI version %s", version) + + checker := versionpkg.NewChecker(&http.Client{}, log, version) + isLatest, err := checker.IsLatest(ctx) + if err != nil { + log.Debug("Version check failed: %v", err) + return + } + + if !isLatest { + latest, _ := checker.GetLatestVersion(ctx) + log.Info("New version available: %s", latest) + log.Info("You are running ado2gh CLI version %s", version) + } } func checkGitHubStatus(ctx context.Context, log *logger.Logger) { envProvider := getEnvProvider() - if envProvider.SkipStatusCheck() == "true" || envProvider.SkipStatusCheck() == "1" { + skip := envProvider.SkipStatusCheck() + if strings.EqualFold(skip, "true") || skip == "1" { log.Info("Skipped GitHub status check due to GEI_SKIP_STATUS_CHECK environment variable") return } + + count, err := status.GetUnresolvedIncidentsCount(ctx, &http.Client{}, "https://www.githubstatus.com") + if err != nil { + log.Debug("GitHub status check failed: %v", err) + return + } + + if count > 0 { + log.Warning("GitHub is currently experiencing %d incident(s). Check https://www.githubstatus.com for details.", count) + } } diff --git a/cmd/bbs2gh/main.go b/cmd/bbs2gh/main.go index f3369aa09..3b0ab491d 100644 --- a/cmd/bbs2gh/main.go +++ b/cmd/bbs2gh/main.go @@ -2,13 +2,22 @@ package main import ( "context" + "net/http" "os" + "strings" "github.com/github/gh-gei/pkg/env" "github.com/github/gh-gei/pkg/logger" + "github.com/github/gh-gei/pkg/status" + versionpkg "github.com/github/gh-gei/pkg/version" "github.com/spf13/cobra" ) +// contextKey is an unexported type for context keys in this package. +type contextKey string + +const loggerKey contextKey = "logger" + var ( version = "dev" verbose bool @@ -25,11 +34,16 @@ func newRootCmd() *cobra.Command { Use: "bbs2gh", Short: "Bitbucket Server to GitHub migration CLI", Long: "Migrate repositories from Bitbucket Server and Data Center to GitHub Enterprise Cloud.", - PersistentPreRun: func(cmd *cobra.Command, args []string) { + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { log := logger.New(verbose) - ctx := context.WithValue(cmd.Context(), "logger", log) + ctx := context.WithValue(cmd.Context(), loggerKey, log) cmd.SetContext(ctx) log.Debug("Execution started") + + checkVersion(ctx, log) + checkGitHubStatus(ctx, log) + + return nil }, SilenceUsage: true, SilenceErrors: true, @@ -57,7 +71,7 @@ func newRootCmd() *cobra.Command { } func getLogger(cmd *cobra.Command) *logger.Logger { - if log, ok := cmd.Context().Value("logger").(*logger.Logger); ok { + if log, ok := cmd.Context().Value(loggerKey).(*logger.Logger); ok { return log } return logger.New(false) @@ -69,17 +83,41 @@ func getEnvProvider() *env.Provider { func checkVersion(ctx context.Context, log *logger.Logger) { envProvider := getEnvProvider() - if envProvider.SkipVersionCheck() == "true" || envProvider.SkipVersionCheck() == "1" { + skip := envProvider.SkipVersionCheck() + if strings.EqualFold(skip, "true") || skip == "1" { log.Info("Skipped latest version check due to GEI_SKIP_VERSION_CHECK environment variable") return } - log.Info("You are running bbs2gh CLI version %s", version) + + checker := versionpkg.NewChecker(&http.Client{}, log, version) + isLatest, err := checker.IsLatest(ctx) + if err != nil { + log.Debug("Version check failed: %v", err) + return + } + + if !isLatest { + latest, _ := checker.GetLatestVersion(ctx) + log.Info("New version available: %s", latest) + log.Info("You are running bbs2gh CLI version %s", version) + } } func checkGitHubStatus(ctx context.Context, log *logger.Logger) { envProvider := getEnvProvider() - if envProvider.SkipStatusCheck() == "true" || envProvider.SkipStatusCheck() == "1" { + skip := envProvider.SkipStatusCheck() + if strings.EqualFold(skip, "true") || skip == "1" { log.Info("Skipped GitHub status check due to GEI_SKIP_STATUS_CHECK environment variable") return } + + count, err := status.GetUnresolvedIncidentsCount(ctx, &http.Client{}, "https://www.githubstatus.com") + if err != nil { + log.Debug("GitHub status check failed: %v", err) + return + } + + if count > 0 { + log.Warning("GitHub is currently experiencing %d incident(s). Check https://www.githubstatus.com for details.", count) + } } diff --git a/cmd/gei/abort_migration.go b/cmd/gei/abort_migration.go new file mode 100644 index 000000000..4c717d923 --- /dev/null +++ b/cmd/gei/abort_migration.go @@ -0,0 +1,63 @@ +package main + +import ( + "context" + "strings" + + "github.com/github/gh-gei/internal/cmdutil" + "github.com/github/gh-gei/pkg/logger" + "github.com/spf13/cobra" +) + +// migrationAborter is the consumer-defined interface for aborting migrations. +type migrationAborter interface { + AbortMigration(ctx context.Context, id string) (bool, error) +} + +// newAbortMigrationCmd creates the abort-migration cobra command. +func newAbortMigrationCmd(gh migrationAborter, log *logger.Logger) *cobra.Command { + var migrationID string + + cmd := &cobra.Command{ + Use: "abort-migration", + Short: "Aborts a repository migration that is queued or in progress", + Long: "Aborts a repository migration that is queued or in progress.", + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateAbortMigrationID(migrationID); err != nil { + return err + } + return runAbortMigration(cmd.Context(), gh, log, migrationID) + }, + } + + cmd.Flags().StringVar(&migrationID, "migration-id", "", + "The ID of the migration to abort, starting with RM_. Organization migrations, where the ID starts with OM_, are not supported.") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + + return cmd +} + +func validateAbortMigrationID(id string) error { + if strings.TrimSpace(id) == "" { + return cmdutil.NewUserError("--migration-id must be provided") + } + if !strings.HasPrefix(id, repoMigrationIDPrefix) { + return cmdutil.NewUserErrorf( + "Invalid migration ID: %s. Only repository migration IDs starting with RM_ are supported.", id) + } + return nil +} + +func runAbortMigration(ctx context.Context, gh migrationAborter, log *logger.Logger, migrationID string) error { + success, err := gh.AbortMigration(ctx, migrationID) + if err != nil { + return err + } + if !success { + log.Errorf("Failed to abort migration %s", migrationID) + return nil + } + log.Info("Migration %s was canceled", migrationID) + return nil +} diff --git a/cmd/gei/abort_migration_test.go b/cmd/gei/abort_migration_test.go new file mode 100644 index 000000000..092f45bdf --- /dev/null +++ b/cmd/gei/abort_migration_test.go @@ -0,0 +1,118 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMigrationAborter implements the migrationAborter interface for testing. +type mockMigrationAborter struct { + result bool + err error + called bool + gotID string +} + +func (m *mockMigrationAborter) AbortMigration(_ context.Context, id string) (bool, error) { + m.called = true + m.gotID = id + return m.result, m.err +} + +func TestAbortMigration(t *testing.T) { + tests := []struct { + name string + migrationID string + mock *mockMigrationAborter + wantErr string + wantOutput []string // substrings that must appear in output + wantCalled bool + wantID string // expected migration ID passed to mock + }{ + { + name: "abort succeeds", + migrationID: "RM_123", + mock: &mockMigrationAborter{result: true}, + wantOutput: []string{"Migration RM_123 was canceled"}, + wantCalled: true, + wantID: "RM_123", + }, + { + name: "abort fails returns false", + migrationID: "RM_456", + mock: &mockMigrationAborter{result: false}, + wantOutput: []string{"Failed to abort migration RM_456"}, + wantCalled: true, + }, + { + name: "abort returns error", + migrationID: "RM_789", + mock: &mockMigrationAborter{err: fmt.Errorf("network failure")}, + wantErr: "network failure", + wantCalled: true, + }, + { + name: "missing migration ID", + migrationID: "", + mock: &mockMigrationAborter{}, + wantErr: "--migration-id must be provided", + wantCalled: false, + }, + { + name: "invalid migration ID no RM_ prefix", + migrationID: "XX_invalid", + mock: &mockMigrationAborter{}, + wantErr: "Invalid migration ID: XX_invalid. Only repository migration IDs starting with RM_ are supported.", + wantCalled: false, + }, + { + name: "OM_ prefix rejected for abort", + migrationID: "OM_100", + mock: &mockMigrationAborter{}, + wantErr: "Invalid migration ID: OM_100. Only repository migration IDs starting with RM_ are supported.", + wantCalled: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + cmd := newAbortMigrationCmd(tc.mock, log) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + + args := []string{} + if tc.migrationID != "" { + args = append(args, "--migration-id", tc.migrationID) + } + cmd.SetArgs(args) + + err := cmd.Execute() + + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + + output := buf.String() + for _, want := range tc.wantOutput { + assert.Contains(t, output, want, "expected output to contain %q", want) + } + + assert.Equal(t, tc.wantCalled, tc.mock.called, "expected AbortMigration called=%v", tc.wantCalled) + if tc.wantID != "" { + assert.Equal(t, tc.wantID, tc.mock.gotID) + } + }) + } +} diff --git a/cmd/gei/create_team.go b/cmd/gei/create_team.go new file mode 100644 index 000000000..f2b9ce9b9 --- /dev/null +++ b/cmd/gei/create_team.go @@ -0,0 +1,115 @@ +package main + +import ( + "context" + "strings" + + "github.com/github/gh-gei/internal/cmdutil" + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/spf13/cobra" +) + +// teamCreator is the consumer-defined interface for create-team. +type teamCreator interface { + GetTeams(ctx context.Context, org string) ([]github.Team, error) + CreateTeam(ctx context.Context, org, name string) (*github.Team, error) + GetTeamMembers(ctx context.Context, org, teamSlug string) ([]string, error) + RemoveTeamMember(ctx context.Context, org, teamSlug, member string) error + GetIdpGroupId(ctx context.Context, org, groupName string) (int, error) + AddEmuGroupToTeam(ctx context.Context, org, teamSlug string, groupID int) error +} + +// newCreateTeamCmd creates the create-team cobra command. +func newCreateTeamCmd(gh teamCreator, log *logger.Logger) *cobra.Command { + var ( + githubOrg string + teamName string + idpGroup string + ) + + cmd := &cobra.Command{ + Use: "create-team", + Short: "Creates a GitHub team and optionally links it to an IdP group", + Long: "Creates a GitHub team and optionally links it to an IdP group.", + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateCreateTeamArgs(githubOrg, teamName); err != nil { + return err + } + return runCreateTeam(cmd.Context(), gh, log, githubOrg, teamName, idpGroup) + }, + } + + cmd.Flags().StringVar(&githubOrg, "github-org", "", "The GitHub organization to create the team in (REQUIRED)") + cmd.Flags().StringVar(&teamName, "team-name", "", "The name of the team to create (REQUIRED)") + cmd.Flags().StringVar(&idpGroup, "idp-group", "", "The name of the IdP group to link to the team") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + + return cmd +} + +func validateCreateTeamArgs(githubOrg, teamName string) error { + if strings.TrimSpace(githubOrg) == "" { + return cmdutil.NewUserError("--github-org must be provided") + } + if strings.HasPrefix(githubOrg, "http://") || strings.HasPrefix(githubOrg, "https://") { + return cmdutil.NewUserError("The --github-org option expects an organization name, not a URL. Please provide just the organization name.") + } + if strings.TrimSpace(teamName) == "" { + return cmdutil.NewUserError("--team-name must be provided") + } + return nil +} + +func runCreateTeam(ctx context.Context, gh teamCreator, log *logger.Logger, githubOrg, teamName, idpGroup string) error { + log.Info("Creating GitHub team...") + + teams, err := gh.GetTeams(ctx, githubOrg) + if err != nil { + return err + } + + var teamSlug string + for _, t := range teams { + if t.Name == teamName { + teamSlug = t.Slug + break + } + } + + if teamSlug != "" { + log.Success("Team '%s' already exists. New team will not be created", teamName) + } else { + team, err := gh.CreateTeam(ctx, githubOrg, teamName) + if err != nil { + return err + } + teamSlug = team.Slug + log.Success("Successfully created team") + } + + if strings.TrimSpace(idpGroup) == "" { + log.Info("No IdP Group provided, skipping the IdP linking step") + } else { + members, err := gh.GetTeamMembers(ctx, githubOrg, teamSlug) + if err != nil { + return err + } + for _, member := range members { + if err := gh.RemoveTeamMember(ctx, githubOrg, teamSlug, member); err != nil { + return err + } + } + idpGroupID, err := gh.GetIdpGroupId(ctx, githubOrg, idpGroup) + if err != nil { + return err + } + if err := gh.AddEmuGroupToTeam(ctx, githubOrg, teamSlug, idpGroupID); err != nil { + return err + } + log.Success("Successfully linked team to Idp group") + } + + return nil +} diff --git a/cmd/gei/create_team_test.go b/cmd/gei/create_team_test.go new file mode 100644 index 000000000..77f203bc2 --- /dev/null +++ b/cmd/gei/create_team_test.go @@ -0,0 +1,277 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockTeamCreator implements teamCreator for testing. +type mockTeamCreator struct { + teams []github.Team + getTeamsErr error + createdTeam *github.Team + createTeamErr error + teamMembers []string + getMembersErr error + removeMemberErr error + idpGroupID int + getIdpErr error + addEmuErr error + + // capture calls + gotCreateOrg string + gotCreateName string + gotMembersOrg string + gotMembersSlug string + removedMembers []string + gotIdpOrg string + gotIdpGroup string + gotEmuOrg string + gotEmuSlug string + gotEmuGroupID int +} + +func (m *mockTeamCreator) GetTeams(_ context.Context, org string) ([]github.Team, error) { + return m.teams, m.getTeamsErr +} + +func (m *mockTeamCreator) CreateTeam(_ context.Context, org, name string) (*github.Team, error) { + m.gotCreateOrg = org + m.gotCreateName = name + return m.createdTeam, m.createTeamErr +} + +func (m *mockTeamCreator) GetTeamMembers(_ context.Context, org, teamSlug string) ([]string, error) { + m.gotMembersOrg = org + m.gotMembersSlug = teamSlug + return m.teamMembers, m.getMembersErr +} + +func (m *mockTeamCreator) RemoveTeamMember(_ context.Context, org, teamSlug, member string) error { + m.removedMembers = append(m.removedMembers, member) + return m.removeMemberErr +} + +func (m *mockTeamCreator) GetIdpGroupId(_ context.Context, org, groupName string) (int, error) { + m.gotIdpOrg = org + m.gotIdpGroup = groupName + return m.idpGroupID, m.getIdpErr +} + +func (m *mockTeamCreator) AddEmuGroupToTeam(_ context.Context, org, teamSlug string, groupID int) error { + m.gotEmuOrg = org + m.gotEmuSlug = teamSlug + m.gotEmuGroupID = groupID + return m.addEmuErr +} + +func TestCreateTeam(t *testing.T) { + tests := []struct { + name string + args []string + mock *mockTeamCreator + wantErr string + wantOutput []string + assertArgs func(t *testing.T, m *mockTeamCreator) + }{ + { + name: "team does not exist, creates team, no IdP group", + args: []string{"--github-org", "my-org", "--team-name", "my-team"}, + mock: &mockTeamCreator{ + teams: []github.Team{}, + createdTeam: &github.Team{ID: "1", Name: "my-team", Slug: "my-team"}, + }, + wantOutput: []string{ + "Creating GitHub team...", + "Successfully created team", + "No IdP Group provided, skipping the IdP linking step", + }, + assertArgs: func(t *testing.T, m *mockTeamCreator) { + assert.Equal(t, "my-org", m.gotCreateOrg) + assert.Equal(t, "my-team", m.gotCreateName) + }, + }, + { + name: "team already exists, logs and skips creation", + args: []string{"--github-org", "my-org", "--team-name", "existing-team"}, + mock: &mockTeamCreator{ + teams: []github.Team{ + {ID: "10", Name: "existing-team", Slug: "existing-team-slug"}, + {ID: "20", Name: "other-team", Slug: "other-slug"}, + }, + }, + wantOutput: []string{ + "Creating GitHub team...", + "Team 'existing-team' already exists. New team will not be created", + "No IdP Group provided, skipping the IdP linking step", + }, + assertArgs: func(t *testing.T, m *mockTeamCreator) { + // CreateTeam should NOT be called + assert.Empty(t, m.gotCreateOrg) + assert.Empty(t, m.gotCreateName) + }, + }, + { + name: "team does not exist, creates team, links IdP group with members removed", + args: []string{"--github-org", "my-org", "--team-name", "my-team", "--idp-group", "my-idp-group"}, + mock: &mockTeamCreator{ + teams: []github.Team{}, + createdTeam: &github.Team{ID: "1", Name: "my-team", Slug: "my-team-slug"}, + teamMembers: []string{"user1", "user2"}, + idpGroupID: 42, + }, + wantOutput: []string{ + "Creating GitHub team...", + "Successfully created team", + "Successfully linked team to Idp group", + }, + assertArgs: func(t *testing.T, m *mockTeamCreator) { + assert.Equal(t, "my-org", m.gotCreateOrg) + assert.Equal(t, "my-team", m.gotCreateName) + // Members should be removed + assert.Equal(t, []string{"user1", "user2"}, m.removedMembers) + // IdP group looked up and linked + assert.Equal(t, "my-org", m.gotIdpOrg) + assert.Equal(t, "my-idp-group", m.gotIdpGroup) + assert.Equal(t, "my-org", m.gotEmuOrg) + assert.Equal(t, "my-team-slug", m.gotEmuSlug) + assert.Equal(t, 42, m.gotEmuGroupID) + }, + }, + { + name: "team already exists, links IdP group", + args: []string{"--github-org", "my-org", "--team-name", "existing-team", "--idp-group", "idp-group"}, + mock: &mockTeamCreator{ + teams: []github.Team{ + {ID: "10", Name: "existing-team", Slug: "existing-slug"}, + }, + teamMembers: []string{"member1"}, + idpGroupID: 99, + }, + wantOutput: []string{ + "Team 'existing-team' already exists. New team will not be created", + "Successfully linked team to Idp group", + }, + assertArgs: func(t *testing.T, m *mockTeamCreator) { + // Should use slug from existing team + assert.Equal(t, "existing-slug", m.gotMembersSlug) + assert.Equal(t, "existing-slug", m.gotEmuSlug) + assert.Equal(t, []string{"member1"}, m.removedMembers) + }, + }, + { + name: "GetTeams error propagates", + args: []string{"--github-org", "my-org", "--team-name", "my-team"}, + mock: &mockTeamCreator{getTeamsErr: fmt.Errorf("api error")}, + wantErr: "api error", + }, + { + name: "CreateTeam error propagates", + args: []string{"--github-org", "my-org", "--team-name", "my-team"}, + mock: &mockTeamCreator{ + teams: []github.Team{}, + createTeamErr: fmt.Errorf("create failed"), + }, + wantErr: "create failed", + }, + { + name: "GetTeamMembers error propagates", + args: []string{"--github-org", "my-org", "--team-name", "my-team", "--idp-group", "grp"}, + mock: &mockTeamCreator{ + teams: []github.Team{}, + createdTeam: &github.Team{ID: "1", Name: "my-team", Slug: "my-team"}, + getMembersErr: fmt.Errorf("members error"), + }, + wantErr: "members error", + }, + { + name: "RemoveTeamMember error propagates", + args: []string{"--github-org", "my-org", "--team-name", "my-team", "--idp-group", "grp"}, + mock: &mockTeamCreator{ + teams: []github.Team{}, + createdTeam: &github.Team{ID: "1", Name: "my-team", Slug: "my-team"}, + teamMembers: []string{"user1"}, + removeMemberErr: fmt.Errorf("remove error"), + }, + wantErr: "remove error", + }, + { + name: "GetIdpGroupId error propagates", + args: []string{"--github-org", "my-org", "--team-name", "my-team", "--idp-group", "grp"}, + mock: &mockTeamCreator{ + teams: []github.Team{}, + createdTeam: &github.Team{ID: "1", Name: "my-team", Slug: "my-team"}, + teamMembers: []string{}, + getIdpErr: fmt.Errorf("idp error"), + }, + wantErr: "idp error", + }, + { + name: "AddEmuGroupToTeam error propagates", + args: []string{"--github-org", "my-org", "--team-name", "my-team", "--idp-group", "grp"}, + mock: &mockTeamCreator{ + teams: []github.Team{}, + createdTeam: &github.Team{ID: "1", Name: "my-team", Slug: "my-team"}, + teamMembers: []string{}, + idpGroupID: 10, + addEmuErr: fmt.Errorf("emu error"), + }, + wantErr: "emu error", + }, + { + name: "github-org is URL", + args: []string{"--github-org", "https://github.com/my-org", "--team-name", "my-team"}, + mock: &mockTeamCreator{}, + wantErr: "The --github-org option expects an organization name, not a URL", + }, + { + name: "empty github-org", + args: []string{"--github-org", "", "--team-name", "my-team"}, + mock: &mockTeamCreator{}, + wantErr: "--github-org must be provided", + }, + { + name: "empty team-name", + args: []string{"--github-org", "my-org", "--team-name", ""}, + mock: &mockTeamCreator{}, + wantErr: "--team-name must be provided", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + cmd := newCreateTeamCmd(tc.mock, log) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs(tc.args) + + err := cmd.Execute() + + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + + output := buf.String() + for _, want := range tc.wantOutput { + assert.Contains(t, output, want, "expected output to contain %q", want) + } + + if tc.assertArgs != nil { + tc.assertArgs(t, tc.mock) + } + }) + } +} diff --git a/cmd/gei/download_logs.go b/cmd/gei/download_logs.go new file mode 100644 index 000000000..cbc62d3ff --- /dev/null +++ b/cmd/gei/download_logs.go @@ -0,0 +1,202 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/github/gh-gei/internal/cmdutil" + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/spf13/cobra" +) + +// logDownloader is the consumer-defined interface for fetching migration info. +type logDownloader interface { + GetMigration(ctx context.Context, id string) (*github.Migration, error) + GetMigrationLogUrl(ctx context.Context, org, repo string) (*github.MigrationLogResult, error) +} + +// fileDownloader is the consumer-defined interface for downloading files. +type fileDownloader interface { + DownloadToFile(ctx context.Context, url, filepath string) error +} + +// fileChecker is the consumer-defined interface for checking file existence. +type fileChecker interface { + FileExists(path string) bool +} + +// downloadLogsOptions holds tunable parameters for the download-logs command, +// allowing tests to set retries=0 and delay=0 so they don't wait. +type downloadLogsOptions struct { + maxRetries int + retryDelay time.Duration +} + +// newDownloadLogsCmd creates the download-logs cobra command. +func newDownloadLogsCmd(gh logDownloader, dl fileDownloader, fc fileChecker, log *logger.Logger, opts downloadLogsOptions) *cobra.Command { + var ( + migrationID string + githubTargetOrg string + targetRepo string + logFile string + overwrite bool + ) + + cmd := &cobra.Command{ + Use: "download-logs", + Short: "Downloads migration logs for a repository migration", + Long: "Downloads migration logs for a repository migration, either by migration ID or by org/repo.", + RunE: func(cmd *cobra.Command, args []string) error { + return runDownloadLogs(cmd.Context(), gh, dl, fc, log, downloadLogsParams{ + migrationID: migrationID, + githubTargetOrg: githubTargetOrg, + targetRepo: targetRepo, + logFile: logFile, + overwrite: overwrite, + maxRetries: opts.maxRetries, + retryDelay: opts.retryDelay, + }) + }, + } + + cmd.Flags().StringVar(&migrationID, "migration-id", "", "The ID of the migration") + cmd.Flags().StringVar(&githubTargetOrg, "github-target-org", "", "Target GitHub organization") + cmd.Flags().StringVar(&targetRepo, "target-repo", "", "Target repository name") + cmd.Flags().StringVar(&logFile, "migration-log-file", "", "Custom output filename for the migration log") + cmd.Flags().BoolVar(&overwrite, "overwrite", false, "Overwrite the log file if it already exists") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + + return cmd +} + +type downloadLogsParams struct { + migrationID string + githubTargetOrg string + targetRepo string + logFile string + overwrite bool + maxRetries int + retryDelay time.Duration +} + +func runDownloadLogs(ctx context.Context, gh logDownloader, dl fileDownloader, fc fileChecker, log *logger.Logger, p downloadLogsParams) error { + hasMigrationID := p.migrationID != "" + hasOrgRepo := p.githubTargetOrg != "" && p.targetRepo != "" + + if !hasMigrationID && !hasOrgRepo { + return cmdutil.NewUserError("must provide either --migration-id or both --github-target-org and --target-repo") + } + + // Check custom filename early + if p.logFile != "" { + if err := checkFileOverwrite(fc, log, p.logFile, p.overwrite); err != nil { + return err + } + } + + log.Warning("Migration logs are only available for 24 hours after a migration finishes!") + + var ( + logURL string + filename string + repoName string + ) + + if hasMigrationID { + if p.githubTargetOrg != "" || p.targetRepo != "" { + log.Warning("--github-target-org and --target-repo will be ignored because --migration-id was provided") + } + + m, err := waitForMigrationLogByID(ctx, gh, p.migrationID, p.maxRetries, p.retryDelay) + if err != nil { + return err + } + logURL = m.MigrationLogURL + repoName = m.RepositoryName + filename = fmt.Sprintf("migration-log-%s-%s.log", m.RepositoryName, p.migrationID) + } else { + result, err := waitForMigrationLogByOrgRepo(ctx, gh, p.githubTargetOrg, p.targetRepo, p.maxRetries, p.retryDelay) + if err != nil { + return err + } + logURL = result.MigrationLogURL + repoName = p.targetRepo + filename = fmt.Sprintf("migration-log-%s-%s-%s.log", p.githubTargetOrg, p.targetRepo, result.MigrationID) + } + + if p.logFile != "" { + filename = p.logFile + } else { + // Check default filename for overwrite + if err := checkFileOverwrite(fc, log, filename, p.overwrite); err != nil { + return err + } + } + + log.Info("Downloading migration logs...") + log.Info("Downloading log for repository %s to %s...", repoName, filename) + + if err := dl.DownloadToFile(ctx, logURL, filename); err != nil { + return err + } + + log.Success("Downloaded %s log to %s.", repoName, filename) + return nil +} + +func waitForMigrationLogByID(ctx context.Context, gh logDownloader, migrationID string, maxRetries int, retryDelay time.Duration) (*github.Migration, error) { + for attempt := 0; attempt <= maxRetries; attempt++ { + m, err := gh.GetMigration(ctx, migrationID) + if err != nil { + return nil, err + } + if m.MigrationLogURL != "" { + return m, nil + } + if attempt < maxRetries { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + } + } + } + return nil, cmdutil.NewUserErrorf("migration log URL was not populated for migration %s after retries", migrationID) +} + +func waitForMigrationLogByOrgRepo(ctx context.Context, gh logDownloader, org, repo string, maxRetries int, retryDelay time.Duration) (*github.MigrationLogResult, error) { + for attempt := 0; attempt <= maxRetries; attempt++ { + result, err := gh.GetMigrationLogUrl(ctx, org, repo) + if err != nil { + return nil, err + } + if result == nil { + return nil, cmdutil.NewUserErrorf("no migration found for %s/%s", org, repo) + } + if result.MigrationLogURL != "" { + return result, nil + } + if attempt < maxRetries { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelay): + } + } + } + return nil, cmdutil.NewUserErrorf("migration log URL was not populated for %s/%s after retries", org, repo) +} + +func checkFileOverwrite(fc fileChecker, log *logger.Logger, filepath string, overwrite bool) error { + if !fc.FileExists(filepath) { + return nil + } + if !overwrite { + return cmdutil.NewUserErrorf("file %s already exists. Use --overwrite to overwrite it", filepath) + } + log.Warning("File %s already exists and will be overwritten", filepath) + return nil +} diff --git a/cmd/gei/download_logs_test.go b/cmd/gei/download_logs_test.go new file mode 100644 index 000000000..a46ec92c1 --- /dev/null +++ b/cmd/gei/download_logs_test.go @@ -0,0 +1,390 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockLogDownloader implements the logDownloader interface for testing. +type mockLogDownloader struct { + getMigrationResult *github.Migration + getMigrationResults []*github.Migration // if set, returns results[calls-1] per call + getMigrationErr error + getMigrationCalls int + + getMigrationLogURLResult *github.MigrationLogResult + getMigrationLogURLResults []*github.MigrationLogResult // if set, returns results[calls-1] per call + getMigrationLogURLErr error + getMigrationLogURLCalls int +} + +func (m *mockLogDownloader) GetMigration(_ context.Context, _ string) (*github.Migration, error) { + m.getMigrationCalls++ + if m.getMigrationResults != nil && m.getMigrationCalls <= len(m.getMigrationResults) { + return m.getMigrationResults[m.getMigrationCalls-1], m.getMigrationErr + } + return m.getMigrationResult, m.getMigrationErr +} + +func (m *mockLogDownloader) GetMigrationLogUrl(_ context.Context, _, _ string) (*github.MigrationLogResult, error) { + m.getMigrationLogURLCalls++ + if m.getMigrationLogURLResults != nil && m.getMigrationLogURLCalls <= len(m.getMigrationLogURLResults) { + return m.getMigrationLogURLResults[m.getMigrationLogURLCalls-1], m.getMigrationLogURLErr + } + return m.getMigrationLogURLResult, m.getMigrationLogURLErr +} + +// mockFileDownloader implements the fileDownloader interface for testing. +type mockFileDownloader struct { + err error + calls int + gotURL string + gotPath string +} + +func (m *mockFileDownloader) DownloadToFile(_ context.Context, url, filepath string) error { + m.calls++ + m.gotURL = url + m.gotPath = filepath + return m.err +} + +// mockFileChecker implements the fileChecker interface for testing. +type mockFileChecker struct { + exists bool +} + +func (m *mockFileChecker) FileExists(_ string) bool { + return m.exists +} + +func TestDownloadLogs_ByMigrationID_Success(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResult: &github.Migration{ + ID: "RM_123", + RepositoryName: "my-repo", + MigrationLogURL: "https://example.com/log", + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_123"}) + + err := cmd.Execute() + require.NoError(t, err) + + assert.Equal(t, 1, gh.getMigrationCalls) + assert.Equal(t, 1, dl.calls) + assert.Equal(t, "https://example.com/log", dl.gotURL) + assert.Contains(t, dl.gotPath, "migration-log-my-repo-RM_123.log") + + output := buf.String() + assert.Contains(t, output, "Downloading migration logs") + assert.Contains(t, output, "Downloading log for repository my-repo to migration-log-my-repo-RM_123.log") + assert.Contains(t, output, "Downloaded my-repo log to migration-log-my-repo-RM_123.log") +} + +func TestDownloadLogs_ByOrgRepo_Success(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationLogURLResult: &github.MigrationLogResult{ + MigrationLogURL: "https://example.com/org-log", + MigrationID: "RM_456", + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--github-target-org", "my-org", "--target-repo", "my-repo"}) + + err := cmd.Execute() + require.NoError(t, err) + + assert.Equal(t, 1, gh.getMigrationLogURLCalls) + assert.Equal(t, 1, dl.calls) + assert.Equal(t, "https://example.com/org-log", dl.gotURL) + assert.Contains(t, dl.gotPath, "migration-log-my-org-my-repo-RM_456.log") +} + +func TestDownloadLogs_ByMigrationID_LogURLEmpty_Error(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResult: &github.Migration{ + ID: "RM_789", + RepositoryName: "my-repo", + MigrationLogURL: "", // empty — retry exhausted + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_789"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "migration log URL") + assert.Equal(t, 0, dl.calls) +} + +func TestDownloadLogs_ByOrgRepo_MigrationNotFound_Error(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationLogURLResult: nil, // no migration found + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--github-target-org", "my-org", "--target-repo", "my-repo"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "no migration found") + assert.Equal(t, 0, dl.calls) +} + +func TestDownloadLogs_FileExists_NoOverwrite_Error(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResult: &github.Migration{ + ID: "RM_123", + RepositoryName: "my-repo", + MigrationLogURL: "https://example.com/log", + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: true} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_123"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "already exists") + assert.Equal(t, 0, dl.calls) +} + +func TestDownloadLogs_FileExists_WithOverwrite_Success(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResult: &github.Migration{ + ID: "RM_123", + RepositoryName: "my-repo", + MigrationLogURL: "https://example.com/log", + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: true} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_123", "--overwrite"}) + + err := cmd.Execute() + require.NoError(t, err) + assert.Equal(t, 1, dl.calls) + + output := buf.String() + assert.Contains(t, output, "already exists") +} + +func TestDownloadLogs_NeitherMigrationIDNorOrgRepo_Error(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{} + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "must provide either --migration-id or both --github-target-org and --target-repo") +} + +func TestDownloadLogs_MigrationIDWithOrgRepo_WarnsAndUsesMigrationID(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResult: &github.Migration{ + ID: "RM_123", + RepositoryName: "my-repo", + MigrationLogURL: "https://example.com/log", + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_123", "--github-target-org", "my-org", "--target-repo", "my-repo"}) + + err := cmd.Execute() + require.NoError(t, err) + + // Should use migration ID path, not org/repo + assert.Equal(t, 1, gh.getMigrationCalls) + assert.Equal(t, 0, gh.getMigrationLogURLCalls) + + output := buf.String() + assert.Contains(t, output, "will be ignored") +} + +func TestDownloadLogs_CustomFilename(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResult: &github.Migration{ + ID: "RM_123", + RepositoryName: "my-repo", + MigrationLogURL: "https://example.com/log", + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_123", "--migration-log-file", "custom.log"}) + + err := cmd.Execute() + require.NoError(t, err) + + assert.Equal(t, "custom.log", dl.gotPath) +} + +func TestDownloadLogs_ByOrgRepo_LogURLEmpty_Error(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationLogURLResult: &github.MigrationLogResult{ + MigrationLogURL: "", // empty — retry exhausted + MigrationID: "RM_456", + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--github-target-org", "my-org", "--target-repo", "my-repo"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "migration log URL") + assert.Equal(t, 0, dl.calls) +} + +func TestDownloadLogs_DownloadError_PropagatesError(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResult: &github.Migration{ + ID: "RM_123", + RepositoryName: "my-repo", + MigrationLogURL: "https://example.com/log", + }, + } + dl := &mockFileDownloader{err: fmt.Errorf("download failed: network error")} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_123"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "network error") +} + +func TestDownloadLogs_ByMigrationID_RetrySucceedsOnSecondAttempt(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{ + getMigrationResults: []*github.Migration{ + {ID: "RM_123", RepositoryName: "my-repo", MigrationLogURL: ""}, // first call: empty + {ID: "RM_123", RepositoryName: "my-repo", MigrationLogURL: "https://example.com/log-retry"}, // second call: populated + }, + } + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 1, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--migration-id", "RM_123"}) + + err := cmd.Execute() + require.NoError(t, err) + + assert.Equal(t, 2, gh.getMigrationCalls) + assert.Equal(t, 1, dl.calls) + assert.Equal(t, "https://example.com/log-retry", dl.gotURL) +} + +func TestDownloadLogs_PartialOrgRepo_Error(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + gh := &mockLogDownloader{} + dl := &mockFileDownloader{} + fc := &mockFileChecker{exists: false} + + // Only --github-target-org, missing --target-repo + cmd := newDownloadLogsCmd(gh, dl, fc, log, downloadLogsOptions{maxRetries: 0, retryDelay: 0}) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs([]string{"--github-target-org", "my-org"}) + + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "must provide either --migration-id or both --github-target-org and --target-repo") +} diff --git a/cmd/gei/generate_mannequin_csv.go b/cmd/gei/generate_mannequin_csv.go new file mode 100644 index 000000000..33c4c1217 --- /dev/null +++ b/cmd/gei/generate_mannequin_csv.go @@ -0,0 +1,106 @@ +package main + +import ( + "context" + "fmt" + "os" + "strings" + + "github.com/github/gh-gei/internal/cmdutil" + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/github/gh-gei/pkg/mannequin" + "github.com/spf13/cobra" +) + +// mannequinCSVGenerator is the consumer-defined interface for generate-mannequin-csv. +type mannequinCSVGenerator interface { + GetOrganizationId(ctx context.Context, org string) (string, error) + GetMannequins(ctx context.Context, orgID string) ([]github.Mannequin, error) +} + +// newGenerateMannequinCSVCmd creates the generate-mannequin-csv cobra command. +func newGenerateMannequinCSVCmd(gh mannequinCSVGenerator, log *logger.Logger, writeFile func(path, content string) error) *cobra.Command { + var ( + githubTargetOrg string + output string + includeReclaimed bool + ) + + if writeFile == nil { + writeFile = func(path, content string) error { + return os.WriteFile(path, []byte(content), 0o600) + } + } + + cmd := &cobra.Command{ + Use: "generate-mannequin-csv", + Short: "Generates a CSV file with mannequin users", + Long: "Generates a CSV file with mannequin users for an organization.", + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateGenerateMannequinCSVArgs(githubTargetOrg); err != nil { + return err + } + return runGenerateMannequinCSV(cmd.Context(), gh, log, writeFile, githubTargetOrg, output, includeReclaimed) + }, + } + + cmd.Flags().StringVar(&githubTargetOrg, "github-target-org", "", "The target GitHub organization (REQUIRED)") + cmd.Flags().StringVar(&output, "output", "mannequins.csv", "Output file path") + cmd.Flags().BoolVar(&includeReclaimed, "include-reclaimed", false, "Include mannequins that have already been reclaimed") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + + return cmd +} + +func validateGenerateMannequinCSVArgs(githubTargetOrg string) error { + if strings.TrimSpace(githubTargetOrg) == "" { + return cmdutil.NewUserError("--github-target-org must be provided") + } + if strings.HasPrefix(githubTargetOrg, "http://") || strings.HasPrefix(githubTargetOrg, "https://") { + return cmdutil.NewUserError("The --github-target-org option expects an organization name, not a URL. Please provide just the organization name.") + } + return nil +} + +func runGenerateMannequinCSV(ctx context.Context, gh mannequinCSVGenerator, log *logger.Logger, writeFile func(path, content string) error, org, output string, includeReclaimed bool) error { + log.Info("Generating CSV...") + + orgID, err := gh.GetOrganizationId(ctx, org) + if err != nil { + return err + } + + mannequins, err := gh.GetMannequins(ctx, orgID) + if err != nil { + return err + } + + reclaimedCount := 0 + for _, m := range mannequins { + if m.MappedUser != nil { + reclaimedCount++ + } + } + + log.Info(" # Mannequins Found: %d", len(mannequins)) + log.Info(" # Mannequins Previously Reclaimed: %d", reclaimedCount) + + var sb strings.Builder + sb.WriteString(mannequin.CSVHeader) + sb.WriteString("\n") + + for _, m := range mannequins { + if !includeReclaimed && m.MappedUser != nil { + continue + } + mappedLogin := "" + if m.MappedUser != nil { + mappedLogin = m.MappedUser.Login + } + fmt.Fprintf(&sb, "%s,%s,%s\n", m.Login, m.ID, mappedLogin) + } + + return writeFile(output, sb.String()) +} diff --git a/cmd/gei/generate_mannequin_csv_test.go b/cmd/gei/generate_mannequin_csv_test.go new file mode 100644 index 000000000..f31e72008 --- /dev/null +++ b/cmd/gei/generate_mannequin_csv_test.go @@ -0,0 +1,152 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/github/gh-gei/pkg/mannequin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockMannequinCSVGenerator struct { + orgID string + orgIDErr error + + mannequins []github.Mannequin + mannequinsErr error +} + +func (m *mockMannequinCSVGenerator) GetOrganizationId(_ context.Context, org string) (string, error) { + return m.orgID, m.orgIDErr +} + +func (m *mockMannequinCSVGenerator) GetMannequins(_ context.Context, orgID string) ([]github.Mannequin, error) { + return m.mannequins, m.mannequinsErr +} + +func TestGenerateMannequinCSV(t *testing.T) { + tests := []struct { + name string + args []string + mock *mockMannequinCSVGenerator + wantErr string + wantOutput []string + wantCSVContent string + }{ + { + name: "no mannequins generates CSV with header only", + args: []string{"--github-target-org", "FooOrg", "--output", "test.csv"}, + mock: &mockMannequinCSVGenerator{ + orgID: "org-id", + mannequins: []github.Mannequin{}, + }, + wantOutput: []string{ + "Generating CSV", + "# Mannequins Found: 0", + "# Mannequins Previously Reclaimed: 0", + }, + wantCSVContent: mannequin.CSVHeader + "\n", + }, + { + name: "mannequins without reclaimed, exclude reclaimed", + args: []string{"--github-target-org", "FooOrg"}, + mock: &mockMannequinCSVGenerator{ + orgID: "org-id", + mannequins: []github.Mannequin{ + {ID: "monaid", Login: "mona"}, + {ID: "monalisaid", Login: "monalisa", MappedUser: &github.MannequinUser{ID: "mapped-id", Login: "monalisa_gh"}}, + }, + }, + wantOutput: []string{ + "# Mannequins Found: 2", + "# Mannequins Previously Reclaimed: 1", + }, + wantCSVContent: mannequin.CSVHeader + "\n" + + "mona,monaid,\n", + }, + { + name: "include reclaimed mannequins", + args: []string{"--github-target-org", "FooOrg", "--include-reclaimed"}, + mock: &mockMannequinCSVGenerator{ + orgID: "org-id", + mannequins: []github.Mannequin{ + {ID: "monaid", Login: "mona"}, + {ID: "monalisaid", Login: "monalisa", MappedUser: &github.MannequinUser{ID: "mapped-id", Login: "monalisa_gh"}}, + }, + }, + wantCSVContent: mannequin.CSVHeader + "\n" + + "mona,monaid,\n" + + "monalisa,monalisaid,monalisa_gh\n", + }, + { + name: "missing github-target-org flag", + args: []string{}, + mock: &mockMannequinCSVGenerator{}, + wantErr: "--github-target-org must be provided", + }, + { + name: "github-target-org is URL", + args: []string{"--github-target-org", "https://github.com/my-org"}, + mock: &mockMannequinCSVGenerator{}, + wantErr: "expects an organization name, not a URL", + }, + { + name: "GetOrganizationId error propagates", + args: []string{"--github-target-org", "FooOrg"}, + mock: &mockMannequinCSVGenerator{ + orgIDErr: fmt.Errorf("org not found"), + }, + wantErr: "org not found", + }, + { + name: "GetMannequins error propagates", + args: []string{"--github-target-org", "FooOrg"}, + mock: &mockMannequinCSVGenerator{ + orgID: "org-id", + mannequinsErr: fmt.Errorf("api error"), + }, + wantErr: "api error", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + var writtenContent string + writeFile := func(path, content string) error { + writtenContent = content + return nil + } + + cmd := newGenerateMannequinCSVCmd(tc.mock, log, writeFile) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs(tc.args) + + err := cmd.Execute() + + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + + output := buf.String() + for _, want := range tc.wantOutput { + assert.Contains(t, output, want, "expected output to contain %q", want) + } + + if tc.wantCSVContent != "" { + assert.Equal(t, tc.wantCSVContent, writtenContent) + } + }) + } +} diff --git a/cmd/gei/generate_script.go b/cmd/gei/generate_script.go index 88ad76f67..45d05ca61 100644 --- a/cmd/gei/generate_script.go +++ b/cmd/gei/generate_script.go @@ -10,7 +10,6 @@ import ( "github.com/github/gh-gei/pkg/env" "github.com/github/gh-gei/pkg/github" - "github.com/github/gh-gei/pkg/http" "github.com/github/gh-gei/pkg/logger" "github.com/github/gh-gei/pkg/scriptgen" "github.com/spf13/cobra" @@ -104,16 +103,15 @@ func runGenerateScript(ctx context.Context, opts *generateScriptOptions, log *lo sourceAPIURL = "https://api.github.com" } - httpCfg := http.DefaultConfig() - httpCfg.NoSSLVerify = opts.noSSLVerify - httpClient := http.NewClient(httpCfg, log) - - githubCfg := github.Config{ - APIURL: sourceAPIURL, - PAT: githubPAT, - NoSSLVerify: opts.noSSLVerify, + clientOpts := []github.Option{ + github.WithAPIURL(sourceAPIURL), + github.WithLogger(log), + github.WithVersion(version), + } + if opts.noSSLVerify { + clientOpts = append(clientOpts, github.WithNoSSLVerify()) } - githubClient := github.NewClient(githubCfg, httpClient, log) + githubClient := github.NewClient(githubPAT, clientOpts...) // Get repositories from source org log.Info("GITHUB ORG: %s", opts.githubSourceOrg) diff --git a/cmd/gei/grant_migrator_role.go b/cmd/gei/grant_migrator_role.go new file mode 100644 index 000000000..8574d943c --- /dev/null +++ b/cmd/gei/grant_migrator_role.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "strings" + + "github.com/github/gh-gei/internal/cmdutil" + "github.com/github/gh-gei/pkg/logger" + "github.com/spf13/cobra" +) + +// migratorRoleGranter is the consumer-defined interface for granting migrator roles. +type migratorRoleGranter interface { + GetOrganizationId(ctx context.Context, org string) (string, error) + GrantMigratorRole(ctx context.Context, orgID, actor, actorType string) (bool, error) +} + +// newGrantMigratorRoleCmd creates the grant-migrator-role cobra command. +func newGrantMigratorRoleCmd(gh migratorRoleGranter, log *logger.Logger) *cobra.Command { + var ( + githubOrg string + actor string + actorType string + ) + + cmd := &cobra.Command{ + Use: "grant-migrator-role", + Short: "Grants the migrator role to a user or team for a GitHub organization", + Long: "Grants the migrator role to a user or team for a GitHub organization.", + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateMigratorRoleArgs(githubOrg, actor, actorType, cmd); err != nil { + return err + } + actorType = strings.ToUpper(actorType) + return runGrantMigratorRole(cmd.Context(), gh, log, githubOrg, actor, actorType) + }, + } + + cmd.Flags().StringVar(&githubOrg, "github-org", "", "The GitHub organization to grant the migrator role for (REQUIRED)") + cmd.Flags().StringVar(&actor, "actor", "", "The user or team to grant the migrator role to (REQUIRED)") + cmd.Flags().StringVar(&actorType, "actor-type", "", "The type of the actor (USER or TEAM) (REQUIRED)") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + cmd.Flags().String("ghes-api-url", "", "API URL for the source GHES instance") + + return cmd +} + +func runGrantMigratorRole(ctx context.Context, gh migratorRoleGranter, log *logger.Logger, githubOrg, actor, actorType string) error { + log.Info("Granting migrator role ...") + + orgID, err := gh.GetOrganizationId(ctx, githubOrg) + if err != nil { + return err + } + + success, err := gh.GrantMigratorRole(ctx, orgID, actor, actorType) + if err != nil { + return err + } + + if success { + log.Success("Migrator role successfully set for the %s \"%s\"", actorType, actor) + } else { + log.Errorf("Migrator role couldn't be set for the %s \"%s\"", actorType, actor) + } + + return nil +} + +// validateMigratorRoleArgs validates the shared arguments for grant/revoke migrator role commands. +func validateMigratorRoleArgs(githubOrg, actor, actorType string, cmd *cobra.Command) error { + if strings.TrimSpace(githubOrg) == "" { + return cmdutil.NewUserError("--github-org must be provided") + } + if strings.TrimSpace(actor) == "" { + return cmdutil.NewUserError("--actor must be provided") + } + if strings.HasPrefix(githubOrg, "http://") || strings.HasPrefix(githubOrg, "https://") { + return cmdutil.NewUserError("The --github-org option expects an organization name, not a URL. Please provide just the organization name.") + } + + upper := strings.ToUpper(actorType) + if upper != "TEAM" && upper != "USER" { + return cmdutil.NewUserError("Actor type must be either TEAM or USER.") + } + + ghesAPIURL, _ := cmd.Flags().GetString("ghes-api-url") + targetAPIURL, _ := cmd.Flags().GetString("target-api-url") + if ghesAPIURL != "" && targetAPIURL != "" { + return cmdutil.NewUserError("Only one of --ghes-api-url or --target-api-url can be set at a time.") + } + + return nil +} diff --git a/cmd/gei/grant_migrator_role_test.go b/cmd/gei/grant_migrator_role_test.go new file mode 100644 index 000000000..97bac1013 --- /dev/null +++ b/cmd/gei/grant_migrator_role_test.go @@ -0,0 +1,151 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMigratorRoleGranter implements migratorRoleGranter for testing. +type mockMigratorRoleGranter struct { + orgID string + orgIDErr error + grantResult bool + grantErr error + gotOrg string + gotOrgID string + gotActor string + gotType string +} + +func (m *mockMigratorRoleGranter) GetOrganizationId(_ context.Context, org string) (string, error) { + m.gotOrg = org + return m.orgID, m.orgIDErr +} + +func (m *mockMigratorRoleGranter) GrantMigratorRole(_ context.Context, orgID, actor, actorType string) (bool, error) { + m.gotOrgID = orgID + m.gotActor = actor + m.gotType = actorType + return m.grantResult, m.grantErr +} + +func TestGrantMigratorRole(t *testing.T) { + tests := []struct { + name string + args []string + mock *mockMigratorRoleGranter + wantErr string + wantOutput []string + assertArgs func(t *testing.T, m *mockMigratorRoleGranter) + }{ + { + name: "grant succeeds", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleGranter{orgID: "ORG_ID_123", grantResult: true}, + wantOutput: []string{ + "Granting migrator role ...", + `Migrator role successfully set for the USER "monalisa"`, + }, + assertArgs: func(t *testing.T, m *mockMigratorRoleGranter) { + assert.Equal(t, "my-org", m.gotOrg) + assert.Equal(t, "ORG_ID_123", m.gotOrgID) + assert.Equal(t, "monalisa", m.gotActor) + assert.Equal(t, "USER", m.gotType) + }, + }, + { + name: "grant succeeds with lowercase actor type", + args: []string{"--github-org", "my-org", "--actor", "my-team", "--actor-type", "team"}, + mock: &mockMigratorRoleGranter{orgID: "ORG_ID_123", grantResult: true}, + wantOutput: []string{ + `Migrator role successfully set for the TEAM "my-team"`, + }, + }, + { + name: "grant fails returns false", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleGranter{orgID: "ORG_ID_123", grantResult: false}, + wantOutput: []string{ + `Migrator role couldn't be set for the USER "monalisa"`, + }, + }, + { + name: "GetOrganizationId error propagates", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleGranter{orgIDErr: fmt.Errorf("org not found")}, + wantErr: "org not found", + }, + { + name: "invalid actor type", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "INVALID"}, + mock: &mockMigratorRoleGranter{}, + wantErr: "Actor type must be either TEAM or USER.", + }, + { + name: "github-org is URL", + args: []string{"--github-org", "https://github.com/my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleGranter{}, + wantErr: "The --github-org option expects an organization name, not a URL", + }, + { + name: "ghes-api-url and target-api-url both set", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER", "--ghes-api-url", "https://ghes.example.com", "--target-api-url", "https://api.github.com"}, + mock: &mockMigratorRoleGranter{}, + wantErr: "Only one of --ghes-api-url or --target-api-url can be set at a time.", + }, + { + name: "GrantMigratorRole error propagates", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleGranter{orgID: "ORG_ID_123", grantErr: fmt.Errorf("permission denied")}, + wantErr: "permission denied", + }, + { + name: "empty github-org", + args: []string{"--github-org", "", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleGranter{}, + wantErr: "--github-org must be provided", + }, + { + name: "empty actor", + args: []string{"--github-org", "my-org", "--actor", "", "--actor-type", "USER"}, + mock: &mockMigratorRoleGranter{}, + wantErr: "--actor must be provided", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + cmd := newGrantMigratorRoleCmd(tc.mock, log) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs(tc.args) + + err := cmd.Execute() + + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + + output := buf.String() + for _, want := range tc.wantOutput { + assert.Contains(t, output, want, "expected output to contain %q", want) + } + + if tc.assertArgs != nil { + tc.assertArgs(t, tc.mock) + } + }) + } +} diff --git a/cmd/gei/main.go b/cmd/gei/main.go index 4898f2349..7cc5c6662 100644 --- a/cmd/gei/main.go +++ b/cmd/gei/main.go @@ -2,13 +2,22 @@ package main import ( "context" + "net/http" "os" + "strings" "github.com/github/gh-gei/pkg/env" "github.com/github/gh-gei/pkg/logger" + "github.com/github/gh-gei/pkg/status" + versionpkg "github.com/github/gh-gei/pkg/version" "github.com/spf13/cobra" ) +// contextKey is an unexported type for context keys in this package. +type contextKey string + +const loggerKey contextKey = "logger" + var ( version = "dev" verbose bool @@ -25,13 +34,18 @@ func newRootCmd() *cobra.Command { Use: "gei", Short: "GitHub Enterprise Importer CLI", Long: "CLI for migrating repositories between GitHub instances using GitHub Enterprise Importer.", - PersistentPreRun: func(cmd *cobra.Command, args []string) { + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { // Initialize logger log := logger.New(verbose) - ctx := context.WithValue(cmd.Context(), "logger", log) + ctx := context.WithValue(cmd.Context(), loggerKey, log) cmd.SetContext(ctx) log.Debug("Execution started") + + checkVersion(ctx, log) + checkGitHubStatus(ctx, log) + + return nil }, SilenceUsage: true, SilenceErrors: true, @@ -62,7 +76,7 @@ func newRootCmd() *cobra.Command { // getLogger retrieves the logger from the command context func getLogger(cmd *cobra.Command) *logger.Logger { - if log, ok := cmd.Context().Value("logger").(*logger.Logger); ok { + if log, ok := cmd.Context().Value(loggerKey).(*logger.Logger); ok { return log } return logger.New(false) @@ -76,24 +90,42 @@ func getEnvProvider() *env.Provider { // checkVersion checks if a newer version is available func checkVersion(ctx context.Context, log *logger.Logger) { envProvider := getEnvProvider() - - if envProvider.SkipVersionCheck() == "true" || envProvider.SkipVersionCheck() == "1" { + skip := envProvider.SkipVersionCheck() + if strings.EqualFold(skip, "true") || skip == "1" { log.Info("Skipped latest version check due to GEI_SKIP_VERSION_CHECK environment variable") return } - // TODO: Implement version check - log.Info("You are running gei CLI version %s", version) + checker := versionpkg.NewChecker(&http.Client{}, log, version) + isLatest, err := checker.IsLatest(ctx) + if err != nil { + log.Debug("Version check failed: %v", err) + return + } + + if !isLatest { + latest, _ := checker.GetLatestVersion(ctx) + log.Info("New version available: %s", latest) + log.Info("You are running gei CLI version %s", version) + } } // checkGitHubStatus checks if GitHub is experiencing incidents func checkGitHubStatus(ctx context.Context, log *logger.Logger) { envProvider := getEnvProvider() - - if envProvider.SkipStatusCheck() == "true" || envProvider.SkipStatusCheck() == "1" { + skip := envProvider.SkipStatusCheck() + if strings.EqualFold(skip, "true") || skip == "1" { log.Info("Skipped GitHub status check due to GEI_SKIP_STATUS_CHECK environment variable") return } - // TODO: Implement GitHub status check + count, err := status.GetUnresolvedIncidentsCount(ctx, &http.Client{}, "https://www.githubstatus.com") + if err != nil { + log.Debug("GitHub status check failed: %v", err) + return + } + + if count > 0 { + log.Warning("GitHub is currently experiencing %d incident(s). Check https://www.githubstatus.com for details.", count) + } } diff --git a/cmd/gei/reclaim_mannequin.go b/cmd/gei/reclaim_mannequin.go new file mode 100644 index 000000000..38ca83bf4 --- /dev/null +++ b/cmd/gei/reclaim_mannequin.go @@ -0,0 +1,159 @@ +package main + +import ( + "bufio" + "context" + "os" + "strings" + + "github.com/github/gh-gei/internal/cmdutil" + "github.com/github/gh-gei/pkg/logger" + "github.com/spf13/cobra" +) + +// mannequinReclaimer is the consumer-defined interface for the reclaim service. +type mannequinReclaimer interface { + ReclaimMannequin(ctx context.Context, mannequinUser, mannequinID, targetUser, org string, force, skipInvitation bool) error + ReclaimMannequins(ctx context.Context, lines []string, org string, force, skipInvitation bool) error +} + +// mannequinReclaimAPI is the consumer-defined interface for direct GitHub API calls +// needed by the reclaim-mannequin command (skip-invitation admin check). +type mannequinReclaimAPI interface { + GetLoginName(ctx context.Context) (string, error) + GetOrgMembershipForUser(ctx context.Context, org, member string) (string, error) +} + +// newReclaimMannequinCmd creates the reclaim-mannequin cobra command. +func newReclaimMannequinCmd( + svc mannequinReclaimer, + api mannequinReclaimAPI, + log *logger.Logger, + fileExists func(string) bool, + readFile func(string) ([]string, error), +) *cobra.Command { + var ( + githubTargetOrg string + csv string + mannequinUser string + mannequinID string + targetUser string + force bool + skipInvitation bool + noPrompt bool + ) + + if fileExists == nil { + fileExists = func(path string) bool { + _, err := os.Stat(path) + return err == nil + } + } + if readFile == nil { + readFile = readFileLines + } + + cmd := &cobra.Command{ + Use: "reclaim-mannequin", + Short: "Reclaims one or more mannequin users", + Long: "Reclaims one or more mannequin users by mapping them to real GitHub users.", + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateReclaimMannequinArgs(githubTargetOrg, csv, mannequinUser, targetUser); err != nil { + return err + } + return runReclaimMannequin(cmd.Context(), svc, api, log, fileExists, readFile, + githubTargetOrg, csv, mannequinUser, mannequinID, targetUser, force, skipInvitation, noPrompt) + }, + } + + cmd.Flags().StringVar(&githubTargetOrg, "github-target-org", "", "The target GitHub organization (REQUIRED)") + cmd.Flags().StringVar(&csv, "csv", "", "Path to a CSV file with mannequin mappings") + cmd.Flags().StringVar(&mannequinUser, "mannequin-user", "", "The login of the mannequin user to reclaim") + cmd.Flags().StringVar(&mannequinID, "mannequin-id", "", "The ID of the mannequin user to reclaim") + cmd.Flags().StringVar(&targetUser, "target-user", "", "The login of the target user to map the mannequin to") + cmd.Flags().BoolVar(&force, "force", false, "Reclaim even if the mannequin is already mapped") + cmd.Flags().BoolVar(&skipInvitation, "skip-invitation", false, "Skip sending an invitation email (EMU orgs only)") + cmd.Flags().BoolVar(&noPrompt, "no-prompt", false, "Skip confirmation prompt for skip-invitation") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + + return cmd +} + +func validateReclaimMannequinArgs(githubTargetOrg, csv, mannequinUser, targetUser string) error { + if strings.TrimSpace(githubTargetOrg) == "" { + return cmdutil.NewUserError("--github-target-org must be provided") + } + if strings.HasPrefix(githubTargetOrg, "http://") || strings.HasPrefix(githubTargetOrg, "https://") { + return cmdutil.NewUserError("The --github-target-org option expects an organization name, not a URL. Please provide just the organization name.") + } + if csv == "" && (mannequinUser == "" || targetUser == "") { + return cmdutil.NewUserError("Either --csv or --mannequin-user and --target-user must be specified") + } + return nil +} + +func runReclaimMannequin( + ctx context.Context, + svc mannequinReclaimer, + api mannequinReclaimAPI, + log *logger.Logger, + fileExists func(string) bool, + readFile func(string) ([]string, error), + org, csv, mannequinUser, mannequinID, targetUser string, + force, skipInvitation, noPrompt bool, +) error { + if skipInvitation { + if !noPrompt { + return cmdutil.NewUserError("Reclaiming mannequins with --skip-invitation is immediate and irreversible. Use --no-prompt to confirm.") + } + + login, err := api.GetLoginName(ctx) + if err != nil { + return err + } + + membership, err := api.GetOrgMembershipForUser(ctx, org, login) + if err != nil { + return err + } + + if membership != "admin" { + return cmdutil.NewUserErrorf("User %s is not an org admin and is not eligible to reclaim mannequins with the --skip-invitation feature.", login) + } + } + + if csv != "" { + log.Info("Reclaiming Mannequins with CSV...") + + if !fileExists(csv) { + return cmdutil.NewUserErrorf("File %s does not exist.", csv) + } + + lines, err := readFile(csv) + if err != nil { + return err + } + + return svc.ReclaimMannequins(ctx, lines, org, force, skipInvitation) + } + + log.Info("Reclaiming Mannequin...") + return svc.ReclaimMannequin(ctx, mannequinUser, mannequinID, targetUser, org, force, skipInvitation) +} + +// readFileLines reads a file and returns its lines. +func readFileLines(path string) ([]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var lines []string + scanner := bufio.NewScanner(f) + for scanner.Scan() { + lines = append(lines, scanner.Text()) + } + return lines, scanner.Err() +} diff --git a/cmd/gei/reclaim_mannequin_test.go b/cmd/gei/reclaim_mannequin_test.go new file mode 100644 index 000000000..8d750cb39 --- /dev/null +++ b/cmd/gei/reclaim_mannequin_test.go @@ -0,0 +1,271 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMannequinReclaimer implements mannequinReclaimer for testing. +type mockMannequinReclaimer struct { + reclaimErr error + reclaimsErr error + + gotReclaimArgs *reclaimSingleArgs + gotReclaimsArgs *reclaimBulkArgs +} + +type reclaimSingleArgs struct { + MannequinUser string + MannequinID string + TargetUser string + Org string + Force bool + SkipInvitation bool +} + +type reclaimBulkArgs struct { + Lines []string + Org string + Force bool + SkipInvitation bool +} + +func (m *mockMannequinReclaimer) ReclaimMannequin(_ context.Context, mannequinUser, mannequinID, targetUser, org string, force, skipInvitation bool) error { + m.gotReclaimArgs = &reclaimSingleArgs{ + MannequinUser: mannequinUser, + MannequinID: mannequinID, + TargetUser: targetUser, + Org: org, + Force: force, + SkipInvitation: skipInvitation, + } + return m.reclaimErr +} + +func (m *mockMannequinReclaimer) ReclaimMannequins(_ context.Context, lines []string, org string, force, skipInvitation bool) error { + m.gotReclaimsArgs = &reclaimBulkArgs{ + Lines: lines, + Org: org, + Force: force, + SkipInvitation: skipInvitation, + } + return m.reclaimsErr +} + +// mockMannequinReclaimAPI implements mannequinReclaimAPI for testing. +type mockMannequinReclaimAPI struct { + loginName string + loginNameErr error + membership string + membershipErr error +} + +func (m *mockMannequinReclaimAPI) GetLoginName(_ context.Context) (string, error) { + return m.loginName, m.loginNameErr +} + +func (m *mockMannequinReclaimAPI) GetOrgMembershipForUser(_ context.Context, org, member string) (string, error) { + return m.membership, m.membershipErr +} + +func TestReclaimMannequin(t *testing.T) { + tests := []struct { + name string + args []string + reclaimer *mockMannequinReclaimer + api *mockMannequinReclaimAPI + fileExists func(string) bool + readFile func(string) ([]string, error) + wantErr string + wantOutput []string + assertCalls func(t *testing.T, r *mockMannequinReclaimer) + }{ + { + name: "single reclaim happy path", + args: []string{"--github-target-org", "FooOrg", "--mannequin-user", "mona", "--target-user", "mona_emu"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + wantOutput: []string{"Reclaiming Mannequin..."}, + assertCalls: func(t *testing.T, r *mockMannequinReclaimer) { + require.NotNil(t, r.gotReclaimArgs) + assert.Equal(t, "mona", r.gotReclaimArgs.MannequinUser) + assert.Equal(t, "", r.gotReclaimArgs.MannequinID) + assert.Equal(t, "mona_emu", r.gotReclaimArgs.TargetUser) + assert.Equal(t, "FooOrg", r.gotReclaimArgs.Org) + assert.False(t, r.gotReclaimArgs.Force) + assert.False(t, r.gotReclaimArgs.SkipInvitation) + }, + }, + { + name: "single reclaim with mannequin ID", + args: []string{"--github-target-org", "FooOrg", "--mannequin-user", "mona", "--mannequin-id", "m1", "--target-user", "mona_emu"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + assertCalls: func(t *testing.T, r *mockMannequinReclaimer) { + require.NotNil(t, r.gotReclaimArgs) + assert.Equal(t, "m1", r.gotReclaimArgs.MannequinID) + }, + }, + { + name: "CSV reclaim happy path", + args: []string{"--github-target-org", "FooOrg", "--csv", "file.csv"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + fileExists: func(string) bool { return true }, + readFile: func(string) ([]string, error) { + return []string{"header", "line1"}, nil + }, + wantOutput: []string{"Reclaiming Mannequins with CSV"}, + assertCalls: func(t *testing.T, r *mockMannequinReclaimer) { + require.NotNil(t, r.gotReclaimsArgs) + assert.Equal(t, []string{"header", "line1"}, r.gotReclaimsArgs.Lines) + assert.Equal(t, "FooOrg", r.gotReclaimsArgs.Org) + }, + }, + { + name: "CSV takes precedence over single", + args: []string{"--github-target-org", "FooOrg", "--csv", "file.csv", "--mannequin-user", "mona", "--target-user", "target"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + fileExists: func(string) bool { return true }, + readFile: func(string) ([]string, error) { return []string{}, nil }, + assertCalls: func(t *testing.T, r *mockMannequinReclaimer) { + require.NotNil(t, r.gotReclaimsArgs, "should use CSV mode") + assert.Nil(t, r.gotReclaimArgs, "should not use single mode") + }, + }, + { + name: "CSV file does not exist", + args: []string{"--github-target-org", "FooOrg", "--csv", "nonexistent.csv"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + fileExists: func(string) bool { return false }, + wantErr: "does not exist", + }, + { + name: "skip-invitation with admin user and no-prompt", + args: []string{"--github-target-org", "FooOrg", "--csv", "file.csv", "--skip-invitation", "--no-prompt"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{ + loginName: "admin_user", + membership: "admin", + }, + fileExists: func(string) bool { return true }, + readFile: func(string) ([]string, error) { return []string{}, nil }, + assertCalls: func(t *testing.T, r *mockMannequinReclaimer) { + require.NotNil(t, r.gotReclaimsArgs) + assert.True(t, r.gotReclaimsArgs.SkipInvitation) + }, + }, + { + name: "skip-invitation without no-prompt returns error", + args: []string{"--github-target-org", "FooOrg", "--csv", "file.csv", "--skip-invitation"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{ + loginName: "admin_user", + membership: "admin", + }, + wantErr: "--no-prompt", + }, + { + name: "skip-invitation non-admin returns error", + args: []string{"--github-target-org", "FooOrg", "--csv", "file.csv", "--skip-invitation", "--no-prompt"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{ + loginName: "regular_user", + membership: "member", + }, + wantErr: "not an org admin", + }, + { + name: "missing org flag", + args: []string{}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + wantErr: "--github-target-org must be provided", + }, + { + name: "org is URL", + args: []string{"--github-target-org", "https://github.com/my-org", "--mannequin-user", "m", "--target-user", "t"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + wantErr: "expects an organization name, not a URL", + }, + { + name: "neither csv nor mannequin-user/target-user", + args: []string{"--github-target-org", "FooOrg"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + wantErr: "Either --csv or --mannequin-user and --target-user must be specified", + }, + { + name: "mannequin-user without target-user", + args: []string{"--github-target-org", "FooOrg", "--mannequin-user", "mona"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + wantErr: "Either --csv or --mannequin-user and --target-user must be specified", + }, + { + name: "force flag passed through", + args: []string{"--github-target-org", "FooOrg", "--mannequin-user", "mona", "--target-user", "target", "--force"}, + reclaimer: &mockMannequinReclaimer{}, + api: &mockMannequinReclaimAPI{}, + assertCalls: func(t *testing.T, r *mockMannequinReclaimer) { + require.NotNil(t, r.gotReclaimArgs) + assert.True(t, r.gotReclaimArgs.Force) + }, + }, + { + name: "reclaim service error propagates", + args: []string{"--github-target-org", "FooOrg", "--mannequin-user", "mona", "--target-user", "target"}, + reclaimer: &mockMannequinReclaimer{reclaimErr: fmt.Errorf("reclaim failed")}, + api: &mockMannequinReclaimAPI{}, + wantErr: "reclaim failed", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + fileExists := tc.fileExists + if fileExists == nil { + fileExists = func(string) bool { return true } + } + readFile := tc.readFile + if readFile == nil { + readFile = func(string) ([]string, error) { return nil, nil } + } + + cmd := newReclaimMannequinCmd(tc.reclaimer, tc.api, log, fileExists, readFile) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs(tc.args) + + err := cmd.Execute() + + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + + output := buf.String() + for _, want := range tc.wantOutput { + assert.Contains(t, output, want, "expected output to contain %q", want) + } + + if tc.assertCalls != nil { + tc.assertCalls(t, tc.reclaimer) + } + }) + } +} diff --git a/cmd/gei/revoke_migrator_role.go b/cmd/gei/revoke_migrator_role.go new file mode 100644 index 000000000..a571af048 --- /dev/null +++ b/cmd/gei/revoke_migrator_role.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "strings" + + "github.com/github/gh-gei/pkg/logger" + "github.com/spf13/cobra" +) + +// migratorRoleRevoker is the consumer-defined interface for revoking migrator roles. +type migratorRoleRevoker interface { + GetOrganizationId(ctx context.Context, org string) (string, error) + RevokeMigratorRole(ctx context.Context, orgID, actor, actorType string) (bool, error) +} + +// newRevokeMigratorRoleCmd creates the revoke-migrator-role cobra command. +func newRevokeMigratorRoleCmd(gh migratorRoleRevoker, log *logger.Logger) *cobra.Command { + var ( + githubOrg string + actor string + actorType string + ) + + cmd := &cobra.Command{ + Use: "revoke-migrator-role", + Short: "Revokes the migrator role from a user or team for a GitHub organization", + Long: "Revokes the migrator role from a user or team for a GitHub organization.", + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateMigratorRoleArgs(githubOrg, actor, actorType, cmd); err != nil { + return err + } + actorType = strings.ToUpper(actorType) + return runRevokeMigratorRole(cmd.Context(), gh, log, githubOrg, actor, actorType) + }, + } + + cmd.Flags().StringVar(&githubOrg, "github-org", "", "The GitHub organization to revoke the migrator role for (REQUIRED)") + cmd.Flags().StringVar(&actor, "actor", "", "The user or team to revoke the migrator role from (REQUIRED)") + cmd.Flags().StringVar(&actorType, "actor-type", "", "The type of the actor (USER or TEAM) (REQUIRED)") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + cmd.Flags().String("ghes-api-url", "", "API URL for the source GHES instance") + + return cmd +} + +func runRevokeMigratorRole(ctx context.Context, gh migratorRoleRevoker, log *logger.Logger, githubOrg, actor, actorType string) error { + log.Info("Revoking migrator role ...") + + orgID, err := gh.GetOrganizationId(ctx, githubOrg) + if err != nil { + return err + } + + success, err := gh.RevokeMigratorRole(ctx, orgID, actor, actorType) + if err != nil { + return err + } + + if success { + log.Success("Migrator role successfully revoked for the %s \"%s\"", actorType, actor) + } else { + log.Errorf("Migrator role couldn't be revoked for the %s \"%s\"", actorType, actor) + } + + return nil +} diff --git a/cmd/gei/revoke_migrator_role_test.go b/cmd/gei/revoke_migrator_role_test.go new file mode 100644 index 000000000..11bcad149 --- /dev/null +++ b/cmd/gei/revoke_migrator_role_test.go @@ -0,0 +1,151 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMigratorRoleRevoker implements migratorRoleRevoker for testing. +type mockMigratorRoleRevoker struct { + orgID string + orgIDErr error + revokeResult bool + revokeErr error + gotOrg string + gotOrgID string + gotActor string + gotType string +} + +func (m *mockMigratorRoleRevoker) GetOrganizationId(_ context.Context, org string) (string, error) { + m.gotOrg = org + return m.orgID, m.orgIDErr +} + +func (m *mockMigratorRoleRevoker) RevokeMigratorRole(_ context.Context, orgID, actor, actorType string) (bool, error) { + m.gotOrgID = orgID + m.gotActor = actor + m.gotType = actorType + return m.revokeResult, m.revokeErr +} + +func TestRevokeMigratorRole(t *testing.T) { + tests := []struct { + name string + args []string + mock *mockMigratorRoleRevoker + wantErr string + wantOutput []string + assertArgs func(t *testing.T, m *mockMigratorRoleRevoker) + }{ + { + name: "revoke succeeds", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleRevoker{orgID: "ORG_ID_123", revokeResult: true}, + wantOutput: []string{ + "Revoking migrator role ...", + `Migrator role successfully revoked for the USER "monalisa"`, + }, + assertArgs: func(t *testing.T, m *mockMigratorRoleRevoker) { + assert.Equal(t, "my-org", m.gotOrg) + assert.Equal(t, "ORG_ID_123", m.gotOrgID) + assert.Equal(t, "monalisa", m.gotActor) + assert.Equal(t, "USER", m.gotType) + }, + }, + { + name: "revoke succeeds with lowercase actor type", + args: []string{"--github-org", "my-org", "--actor", "my-team", "--actor-type", "team"}, + mock: &mockMigratorRoleRevoker{orgID: "ORG_ID_123", revokeResult: true}, + wantOutput: []string{ + `Migrator role successfully revoked for the TEAM "my-team"`, + }, + }, + { + name: "revoke fails returns false", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleRevoker{orgID: "ORG_ID_123", revokeResult: false}, + wantOutput: []string{ + `Migrator role couldn't be revoked for the USER "monalisa"`, + }, + }, + { + name: "GetOrganizationId error propagates", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleRevoker{orgIDErr: fmt.Errorf("org not found")}, + wantErr: "org not found", + }, + { + name: "invalid actor type", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "INVALID"}, + mock: &mockMigratorRoleRevoker{}, + wantErr: "Actor type must be either TEAM or USER.", + }, + { + name: "github-org is URL", + args: []string{"--github-org", "https://github.com/my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleRevoker{}, + wantErr: "The --github-org option expects an organization name, not a URL", + }, + { + name: "ghes-api-url and target-api-url both set", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER", "--ghes-api-url", "https://ghes.example.com", "--target-api-url", "https://api.github.com"}, + mock: &mockMigratorRoleRevoker{}, + wantErr: "Only one of --ghes-api-url or --target-api-url can be set at a time.", + }, + { + name: "RevokeMigratorRole error propagates", + args: []string{"--github-org", "my-org", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleRevoker{orgID: "ORG_ID_123", revokeErr: fmt.Errorf("permission denied")}, + wantErr: "permission denied", + }, + { + name: "empty github-org", + args: []string{"--github-org", "", "--actor", "monalisa", "--actor-type", "USER"}, + mock: &mockMigratorRoleRevoker{}, + wantErr: "--github-org must be provided", + }, + { + name: "empty actor", + args: []string{"--github-org", "my-org", "--actor", "", "--actor-type", "USER"}, + mock: &mockMigratorRoleRevoker{}, + wantErr: "--actor must be provided", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + cmd := newRevokeMigratorRoleCmd(tc.mock, log) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + cmd.SetArgs(tc.args) + + err := cmd.Execute() + + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + + output := buf.String() + for _, want := range tc.wantOutput { + assert.Contains(t, output, want, "expected output to contain %q", want) + } + + if tc.assertArgs != nil { + tc.assertArgs(t, tc.mock) + } + }) + } +} diff --git a/cmd/gei/wait_for_migration.go b/cmd/gei/wait_for_migration.go new file mode 100644 index 000000000..1e6dae8f0 --- /dev/null +++ b/cmd/gei/wait_for_migration.go @@ -0,0 +1,170 @@ +package main + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/github/gh-gei/internal/cmdutil" + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/github/gh-gei/pkg/migration" + "github.com/spf13/cobra" +) + +const ( + repoMigrationIDPrefix = "RM_" + orgMigrationIDPrefix = "OM_" + defaultPollInterval = 60 * time.Second +) + +// migrationWaiter is the consumer-defined interface for waiting on migrations. +type migrationWaiter interface { + GetMigration(ctx context.Context, id string) (*github.Migration, error) + GetOrganizationMigration(ctx context.Context, id string) (*github.OrgMigration, error) +} + +// newWaitForMigrationCmd creates the wait-for-migration cobra command. +// pollInterval controls how long to sleep between status polls; pass 0 in tests. +func newWaitForMigrationCmd(gh migrationWaiter, log *logger.Logger, pollInterval time.Duration) *cobra.Command { + var migrationID string + + cmd := &cobra.Command{ + Use: "wait-for-migration", + Short: "Waits for a migration to finish", + Long: "Polls the migration status API until a repository or organization migration completes or fails.", + RunE: func(cmd *cobra.Command, args []string) error { + if err := validateMigrationID(migrationID); err != nil { + return err + } + return runWaitForMigration(cmd.Context(), gh, log, migrationID, pollInterval) + }, + } + + cmd.Flags().StringVar(&migrationID, "migration-id", "", "The ID of the migration to wait for (REQUIRED)") + cmd.Flags().String("github-target-pat", "", "Personal access token for the target GitHub instance") + cmd.Flags().String("target-api-url", "", "API URL for the target GitHub instance") + + return cmd +} + +func validateMigrationID(id string) error { + if strings.TrimSpace(id) == "" { + return cmdutil.NewUserError("--migration-id must be provided") + } + if !strings.HasPrefix(id, repoMigrationIDPrefix) && !strings.HasPrefix(id, orgMigrationIDPrefix) { + return cmdutil.NewUserErrorf("Invalid migration id: %s", id) + } + return nil +} + +func runWaitForMigration(ctx context.Context, gh migrationWaiter, log *logger.Logger, migrationID string, pollInterval time.Duration) error { + if strings.HasPrefix(migrationID, repoMigrationIDPrefix) { + return waitForRepoMigration(ctx, gh, log, migrationID, pollInterval) + } + return waitForOrgMigration(ctx, gh, log, migrationID, pollInterval) +} + +func waitForRepoMigration(ctx context.Context, gh migrationWaiter, log *logger.Logger, migrationID string, pollInterval time.Duration) error { + log.Info("Waiting for migration (ID: %s) to finish...", migrationID) + + m, err := gh.GetMigration(ctx, migrationID) + if err != nil { + return err + } + + log.Info("Waiting for migration of repository %s to finish...", m.RepositoryName) + + for { + if migration.IsRepoSucceeded(m.State) { + log.Success("Migration %s succeeded for %s", migrationID, m.RepositoryName) + logWarningsCount(log, m.WarningsCount) + log.Info("Migration log available at %s or by running `gh gei download-logs`", m.MigrationLogURL) + return nil + } + + if migration.IsRepoFailed(m.State) { + log.Errorf("Migration %s failed for %s", migrationID, m.RepositoryName) + logWarningsCount(log, m.WarningsCount) + log.Info("Migration log available at %s or by running `gh gei download-logs`", m.MigrationLogURL) + return cmdutil.NewUserError(m.FailureReason) + } + + log.Info("Migration %s for %s is %s", migrationID, m.RepositoryName, m.State) + log.Info("Waiting %s...", formatPollInterval(pollInterval)) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(pollInterval): + } + + m, err = gh.GetMigration(ctx, migrationID) + if err != nil { + return err + } + } +} + +func waitForOrgMigration(ctx context.Context, gh migrationWaiter, log *logger.Logger, migrationID string, pollInterval time.Duration) error { + m, err := gh.GetOrganizationMigration(ctx, migrationID) + if err != nil { + return err + } + + log.Info("Waiting for %s -> %s migration (ID: %s) to finish...", m.SourceOrgURL, m.TargetOrgName, migrationID) + + for { + if migration.IsOrgSucceeded(m.State) { + log.Success("Migration %s succeeded", migrationID) + return nil + } + + if migration.IsOrgFailed(m.State) { + return cmdutil.NewUserErrorf("Migration %s failed for %s -> %s. Failure reason: %s", + migrationID, m.SourceOrgURL, m.TargetOrgName, m.FailureReason) + } + + if migration.IsOrgRepoMigration(m.State) { + completed := m.TotalRepositoriesCount - m.RemainingRepositoriesCount + log.Info("Migration %s is %s - %d/%d repositories completed", + migrationID, m.State, completed, m.TotalRepositoriesCount) + } else { + log.Info("Migration %s is %s", migrationID, m.State) + } + + log.Info("Waiting %s...", formatPollInterval(pollInterval)) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(pollInterval): + } + + m, err = gh.GetOrganizationMigration(ctx, migrationID) + if err != nil { + return err + } + } +} + +// logWarningsCount logs warnings encountered during migration, matching C# WarningsCountLogger. +func logWarningsCount(log *logger.Logger, count int) { + switch count { + case 0: + // no output + case 1: + log.Warning("1 warning encountered during this migration") + default: + log.Warning("%d warnings encountered during this migration", count) + } +} + +func formatPollInterval(d time.Duration) string { + secs := int(d.Seconds()) + if secs == 0 { + return "0 seconds" + } + return fmt.Sprintf("%d seconds", secs) +} diff --git a/cmd/gei/wait_for_migration_test.go b/cmd/gei/wait_for_migration_test.go new file mode 100644 index 000000000..9fe736b38 --- /dev/null +++ b/cmd/gei/wait_for_migration_test.go @@ -0,0 +1,279 @@ +package main + +import ( + "bytes" + "context" + "fmt" + "testing" + "time" + + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockMigrationWaiter implements the migrationWaiter interface for testing. +type mockMigrationWaiter struct { + getMigrationResults []*github.Migration + getMigrationErrors []error + getMigrationCallCount int + getOrgMigrationResults []*github.OrgMigration + getOrgMigrationErrors []error + getOrgMigrationCallCount int +} + +func (m *mockMigrationWaiter) GetMigration(_ context.Context, _ string) (*github.Migration, error) { + i := m.getMigrationCallCount + m.getMigrationCallCount++ + if i < len(m.getMigrationResults) { + var err error + if i < len(m.getMigrationErrors) { + err = m.getMigrationErrors[i] + } + return m.getMigrationResults[i], err + } + return nil, fmt.Errorf("unexpected call to GetMigration (call %d)", i) +} + +func (m *mockMigrationWaiter) GetOrganizationMigration(_ context.Context, _ string) (*github.OrgMigration, error) { + i := m.getOrgMigrationCallCount + m.getOrgMigrationCallCount++ + if i < len(m.getOrgMigrationResults) { + var err error + if i < len(m.getOrgMigrationErrors) { + err = m.getOrgMigrationErrors[i] + } + return m.getOrgMigrationResults[i], err + } + return nil, fmt.Errorf("unexpected call to GetOrganizationMigration (call %d)", i) +} + +func TestWaitForMigration(t *testing.T) { + tests := []struct { + name string + migrationID string + mock *mockMigrationWaiter + wantErr string + wantOutput []string // substrings that must appear in output + }{ + { + name: "repo migration succeeds immediately", + migrationID: "RM_123", + mock: &mockMigrationWaiter{ + getMigrationResults: []*github.Migration{ + {State: "SUCCEEDED", RepositoryName: "my-repo", WarningsCount: 0, MigrationLogURL: "https://example.com/log"}, + }, + }, + wantOutput: []string{"succeeded for my-repo", "Migration log available at https://example.com/log"}, + }, + { + name: "repo migration succeeds after 2 polls", + migrationID: "RM_456", + mock: &mockMigrationWaiter{ + getMigrationResults: []*github.Migration{ + {State: "IN_PROGRESS", RepositoryName: "my-repo"}, + {State: "IN_PROGRESS", RepositoryName: "my-repo"}, + {State: "SUCCEEDED", RepositoryName: "my-repo", WarningsCount: 2, MigrationLogURL: "https://example.com/log"}, + }, + }, + wantOutput: []string{"succeeded for my-repo", "2 warnings encountered during this migration"}, + }, + { + name: "repo migration fails", + migrationID: "RM_789", + mock: &mockMigrationWaiter{ + getMigrationResults: []*github.Migration{ + {State: "FAILED", RepositoryName: "my-repo", FailureReason: "something broke", WarningsCount: 1, MigrationLogURL: "https://example.com/log"}, + }, + }, + wantErr: "something broke", + wantOutput: []string{"failed for my-repo", "1 warning encountered during this migration"}, + }, + { + name: "org migration succeeds immediately", + migrationID: "OM_100", + mock: &mockMigrationWaiter{ + getOrgMigrationResults: []*github.OrgMigration{ + {State: "SUCCEEDED", SourceOrgURL: "https://github.com/src-org", TargetOrgName: "target-org"}, + }, + }, + wantOutput: []string{"succeeded"}, + }, + { + name: "org migration fails", + migrationID: "OM_200", + mock: &mockMigrationWaiter{ + getOrgMigrationResults: []*github.OrgMigration{ + {State: "FAILED", SourceOrgURL: "https://github.com/src-org", TargetOrgName: "target-org", FailureReason: "org migration broke"}, + }, + }, + wantErr: "org migration broke", + }, + { + name: "org migration in repo_migration phase shows progress", + migrationID: "OM_300", + mock: &mockMigrationWaiter{ + getOrgMigrationResults: []*github.OrgMigration{ + {State: "REPO_MIGRATION", SourceOrgURL: "https://github.com/src-org", TargetOrgName: "target-org", TotalRepositoriesCount: 10, RemainingRepositoriesCount: 7}, + {State: "SUCCEEDED", SourceOrgURL: "https://github.com/src-org", TargetOrgName: "target-org"}, + }, + }, + wantOutput: []string{"3/10 repositories completed", "succeeded"}, + }, + { + name: "invalid migration ID prefix", + migrationID: "XX_invalid", + mock: &mockMigrationWaiter{}, + wantErr: "Invalid migration id: XX_invalid", + }, + { + name: "missing migration ID", + migrationID: "", + mock: &mockMigrationWaiter{}, + wantErr: "--migration-id must be provided", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + cmd := newWaitForMigrationCmd(tc.mock, log, 0) + cmd.SetOut(&buf) + cmd.SetErr(&buf) + + args := []string{} + if tc.migrationID != "" { + args = append(args, "--migration-id", tc.migrationID) + } + cmd.SetArgs(args) + + err := cmd.Execute() + + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + } else { + require.NoError(t, err) + } + + output := buf.String() + for _, want := range tc.wantOutput { + assert.Contains(t, output, want, "expected output to contain %q", want) + } + }) + } +} + +func TestWaitForMigration_RepoMigrationWarningCounts(t *testing.T) { + tests := []struct { + name string + warningsCount int + wantWarning string + wantNoWarning bool + }{ + { + name: "zero warnings logs nothing", + warningsCount: 0, + wantNoWarning: true, + }, + { + name: "one warning logs singular", + warningsCount: 1, + wantWarning: "1 warning encountered during this migration", + }, + { + name: "multiple warnings logs plural", + warningsCount: 5, + wantWarning: "5 warnings encountered during this migration", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var buf bytes.Buffer + log := logger.New(false, &buf) + + mock := &mockMigrationWaiter{ + getMigrationResults: []*github.Migration{ + {State: "SUCCEEDED", RepositoryName: "repo", WarningsCount: tc.warningsCount, MigrationLogURL: "https://example.com/log"}, + }, + } + + cmd := newWaitForMigrationCmd(mock, log, 0) + cmd.SetArgs([]string{"--migration-id", "RM_test"}) + + err := cmd.Execute() + require.NoError(t, err) + + output := buf.String() + if tc.wantNoWarning { + assert.NotContains(t, output, "warning") + } else { + assert.Contains(t, output, tc.wantWarning) + } + }) + } +} + +// Verify that polling actually happens (calls > 1 when initial state is pending). +func TestWaitForMigration_RepoPolling(t *testing.T) { + mock := &mockMigrationWaiter{ + getMigrationResults: []*github.Migration{ + {State: "IN_PROGRESS", RepositoryName: "repo"}, + {State: "SUCCEEDED", RepositoryName: "repo", MigrationLogURL: "https://example.com/log"}, + }, + } + + var buf bytes.Buffer + log := logger.New(false, &buf) + cmd := newWaitForMigrationCmd(mock, log, time.Duration(0)) + cmd.SetArgs([]string{"--migration-id", "RM_poll"}) + + err := cmd.Execute() + require.NoError(t, err) + assert.Equal(t, 2, mock.getMigrationCallCount, "expected 2 calls to GetMigration for polling") +} + +// Verify org polling calls > 1 when initial state is pending. +func TestWaitForMigration_OrgPolling(t *testing.T) { + mock := &mockMigrationWaiter{ + getOrgMigrationResults: []*github.OrgMigration{ + {State: "IN_PROGRESS", SourceOrgURL: "https://github.com/org", TargetOrgName: "target"}, + {State: "SUCCEEDED", SourceOrgURL: "https://github.com/org", TargetOrgName: "target"}, + }, + } + + var buf bytes.Buffer + log := logger.New(false, &buf) + cmd := newWaitForMigrationCmd(mock, log, time.Duration(0)) + cmd.SetArgs([]string{"--migration-id", "OM_poll"}) + + err := cmd.Execute() + require.NoError(t, err) + assert.Equal(t, 2, mock.getOrgMigrationCallCount, "expected 2 calls to GetOrganizationMigration for polling") +} + +// Verify that context cancellation stops the polling loop. +func TestWaitForMigration_ContextCancellation(t *testing.T) { + mock := &mockMigrationWaiter{ + getMigrationResults: []*github.Migration{ + {State: "IN_PROGRESS", RepositoryName: "repo"}, + {State: "IN_PROGRESS", RepositoryName: "repo"}, + {State: "IN_PROGRESS", RepositoryName: "repo"}, + }, + } + + var buf bytes.Buffer + log := logger.New(false, &buf) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel immediately so the poll loop exits on the first select. + cancel() + + err := runWaitForMigration(ctx, mock, log, "RM_cancel", 10*time.Second) + require.Error(t, err) + assert.ErrorIs(t, err, context.Canceled) +} diff --git a/docs/specs/2026-03-30-go-port-design.md b/docs/specs/2026-03-30-go-port-design.md new file mode 100644 index 000000000..51ea85493 --- /dev/null +++ b/docs/specs/2026-03-30-go-port-design.md @@ -0,0 +1,529 @@ +# gh-gei Go Port — Design Specification + +**Date:** 2026-03-30 +**Status:** Draft +**Author:** offbyone + AI assistant + +## Overview + +Port the gh-gei CLI suite (gei, ado2gh, bbs2gh) from .NET 8 / C# 12 to Go, producing functionally identical binaries with the same command-line interface, output format, and script generation behavior. The Go and C# codebases coexist in the same repository during the transition. + +## Goals + +1. **Binary compatibility**: Same binary names (`gei`, `ado2gh`, `bbs2gh`), same subcommands, same flags, same output +2. **Script compatibility**: Generated PowerShell scripts must be identical (they invoke the CLI as `gh gei`, etc.) +3. **Test compatibility**: Existing e2e tests pass against Go binaries (only build steps change, not validation) +4. **Log compatibility**: Same log file format (`.octoshift.log`, `.octoshift.verbose.log`) so e2e log assertions pass + - this is the loosest of the goals; if the log assertions assume C# traces, then the e2e test may need to be amended. +5. **Coexistence**: Both codebases live in the same repo for side-by-side inspection +6. **Idiomatic Go**: Use Go conventions (consumer-defined interfaces, explicit wiring, stdlib patterns) rather than mirroring C# architecture +7. **Replacement**: When e2e compatibility is achieved, this will completely replace the C# version + +## Non-Goals + +- Changing the user-facing CLI interface +- Changing what the generated scripts do +- Rewriting the e2e test harness in Go (deferred to a later phase) +- Supporting new features not in the C# version + +## Decisions + +| Decision | Choice | Rationale | +|----------|--------|-----------| +| CLI framework | Cobra | Already established in Phase 1/2, de facto Go standard, same as `gh` CLI | +| GitHub API | `google/go-github` + custom GraphQL | go-github covers REST; migration APIs need custom GraphQL | +| ADO/BBS HTTP | `imroc/req` | Batteries-included: retry, middleware, JSON marshaling | +| Cloud storage | Official AWS SDK v2 + Azure SDK for Go | 1:1 match with C# SDKs, well-maintained | +| Testing | `testify/assert` + `httptest` + table-driven | Already established in Phase 1/2 | +| Linting | `golangci-lint` (25+ linters) | Already configured in Phase 1 | +| Interfaces | Consumer-defined | Idiomatic Go; interfaces declared where used, not where implemented | +| DI | Explicit wiring in main.go | No container; Cobra command constructors accept dependencies | +| Coexistence | Side-by-side at repo root | Go at `cmd/`, `pkg/`, `go.mod`; C# stays in `src/` | +| Integration tests | Hybrid: keep C# harness initially, port to Go later | C# tests are black-box (shell out to `gh` extensions), work against any binary | + +## Architecture + +### Package Layout + +``` +cmd/ + gei/ + main.go # Root command, subcommand registration + migrate_repo.go # migrate-repo command + migrate_org.go # migrate-org command + generate_script.go # generate-script command (exists in Phase 2) + wait_for_migration.go # wait-for-migration command + abort_migration.go # abort-migration command + download_logs.go # download-logs command + create_team.go # create-team command + grant_migrator_role.go # grant-migrator-role command + revoke_migrator_role.go # revoke-migrator-role command + reclaim_mannequin.go # reclaim-mannequin command + generate_mannequin_csv.go # generate-mannequin-csv command + migrate_secret_alerts.go # migrate-secret-alerts command + migrate_code_scanning.go # migrate-code-scanning-alerts command + ado2gh/ + main.go + migrate_repo.go + generate_script.go + # ... 19 commands total (see command inventory below) + bbs2gh/ + main.go + migrate_repo.go + generate_script.go + inventory_report.go + # ... 11 commands total + +pkg/ + github/ + client.go # GitHub API client (REST via go-github + GraphQL) + client_test.go + graphql.go # GraphQL mutation/query helpers + graphql_test.go + models.go # GitHub-specific types + ado/ + client.go # ADO API client (imroc/req) + client_test.go + models.go + bbs/ + client.go # BBS API client (imroc/req) + client_test.go + models.go + storage/ + azure/ + client.go # Azure Blob Storage operations + client_test.go + aws/ + client.go # AWS S3 operations + client_test.go + ghowned/ + client.go # GitHub-owned storage multipart upload + client_test.go + scriptgen/ + generator.go # Script generation engine (exists in Phase 2) + generator_test.go + templates.go # PowerShell template constants + logger/ + logger.go # OctoLogger equivalent (exists in Phase 1) + logger_test.go + env/ + env.go # Environment variable provider (exists in Phase 1) + env_test.go + retry/ + retry.go # Retry policy (exists in Phase 1) + retry_test.go + filesystem/ + filesystem.go # File system operations (exists in Phase 1) + filesystem_test.go + version/ + checker.go # CLI version checking against latest release + checker_test.go + status/ + github.go # githubstatus.com API + github_test.go + confirmation/ + prompt.go # Interactive Y/N confirmation + prompt_test.go + archive/ + uploader.go # Archive upload orchestration (Azure/AWS/GH-owned) + uploader_test.go + http/ + client.go # Shared HTTP client (Phase 1; stdlib + retry) + client_test.go + app/ + app.go # Centralized DI provider struct (Phase 1) + app_test.go + +internal/ + cmdutil/ + flags.go # Shared flag definitions and validation helpers + errors.go # OctoshiftCliError equivalent + testutil/ + httpmock.go # Shared HTTP test helpers + fixtures.go # Test data loading +``` + +**Migration path for Phase 1/2 packages:** + +- **`pkg/http`** — A thin stdlib wrapper providing retry, TLS config, and basic GET/POST/PUT/DELETE. The GitHub client will migrate to `google/go-github` (which manages its own HTTP transport). ADO and BBS clients will migrate to `imroc/req`, which supersedes `pkg/http` with built-in retry, middleware, and JSON marshaling. `pkg/http` will be removed once all consumers are migrated. +- **`pkg/app`** — A centralized DI provider struct (`App` with `Logger`, `Env`, `FileSystem`, `Retry` fields). The target architecture replaces this with explicit wiring in each `cmd/*/main.go`: command constructors accept their dependencies directly, and `pkg/app` is removed. No Wire or container — just constructor calls. + +### Command Pattern + +Each command is a function returning `*cobra.Command`. Dependencies are injected via function parameters. Interfaces are defined at the point of consumption. + +```go +// cmd/gei/migrate_repo.go + +// Only the methods this command actually calls +type migrationStarter interface { + StartMigration(ctx context.Context, opts github.MigrateOpts) (string, error) + GetMigrationState(ctx context.Context, migrationID string) (string, error) +} + +type migrationLogger interface { + GetMigrationLogURL(ctx context.Context, org, migrationID string) (string, error) +} + +func newMigrateRepoCmd(gh migrationStarter, logs migrationLogger, log *logger.Logger, env *env.Env) *cobra.Command { + var opts struct { + sourceOrg string + sourceRepo string + targetOrg string + targetRepo string + targetPAT string + // ... all flags + } + + cmd := &cobra.Command{ + Use: "migrate-repo", + Short: "Migrate a repository", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + // validation, business logic, API calls + migrationID, err := gh.StartMigration(ctx, github.MigrateOpts{...}) + if err != nil { + return fmt.Errorf("starting migration: %w", err) + } + // poll for completion, download logs on failure, etc. + return nil + }, + } + + cmd.Flags().StringVar(&opts.sourceOrg, "github-source-org", "", "...") + // ... register all flags + return cmd +} +``` + +### CLI Wiring (main.go) + +> **Refactor note:** Phase 1/2 used context-value injection for passing dependencies to command handlers (e.g., `context.WithValue(ctx, "logger", log)` in `PersistentPreRun`, retrieved via `cmd.Context().Value("logger")`). The target architecture replaces this with explicit constructor injection — each command constructor accepts its dependencies as typed parameters — which is more type-safe and testable. The `getLogger(cmd)` / `getEnvProvider()` helper pattern will be removed. + +```go +// cmd/gei/main.go +func main() { + log := logger.New(os.Stdout, os.Stderr) + envProvider := env.New() + + root := &cobra.Command{ + Use: "gei", + Short: "GitHub Enterprise Importer", + Version: version, + } + root.PersistentFlags().BoolVar(&verbose, "verbose", false, "...") + + // Wire dependencies and register commands + ghClient := github.NewClient(envProvider.TargetGithubPAT(), github.WithLogger(log)) + + root.AddCommand( + newMigrateRepoCmd(ghClient, ghClient, log, envProvider), + newGenerateScriptCmd(ghClient, log, envProvider), + newWaitForMigrationCmd(ghClient, log), + // ... all commands + ) + + if err := root.Execute(); err != nil { + os.Exit(1) + } +} +``` + +### API Client Design + +**GitHub (`pkg/github`):** + +```go +type Client struct { + rest *gogithub.Client // google/go-github for REST + graphql *graphqlClient // custom thin wrapper for migration mutations + logger *logger.Logger +} + +// REST operations (delegated to go-github) +func (c *Client) GetRepos(ctx context.Context, org string) ([]Repo, error) +func (c *Client) GetTeamMembers(ctx context.Context, org, team string) ([]string, error) + +// GraphQL operations (custom, migration-specific) +func (c *Client) CreateMigrationSource(ctx context.Context, opts MigrationSourceOpts) (string, error) +func (c *Client) StartRepositoryMigration(ctx context.Context, opts MigrateOpts) (string, error) +func (c *Client) GetMigrationState(ctx context.Context, migrationID string) (string, error) +func (c *Client) StartOrganizationMigration(ctx context.Context, opts OrgMigrateOpts) (string, error) +``` + +**ADO (`pkg/ado`):** + +```go +type Client struct { + http *req.Client + pat string + logger *logger.Logger +} + +func NewClient(baseURL, pat string, opts ...Option) *Client +func (c *Client) GetTeamProjects(ctx context.Context, org string) ([]TeamProject, error) +func (c *Client) GetRepos(ctx context.Context, org, project string) ([]Repo, error) +func (c *Client) DisableRepo(ctx context.Context, org, project, repoID string) error +func (c *Client) LockRepo(ctx context.Context, org, project, repoID string) error +// ... ~20 methods mapping to ADO REST API endpoints +``` + +**BBS (`pkg/bbs`):** + +```go +type Client struct { + http *req.Client + username string + password string + logger *logger.Logger +} + +func NewClient(baseURL, username, password string, opts ...Option) *Client +func (c *Client) GetProjects(ctx context.Context) ([]Project, error) +func (c *Client) GetRepos(ctx context.Context, projectKey string) ([]Repo, error) +func (c *Client) GetArchive(ctx context.Context, projectKey, repoSlug string) (io.ReadCloser, error) +``` + +### Migration from Phase 1/2 + +The Phase 1/2 code established initial implementations with direct parameter injection. The target architecture differs in several ways: + +- **Client constructors:** Phase 1/2 clients (e.g., `github.NewClient(cfg, httpClient, log)`, `ado.NewClient(baseURL, pat, log, httpClient)`) accept dependencies as positional parameters. The target architecture uses functional options (e.g., `github.NewClient(pat, github.WithLogger(log))`) for cleaner extensibility. Phase 3+ work will refactor existing clients to match. +- **HTTP layer:** Phase 1/2 clients depend on `pkg/http.Client` (a thin stdlib wrapper). The target architecture has `pkg/github` using `google/go-github` and `pkg/ado` / `pkg/bbs` using `imroc/req`. The `pkg/http` package will be removed once migration is complete. +- **DI pattern:** Phase 1/2 uses `pkg/app.App` as a centralized provider struct. The target architecture wires dependencies explicitly in `main.go` via constructor injection — no central container. +- **Consumer-defined interfaces:** Phase 1/2 code does not yet define interfaces at the consumer site. The target architecture declares narrow interfaces in each command file (e.g., `migrationStarter` in `migrate_repo.go`), enabling easy mocking and testability. + +### Error Handling + +The C# codebase uses `OctoshiftCliException` for user-friendly errors vs. letting unexpected exceptions bubble up. In Go: + +```go +// internal/cmdutil/errors.go +type UserError struct { + Message string + Err error +} + +func (e *UserError) Error() string { return e.Message } +func (e *UserError) Unwrap() error { return e.Err } + +// Usage: return &cmdutil.UserError{Message: "Source org not found. Check the --github-source-org value."} +``` + +The root command's `PersistentPreRunE` or a wrapper handles the distinction: `UserError` gets a clean message; other errors get full stack trace in verbose mode. + +### Logging + +Port `OctoLogger` behavior to `pkg/logger`: + +- **Console output**: Info (stdout), Warning (yellow, stderr), Error (red, stderr) +- **File output**: `.octoshift.log` (info+), `.octoshift.verbose.log` (all including debug) +- **Secret redaction**: `logger.RegisterSecret(secret)` — all output is scrubbed +- **Warning counter**: `logger.Warnings()` returns count for summary output +- **Verbose mode**: Controlled by `--verbose` flag, enables debug output to console + +### Script Generation + +The `pkg/scriptgen` package (Phase 2) generates PowerShell scripts. The generated scripts: + +1. Define helper functions (`Exec`, `ExecAndGetMigrationID`) +2. Validate required environment variables +3. Generate `gh gei migrate-repo` / `gh ado2gh migrate-repo` / `gh bbs2gh migrate-repo` calls +4. In parallel mode: queue migrations with `--queue-only`, collect IDs, then `wait-for-migration` for each +5. In sequential mode: run each migration synchronously via `Exec` +6. Print summary (success/failure counts) + +Scripts must produce byte-identical output to the C# version. The validation script (`scripts/validate-scripts.sh`) diffs C# vs Go output. + +### Cloud Storage + +**Azure Blob (`pkg/storage/azure`):** + +```go +type Client struct { + serviceClient *azblob.Client + logger *logger.Logger +} + +func NewClient(connectionString string, opts ...Option) (*Client, error) +func (c *Client) Upload(ctx context.Context, container, blob string, data io.Reader) (string, error) +func (c *Client) Download(ctx context.Context, container, blob string) (io.ReadCloser, error) +``` + +**AWS S3 (`pkg/storage/aws`):** + +```go +type Client struct { + s3Client *s3.Client + logger *logger.Logger +} + +func NewClient(ctx context.Context, accessKey, secretKey, region string, opts ...Option) (*Client, error) +func (c *Client) Upload(ctx context.Context, bucket, key string, data io.Reader) (string, error) +``` + +**GitHub-owned storage (`pkg/storage/ghowned`):** + +```go +type Client struct { + httpClient *http.Client + logger *logger.Logger +} + +func (c *Client) Upload(ctx context.Context, uploadURL string, data io.Reader, partSizeMiB int) error +``` + +### CI/CD Changes + +**Build workflow changes:** +- Replace `setup-dotnet` with `actions/setup-go` +- Replace `dotnet build` with `go build ./cmd/...` +- Replace `dotnet test` with `go test -race -coverprofile=... ./...` +- Replace `dotnet publish` with cross-compiled `GOOS=X GOARCH=Y go build` +- Replace `dotnet format --verify-no-changes` with `golangci-lint run` +- Update CodeQL from `csharp` to `go` + +**E2e workflow changes (build steps only):** +- `build-for-e2e-test` job: replace `dotnet publish` with `go build` cross-compilation +- Binary naming convention already matches Go convention: `gei-linux-amd64`, `gei-darwin-arm64`, etc. +- Keep `dotnet test` for integration test runner (C# harness stays) + +**Validation steps unchanged:** +- Binary download/copy/chmod +- `gh extension install` +- Integration test execution (still C# xunit runner) +- Log file collection and assertion +- Test result publishing + +## Command Inventory + +### gei (13 commands) + +| Command | Shared? | Complexity | +|---------|---------|------------| +| `migrate-repo` | No | High — orchestrates migration source creation, migration start, archive upload, polling | +| `migrate-org` | No | High — organization-level migration | +| `generate-script` | No | Medium — Phase 2 started this | +| `wait-for-migration` | Yes | Low — poll migration state | +| `abort-migration` | Yes | Low — single API call | +| `download-logs` | Yes | Low — fetch log URL, download | +| `create-team` | Yes | Low — create team + set IdP | +| `grant-migrator-role` | Yes | Low — single GraphQL mutation | +| `revoke-migrator-role` | Yes | Low — single GraphQL mutation | +| `reclaim-mannequin` | Yes | Medium — mannequin reclaim logic | +| `generate-mannequin-csv` | Yes | Medium — fetch + format mannequins | +| `migrate-secret-alerts` | No | Medium — paginate + migrate alerts | +| `migrate-code-scanning-alerts` | No | Medium — paginate + migrate alerts | + +### ado2gh (19 commands) + +| Command | Shared? | Complexity | +|---------|---------|------------| +| `migrate-repo` | No | High | +| `generate-script` | No | Medium | +| `inventory-report` | No | Medium — fetch all orgs/projects/repos, generate CSV | +| `rewire-pipeline` | No | Medium — update pipeline service connections | +| `test-pipelines` | No | Medium — concurrent pipeline testing | +| `add-team-to-repo` | No | Low | +| `configure-auto-link` | No | Low | +| `disable-repo` | No | Low | +| `integrate-boards` | No | Low | +| `lock-repo` | No | Low | +| `share-service-connection` | No | Low | +| `wait-for-migration` | Yes | Low | +| `abort-migration` | Yes | Low | +| `download-logs` | Yes | Low | +| `create-team` | Yes | Low | +| `grant-migrator-role` | Yes | Low | +| `revoke-migrator-role` | Yes | Low | +| `reclaim-mannequin` | Yes | Medium | +| `generate-mannequin-csv` | Yes | Medium | + +### bbs2gh (11 commands) + +| Command | Shared? | Complexity | +|---------|---------|------------| +| `migrate-repo` | No | High — includes SSH/SMB archive download | +| `generate-script` | No | Medium | +| `inventory-report` | No | Medium | +| `wait-for-migration` | Yes | Low | +| `abort-migration` | Yes | Low | +| `download-logs` | Yes | Low | +| `create-team` | Yes | Low | +| `grant-migrator-role` | Yes | Low | +| `revoke-migrator-role` | Yes | Low | +| `reclaim-mannequin` | Yes | Medium | +| `generate-mannequin-csv` | Yes | Medium | + +## Phased Delivery Plan + +### Phase 3: Complete generate-script (PR #3 on stack) +- `gei generate-script` already exists (Phase 2); no work needed for gei +- Wire `generate-script` into ado2gh and bbs2gh CLIs (ADO/BBS-specific variants) +- Validate with `scripts/validate-scripts.sh` +- ~500 lines + +### Phase 4: Core migration commands (PR #4-5) +- `migrate-repo` for all 3 CLIs (highest complexity) +- `wait-for-migration`, `abort-migration`, `download-logs` (shared) +- GitHub GraphQL client for migration APIs +- ~3,000 lines + +### Phase 5: Cloud storage clients (PR #6) +- Azure Blob, AWS S3, GitHub-owned storage upload +- Archive upload orchestration +- ~1,500 lines + +### Phase 6: ADO-specific commands (PR #7) +- 8 ADO-only commands (lock-repo, disable-repo, rewire-pipeline, etc.) +- `inventory-report` for ado2gh and bbs2gh +- `test-pipelines` +- ~2,000 lines + +### Phase 7: Remaining commands (PR #8) +- `migrate-org` (gei only) +- Mannequin commands (reclaim, generate-csv) +- Team/role commands (create-team, grant/revoke-migrator-role) +- Alert migration commands (secret-alerts, code-scanning) +- ~2,500 lines + +### Phase 8: CI/CD integration (PR #9) +- Update `CI.yml` build steps for Go +- Update `build-for-e2e-test` for Go cross-compilation +- Update `publish` job for Go binaries +- Keep C# integration test runner +- Update `copilot-setup-steps.yml` +- ~200 lines of workflow YAML + +### Phase 9: Port integration tests to Go (PR #10+) +- Rewrite `OctoshiftCLI.IntegrationTests` in Go +- Remove C# dependency from e2e workflow +- This can be a separate project after the main port + +## Risks and Mitigations + +| Risk | Mitigation | +|------|-----------| +| Script output divergence | `scripts/validate-scripts.sh` runs in CI, diffs C# vs Go output | +| GraphQL API compatibility | Test against real GitHub API; migration mutations are documented | +| BBS SSH/SMB archive download | Go has `golang.org/x/crypto/ssh` for SSH; evaluate `hirochachacha/go-smb2` for SMB | +| go-github missing methods | go-github is comprehensive; for gaps, use raw HTTP via its `Client.NewRequest` | +| E2e test flakiness during transition | Run both C# and Go builds in CI, compare results | +| Performance differences | Go binaries will likely be faster; ensure no timeout assumptions in tests | + +## Dependencies + +``` +# Already in use (Phase 1/2) +github.com/spf13/cobra # CLI framework +github.com/stretchr/testify # Test assertions +github.com/avast/retry-go/v4 # Retry with backoff (used by pkg/retry) + +# To be added +github.com/google/go-github/v68 # GitHub REST API +github.com/imroc/req/v3 # HTTP client for ADO/BBS +github.com/aws/aws-sdk-go-v2 # AWS S3 +github.com/Azure/azure-sdk-for-go # Azure Blob Storage +golang.org/x/crypto # SSH client (for BBS) +github.com/hirochachacha/go-smb2 # SMB client (for BBS, needs evaluation) +``` diff --git a/docs/superpowers/plans/2026-03-30-go-port-implementation.md b/docs/superpowers/plans/2026-03-30-go-port-implementation.md new file mode 100644 index 000000000..e083eef1a --- /dev/null +++ b/docs/superpowers/plans/2026-03-30-go-port-implementation.md @@ -0,0 +1,1453 @@ +# gh-gei Go Port — Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Port the gh-gei CLI suite (gei, ado2gh, bbs2gh) from .NET/C# to Go, producing functionally identical binaries with the same CLI interface, output format, and script generation behavior. + +**Architecture:** Cobra-based CLIs with explicit dependency wiring in main.go. Consumer-defined interfaces at point of use. GitHub API via go-github + custom GraphQL; ADO/BBS via imroc/req. Side-by-side coexistence with C# code in the same repo. + +**Tech Stack:** Go 1.25+, Cobra, go-github/v68, imroc/req/v3, Azure SDK for Go, AWS SDK v2, testify, golangci-lint + +**Spec:** `docs/specs/2026-03-30-go-port-design.md` + +**Existing work:** PRs #1500 (Phase 1: base framework) and #1501 (Phase 2: gei generate-script) + +**PR strategy:** Stack of draft PRs on top of `o1/golang-port/2`. Each phase = 1-2 PRs. + +**Local e2e validation:** Focus on macOS. The `GithubToGithub` integration test is the simplest e2e test (only needs `GHEC_PAT`). It exercises `generate-script`, `migrate-repo`, and `download-logs`. After PR 5 is complete, run this test locally as a proof-of-concept. Rely on GitHub CI for Windows and Linux. + +--- + +## Local E2E Setup (Do Once, Before First E2E Validation) + +### Task 0: Local e2e infrastructure + +Add justfile targets for building, installing, and running Go-based e2e tests locally on macOS. + +**Prerequisites:** +- `pwsh` (PowerShell Core) installed +- `gh` CLI installed +- `GHEC_PAT` environment variable set (via direnv `.envrc.local`) +- .NET 8.0 SDK (via mise) + +**Files:** +- Modify: `justfile` (add Go extension install + e2e targets) + +- [ ] **Step 1: Add `go-install-extensions-macos` justfile target** + +```just +# Install Go binaries as gh CLI extensions (macOS) +go-install-extensions-macos: go-publish-macos + #!/usr/bin/env bash + set -euo pipefail + for cli in gei ado2gh bbs2gh; do + dir="gh-${cli}" + mkdir -p "$dir" + cp "./dist/osx-x64/${cli}-darwin-amd64" "./${dir}/gh-${cli}" + chmod +x "./${dir}/gh-${cli}" + cd "$dir" && gh extension install . --force && cd .. + done + echo "Go extensions installed successfully!" +``` + +- [ ] **Step 2: Add `go-e2e-github` justfile target** + +```just +# Run GithubToGithub integration test against Go binaries (macOS) +go-e2e-github: go-install-extensions-macos + direnv exec . dotnet test src/OctoshiftCLI.IntegrationTests/OctoshiftCLI.IntegrationTests.csproj \ + --filter "GithubToGithub" \ + --logger "console;verbosity=normal" \ + /p:VersionPrefix=9.9 +``` + +- [ ] **Step 3: Verify infrastructure works with C# binaries first** + +Run the C# e2e test to establish a baseline: +```bash +just publish-macos +just install-extensions # needs updating for macOS +direnv exec . dotnet test src/OctoshiftCLI.IntegrationTests/OctoshiftCLI.IntegrationTests.csproj \ + --filter "GithubToGithub" --logger "console;verbosity=normal" /p:VersionPrefix=9.9 +``` + +This confirms the test infrastructure, credentials, and target orgs work before we try Go binaries. + +- [ ] **Step 4: Validate `generate-script` output parity** + +After PR 3 tasks are done, use `scripts/validate-scripts.sh` to verify Go `generate-script` output matches C#. This is a lightweight validation that doesn't require real migrations. + +### E2E Milestone Checkpoints + +| After PR | Validation | What's Tested | +|----------|-----------|---------------| +| PR 3 | `scripts/validate-scripts.sh` | `gei generate-script` output parity | +| PR 5 | `just go-e2e-github` | Full `GithubToGithub` e2e (generate-script + migrate-repo + download-logs) | +| PR 6 | GitHub CI: `AdoBasic`, `AdoCsv` | ADO-to-GitHub migration | +| PR 7 | GitHub CI: `Bbs` | BBS-to-GitHub migration | +| PR 9 | All 15 CI matrix combinations | Full cross-platform e2e | + +--- + +## Chunk 1: Foundation — Shared Commands & GitHub API Client + +### PR 3: GitHub GraphQL Client + Shared Low-Complexity Commands + +This PR ports the GitHub API surface needed by the shared commands, then implements all 8 shared commands. These are the commands that appear in all three CLIs (gei, ado2gh, bbs2gh) with minimal variation. + +**Dependency changes:** +- Add `github.com/google/go-github/v68` +- Add `github.com/shurcooL/graphql` or use custom thin GraphQL client (raw HTTP + JSON, matching GithubClient.cs patterns) + +--- + +### Task 1: Port GithubClient HTTP layer to pkg/github + +The C# `GithubClient` handles REST (with Link-header pagination), GraphQL (with cursor pagination), rate limiting, and retry. The Go port splits this: REST goes through `go-github` (which handles pagination and rate limiting natively), while GraphQL uses a thin custom client. + +**Files:** +- Create: `pkg/github/graphql.go` +- Create: `pkg/github/graphql_test.go` +- Modify: `pkg/github/client.go` (add go-github integration, remove raw HTTP) +- Modify: `pkg/github/client_test.go` +- Modify: `pkg/github/models.go` (add migration models) +- Create: `pkg/github/ratelimit.go` (secondary rate limit handling) +- Create: `pkg/github/ratelimit_test.go` + +**Reference:** `src/Octoshift/Services/GithubClient.cs` (364 lines) + +- [ ] **Step 1: Write failing tests for GraphQL client** + +Test that the GraphQL client: +- Sends correct `Authorization: Bearer ` header +- Sends `GraphQL-Features: import_api,mannequin_claiming_emu,org_import_api` header +- Sends `User-Agent: OctoshiftCLI/` header +- Serializes query + variables correctly +- Parses successful responses +- Returns errors from GraphQL `errors` array +- Handles cursor-based pagination (hasNextPage/endCursor) + +```go +// pkg/github/graphql_test.go +func TestGraphQLClient_Post(t *testing.T) { + // httptest server returning canned response + // verify headers, body, parse response +} + +func TestGraphQLClient_Post_WithErrors(t *testing.T) { + // server returns {"errors": [{"message": "not found"}]} + // verify error is returned +} + +func TestGraphQLClient_PostWithPagination(t *testing.T) { + // server returns two pages with hasNextPage=true/false + // verify all results collected +} +``` + +Run: `go test ./pkg/github/ -run TestGraphQL -v` +Expected: FAIL (graphql.go doesn't exist yet) + +- [ ] **Step 2: Implement GraphQL client** + +```go +// pkg/github/graphql.go +type graphqlClient struct { + httpClient *http.Client + url string + headers map[string]string + logger *logger.Logger +} + +type graphqlRequest struct { + Query string `json:"query"` + Variables map[string]any `json:"variables,omitempty"` +} + +type graphqlResponse struct { + Data json.RawMessage `json:"data"` + Errors []graphqlError `json:"errors"` +} + +func (c *graphqlClient) Post(ctx context.Context, query string, variables map[string]any) (json.RawMessage, error) +func (c *graphqlClient) PostWithPagination(ctx context.Context, query string, variables map[string]any, dataPath string, pageInfoPath string) ([]json.RawMessage, error) +``` + +Implement secondary rate limit detection (403/429 with specific messages) with exponential backoff (60s/120s/240s, max 3 retries). Match the C# `GithubClient.HandleSecondaryRateLimitAsync` logic. + +Run: `go test ./pkg/github/ -run TestGraphQL -v` +Expected: PASS + +- [ ] **Step 3: Write failing tests for go-github REST integration** + +Test that the Client: +- Uses go-github for REST operations (repos, teams, orgs) +- Correctly maps go-github types to our domain types +- Handles pagination transparently via go-github's built-in pagination + +Run: `go test ./pkg/github/ -run TestClient_GetRepos -v` +Expected: FAIL or needs updating + +- [ ] **Step 4: Migrate Client to use go-github for REST** + +Replace the raw HTTP calls in `pkg/github/client.go` with `go-github` client calls. The `go-github` library handles pagination, rate limiting, and auth natively. + +```go +type Client struct { + rest *gogithub.Client // go-github for REST + graphql *graphqlClient // custom for migration GraphQL + logger *logger.Logger + apiURL string +} + +func NewClient(pat string, opts ...Option) (*Client, error) +``` + +Options: `WithAPIURL(url)`, `WithLogger(log)`, `WithNoSSLVerify()`, `WithUploadsURL(url)` + +Run: `go test ./pkg/github/ -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +--- + +### Task 2: Port GitHub Migration API Methods + +Port the migration-specific methods from `GithubApi.cs` that are needed by the shared commands. + +**Files:** +- Modify: `pkg/github/client.go` +- Modify: `pkg/github/client_test.go` +- Modify: `pkg/github/models.go` + +**Reference:** `src/Octoshift/Services/GithubApi.cs` — migration, org, mannequin, team, migrator-role sections + +- [ ] **Step 1: Write failing tests for organization/user queries** + +```go +func TestClient_GetOrganizationId(t *testing.T) // GraphQL query +func TestClient_GetOrganizationDatabaseId(t *testing.T) +func TestClient_GetEnterpriseId(t *testing.T) +func TestClient_GetLoginName(t *testing.T) // viewer { login } +func TestClient_GetUserId(t *testing.T) +func TestClient_DoesOrgExist(t *testing.T) // REST, handle 404 +func TestClient_GetOrgMembershipForUser(t *testing.T) // REST, handle 404 +``` + +Run: `go test ./pkg/github/ -run "TestClient_Get(Organization|Enterprise|Login|User|OrgMembership)|TestClient_DoesOrg" -v` +Expected: FAIL + +- [ ] **Step 2: Implement organization/user queries** + +Map each C# method to its Go equivalent. GraphQL methods use `graphqlClient.Post()`, REST methods use `go-github`. + +Run tests. Expected: PASS + +- [ ] **Step 3: Write failing tests for migration GraphQL mutations** + +```go +func TestClient_CreateAdoMigrationSource(t *testing.T) +func TestClient_CreateBbsMigrationSource(t *testing.T) +func TestClient_CreateGhecMigrationSource(t *testing.T) +func TestClient_StartMigration(t *testing.T) +func TestClient_StartBbsMigration(t *testing.T) +func TestClient_StartOrganizationMigration(t *testing.T) +func TestClient_GetMigration(t *testing.T) +func TestClient_GetOrganizationMigration(t *testing.T) +func TestClient_GetMigrationLogUrl(t *testing.T) +func TestClient_AbortMigration(t *testing.T) +func TestClient_GrantMigratorRole(t *testing.T) +func TestClient_RevokeMigratorRole(t *testing.T) +``` + +Run: `go test ./pkg/github/ -run "TestClient_(Create.*MigrationSource|Start.*Migration|Get.*Migration|Abort|Grant|Revoke)" -v` +Expected: FAIL + +- [ ] **Step 4: Implement migration GraphQL mutations** + +Each mutation is a string template + variables. Use `graphqlClient.Post()`. Return parsed IDs/states. + +Models needed in `pkg/github/models.go`: +```go +type Migration struct { + ID string + SourceURL string + MigrationLogURL string + State string + WarningsCount int + FailureReason string + RepositoryName string + MigrationSource MigrationSource +} + +type MigrationSource struct { ID, Name, Type string } + +type OrgMigration struct { + State string + SourceOrgURL string + TargetOrgName string + FailureReason string + RemainingRepositoriesCount int + TotalRepositoriesCount int +} +``` + +Run tests. Expected: PASS + +- [ ] **Step 5: Write failing tests for team/mannequin REST and GraphQL methods** + +```go +func TestClient_CreateTeam(t *testing.T) +func TestClient_GetTeams(t *testing.T) +func TestClient_GetTeamMembers(t *testing.T) +func TestClient_RemoveTeamMember(t *testing.T) +func TestClient_GetTeamSlug(t *testing.T) +func TestClient_AddTeamSync(t *testing.T) +func TestClient_AddTeamToRepo(t *testing.T) +func TestClient_GetIdpGroupId(t *testing.T) +func TestClient_AddEmuGroupToTeam(t *testing.T) +func TestClient_GetMannequins(t *testing.T) +func TestClient_GetMannequinsByLogin(t *testing.T) +func TestClient_CreateAttributionInvitation(t *testing.T) +func TestClient_ReclaimMannequinSkipInvitation(t *testing.T) +``` + +- [ ] **Step 6: Implement team/mannequin methods** + +Run tests. Expected: PASS + +- [ ] **Step 7: Commit** + +--- + +### Task 3: Port Shared Command Infrastructure + +Port `OctoshiftCliException` equivalent (`UserError`), migration status constants, secret redaction, and the shared command argument validation pattern. + +**Files:** +- Create: `internal/cmdutil/errors.go` +- Create: `internal/cmdutil/errors_test.go` +- Create: `internal/cmdutil/flags.go` +- Create: `internal/cmdutil/flags_test.go` +- Create: `pkg/migration/status.go` (migration status constants and helpers) +- Create: `pkg/migration/status_test.go` + +**Reference:** `src/Octoshift/OctoshiftCliException.cs`, `src/Octoshift/RepositoryMigrationStatus.cs`, `src/Octoshift/OrganizationMigrationStatus.cs` + +- [ ] **Step 1: Write tests for UserError** + +```go +func TestUserError_Error(t *testing.T) { + err := &UserError{Message: "Source org not found"} + assert.Equal(t, "Source org not found", err.Error()) +} + +func TestUserError_Unwrap(t *testing.T) { + inner := errors.New("network failure") + err := &UserError{Message: "Failed", Err: inner} + assert.ErrorIs(t, err, inner) +} +``` + +- [ ] **Step 2: Implement UserError and migration status constants** + +```go +// internal/cmdutil/errors.go +type UserError struct { + Message string + Err error +} + +// pkg/migration/status.go +const ( + RepoMigrationQueued = "QUEUED" + RepoMigrationInProgress = "IN_PROGRESS" + RepoMigrationFailed = "FAILED" + RepoMigrationSucceeded = "SUCCEEDED" + // ... +) + +func IsRepoMigrationPending(state string) bool +func IsRepoMigrationSucceeded(state string) bool +func IsRepoMigrationFailed(state string) bool +``` + +- [ ] **Step 3: Write tests for flag validation helpers** + +URL-vs-org detection, mutual exclusivity checks, etc. + +- [ ] **Step 4: Implement flag validation helpers** + +- [ ] **Step 5: Commit** + +--- + +### Task 4: Port wait-for-migration Command + +This is the simplest shared command — a polling loop over the migration status API. + +**Files:** +- Create: `cmd/gei/wait_for_migration.go` +- Create: `cmd/gei/wait_for_migration_test.go` + +**Reference:** `src/Octoshift/Commands/WaitForMigration/WaitForMigrationCommandHandler.cs` (110 lines) + +- [ ] **Step 1: Write failing tests for wait-for-migration** + +Table-driven tests: +- Repo migration succeeds immediately +- Repo migration succeeds after 2 polls +- Repo migration fails → error +- Org migration succeeds +- Org migration fails → error +- Invalid migration ID prefix → error +- Missing migration ID → error + +```go +// cmd/gei/wait_for_migration_test.go +type mockMigrationWaiter struct { + mock.Mock +} + +func (m *mockMigrationWaiter) GetMigration(ctx context.Context, id string) (*github.Migration, error) { + args := m.Called(ctx, id) + return args.Get(0).(*github.Migration), args.Error(1) +} +``` + +Use consumer-defined interface pattern: +```go +type migrationWaiter interface { + GetMigration(ctx context.Context, id string) (*github.Migration, error) + GetOrganizationMigration(ctx context.Context, id string) (*github.OrgMigration, error) +} +``` + +- [ ] **Step 2: Implement wait-for-migration command** + +Poll interval: 60 seconds (make configurable for tests via a field on the command struct or a variable). + +```go +func newWaitForMigrationCmd(gh migrationWaiter, log *logger.Logger) *cobra.Command +``` + +Flags: `--migration-id` (required), `--github-target-pat` (for gei; `--github-pat` for ado2gh/bbs2gh), `--target-api-url` + +- [ ] **Step 3: Wire into all 3 CLIs** + +Add `newWaitForMigrationCmd(...)` to `cmd/gei/main.go`, `cmd/ado2gh/main.go`, `cmd/bbs2gh/main.go`. Note: in ado2gh/bbs2gh the PAT flag is `--github-pat` not `--github-target-pat`. + +- [ ] **Step 4: Run tests** + +Run: `go test ./cmd/gei/ -run TestWaitForMigration -v` +Expected: PASS + +- [ ] **Step 5: Commit** + +--- + +### Task 5: Port abort-migration Command + +**Files:** +- Create: `cmd/gei/abort_migration.go` +- Create: `cmd/gei/abort_migration_test.go` + +**Reference:** `src/Octoshift/Commands/AbortMigration/AbortMigrationCommandHandler.cs` (29 lines) + +- [ ] **Step 1: Write failing tests** + +- Abort succeeds → log success +- Abort fails (returns false) → log error + +- [ ] **Step 2: Implement abort-migration** + +```go +type migrationAborter interface { + AbortMigration(ctx context.Context, id string) (bool, error) +} + +func newAbortMigrationCmd(gh migrationAborter, log *logger.Logger) *cobra.Command +``` + +Flags: `--migration-id` (required), `--github-target-pat`/`--github-pat`, `--target-api-url` + +- [ ] **Step 3: Wire into all 3 CLIs** + +- [ ] **Step 4: Run tests, commit** + +--- + +### Task 6: Port download-logs Command + +**Files:** +- Create: `cmd/gei/download_logs.go` +- Create: `cmd/gei/download_logs_test.go` +- Create: `pkg/download/service.go` (HttpDownloadService equivalent) +- Create: `pkg/download/service_test.go` + +**Reference:** `src/Octoshift/Commands/DownloadLogs/DownloadLogsCommandHandler.cs` (122 lines), `src/Octoshift/Services/HttpDownloadService.cs` (60 lines) + +- [ ] **Step 1: Write tests for download service** + +- [ ] **Step 2: Implement download service** + +```go +// pkg/download/service.go +type Service struct { + client *http.Client + logger *logger.Logger +} + +func (s *Service) DownloadToFile(ctx context.Context, url, filepath string) error +func (s *Service) DownloadToBytes(ctx context.Context, url string) ([]byte, error) +``` + +- [ ] **Step 3: Write tests for download-logs command** + +Test both paths: +- By migration ID: calls GetMigration, extracts log URL, downloads +- By org/repo: calls GetMigrationLogUrl, downloads +- File exists without --overwrite → error +- File exists with --overwrite → success +- Migration not found → error + +- [ ] **Step 4: Implement download-logs command** + +```go +type logDownloader interface { + GetMigration(ctx context.Context, id string) (*github.Migration, error) + GetMigrationLogUrl(ctx context.Context, org, repo string) (string, error) +} + +func newDownloadLogsCmd(gh logDownloader, dl *download.Service, log *logger.Logger, fs *filesystem.Provider, retry *retry.Policy) *cobra.Command +``` + +Flags: `--github-target-org`, `--github-target-repo`, `--migration-id`, `--github-target-pat`/`--github-pat`, `--target-api-url`, `--migration-log-file`, `--overwrite` + +- [ ] **Step 5: Wire into all 3 CLIs, run tests, commit** + +--- + +### Task 7: Port grant-migrator-role and revoke-migrator-role Commands + +**Files:** +- Create: `cmd/gei/grant_migrator_role.go` +- Create: `cmd/gei/grant_migrator_role_test.go` +- Create: `cmd/gei/revoke_migrator_role.go` +- Create: `cmd/gei/revoke_migrator_role_test.go` + +**Reference:** `src/Octoshift/Commands/GrantMigratorRole/GrantMigratorRoleCommandHandler.cs`, `src/Octoshift/Commands/RevokeMigratorRole/RevokeMigratorRoleCommandHandler.cs` + +- [ ] **Step 1: Write tests for both commands** + +Grant: success → log, failure → log error +Revoke: success → log, failure → log error + +- [ ] **Step 2: Implement both commands** + +```go +type migratorRoleManager interface { + GetOrganizationId(ctx context.Context, org string) (string, error) + GrantMigratorRole(ctx context.Context, orgID, actor, actorType string) (bool, error) + RevokeMigratorRole(ctx context.Context, orgID, actor, actorType string) (bool, error) +} +``` + +Flags: `--github-org` (required), `--actor` (required), `--actor-type` (required, constrained to USER/TEAM), `--github-target-pat`/`--github-pat`, `--target-api-url` + +- [ ] **Step 3: Wire into all 3 CLIs, run tests, commit** + +--- + +### Task 8: Port create-team Command + +**Files:** +- Create: `cmd/gei/create_team.go` +- Create: `cmd/gei/create_team_test.go` + +**Reference:** `src/Octoshift/Commands/CreateTeam/CreateTeamCommandHandler.cs` (83 lines) + +- [ ] **Step 1: Write tests** + +- Team doesn't exist → create, link IdP group +- Team already exists → log and skip +- IdP group provided → remove members, link group +- No IdP group → skip linking + +- [ ] **Step 2: Implement create-team** + +```go +type teamCreator interface { + GetTeams(ctx context.Context, org string) ([]github.Team, error) + CreateTeam(ctx context.Context, org, name string) (string, string, error) + GetTeamMembers(ctx context.Context, org, teamSlug string) ([]string, error) + RemoveTeamMember(ctx context.Context, org, teamSlug, member string) error + GetIdpGroupId(ctx context.Context, org, groupName string) (int, error) + AddEmuGroupToTeam(ctx context.Context, org, teamSlug string, groupID int) error +} +``` + +Flags: `--github-org` (required), `--team-name` (required), `--idp-group`, `--github-target-pat`/`--github-pat`, `--target-api-url` + +- [ ] **Step 3: Wire into all 3 CLIs, run tests, commit** + +--- + +### Task 9: Port generate-mannequin-csv and reclaim-mannequin Commands + +**Files:** +- Create: `cmd/gei/generate_mannequin_csv.go` +- Create: `cmd/gei/generate_mannequin_csv_test.go` +- Create: `cmd/gei/reclaim_mannequin.go` +- Create: `cmd/gei/reclaim_mannequin_test.go` +- Create: `pkg/mannequin/service.go` (ReclaimService equivalent) +- Create: `pkg/mannequin/service_test.go` + +**Reference:** `src/Octoshift/Commands/GenerateMannequinCsv/GenerateMannequinCsvCommandHandler.cs`, `src/Octoshift/Commands/ReclaimMannequin/ReclaimMannequinCommandHandler.cs`, `src/Octoshift/Services/ReclaimService.cs` + +- [ ] **Step 1: Write tests for mannequin reclaim service** + +Port the core logic from `ReclaimService.cs`: CSV parsing, mannequin matching, invitation vs skip-invitation paths, force mode, duplicate detection. + +- [ ] **Step 2: Implement mannequin reclaim service** + +- [ ] **Step 3: Write tests for generate-mannequin-csv command** + +- [ ] **Step 4: Implement generate-mannequin-csv command** + +Flags: `--github-org` (required), `--output`, `--include-reclaimed`, `--github-target-pat`/`--github-pat`, `--target-api-url` + +- [ ] **Step 5: Write tests for reclaim-mannequin command** + +- [ ] **Step 6: Implement reclaim-mannequin command** + +Flags: `--github-org` (required), `--csv`, `--mannequin-user`, `--mannequin-id`, `--target-user`, `--force`, `--skip-invitation`, `--no-prompt`, `--github-target-pat`/`--github-pat`, `--target-api-url` + +- [ ] **Step 7: Wire into all 3 CLIs, run tests, commit** + +--- + +### Task 10: Port version checker and GitHub status check + +**Files:** +- Create: `pkg/version/checker.go` +- Create: `pkg/version/checker_test.go` +- Create: `pkg/status/github.go` +- Create: `pkg/status/github_test.go` +- Modify: `cmd/gei/main.go` (wire up PersistentPreRunE) +- Modify: `cmd/ado2gh/main.go` +- Modify: `cmd/bbs2gh/main.go` + +**Reference:** `src/Octoshift/Services/VersionChecker.cs`, `src/Octoshift/Services/GithubStatusApi.cs` + +- [ ] **Step 1: Write tests for version checker** + +- Current version < latest → returns false +- Current version == latest → returns true +- Network error → graceful fallback (don't crash) + +- [ ] **Step 2: Implement version checker** + +```go +// pkg/version/checker.go +type Checker struct { + httpClient *http.Client + logger *logger.Logger + version string // compiled-in version +} + +func (c *Checker) IsLatest(ctx context.Context) (bool, error) +func (c *Checker) GetLatestVersion(ctx context.Context) (string, error) +``` + +Fetch from `https://raw.githubusercontent.com/github/gh-gei/main/LATEST-VERSION.txt`. + +- [ ] **Step 3: Write tests for GitHub status API** + +- [ ] **Step 4: Implement GitHub status API** + +```go +// pkg/status/github.go +func GetUnresolvedIncidentsCount(ctx context.Context) (int, error) +``` + +- [ ] **Step 5: Wire into root commands' PersistentPreRunE** + +Both checks are performed before every command (unless `GEI_SKIP_VERSION_CHECK` / `GEI_SKIP_STATUS_CHECK` env vars are set). + +- [ ] **Step 6: Run all tests, run `golangci-lint`, commit** + +--- + +### Task 11: Push PR 3 + +- [ ] **Step 1: Run full test suite** + +```bash +go test -race -count=1 ./... +golangci-lint run +``` + +- [ ] **Step 2: Push branch and create draft PR** + +Base: `o1/golang-port/2` +Title: "Phase 3: GitHub API client + all shared commands (Go port)" + +--- + +## Chunk 2: Core Migration Commands + +### PR 4: Cloud Storage Clients + +Port Azure Blob, AWS S3, and GitHub-owned storage upload. + +--- + +### Task 12: Port Azure Blob Storage client + +**Files:** +- Create: `pkg/storage/azure/client.go` +- Create: `pkg/storage/azure/client_test.go` +- Modify: `go.mod` (add `github.com/Azure/azure-sdk-for-go/sdk/storage/azblob`) + +**Reference:** `src/Octoshift/Services/AzureApi.cs` (124 lines) + +- [ ] **Step 1: Write tests for Azure client** + +- Upload returns SAS URL +- Download returns bytes +- Container naming: `migration-archives-` +- Progress logging every 10 seconds + +- [ ] **Step 2: Implement Azure client** + +```go +type Client struct { + serviceClient *azblob.Client + logger *logger.Logger +} + +func NewClient(connectionString string, opts ...Option) (*Client, error) +func (c *Client) Upload(ctx context.Context, fileName string, content io.Reader, size int64) (string, error) +func (c *Client) Download(ctx context.Context, url string) ([]byte, error) +``` + +Upload creates container `migration-archives-`, uploads blob, generates SAS URL (48h expiry, read-only). + +- [ ] **Step 3: Run tests, commit** + +--- + +### Task 13: Port AWS S3 client + +**Files:** +- Create: `pkg/storage/aws/client.go` +- Create: `pkg/storage/aws/client_test.go` +- Modify: `go.mod` (add `github.com/aws/aws-sdk-go-v2`) + +**Reference:** `src/Octoshift/Services/AwsApi.cs` (141 lines) + +- [ ] **Step 1: Write tests for AWS client** + +- Upload from file path → returns pre-signed URL (48h) +- Upload from stream → returns pre-signed URL +- Progress logging + +- [ ] **Step 2: Implement AWS client** + +```go +type Client struct { + s3Client *s3.Client + presignClient *s3.PresignClient + logger *logger.Logger +} + +func NewClient(ctx context.Context, accessKey, secretKey string, opts ...Option) (*Client, error) +func (c *Client) Upload(ctx context.Context, bucket, key string, data io.Reader) (string, error) +func (c *Client) UploadFile(ctx context.Context, bucket, key, filePath string) (string, error) +``` + +Options: `WithRegion(r)`, `WithSessionToken(t)`, `WithLogger(l)` + +- [ ] **Step 3: Run tests, commit** + +--- + +### Task 14: Port GitHub-owned storage multipart upload + +**Files:** +- Create: `pkg/storage/ghowned/client.go` +- Create: `pkg/storage/ghowned/client_test.go` + +**Reference:** `src/Octoshift/Services/ArchiveUploader.cs` (190 lines) + +- [ ] **Step 1: Write tests for multipart upload** + +Test the 3-phase protocol: +- Small archive (< 100 MiB) → single POST +- Large archive → Start (POST) → Parts (PATCH, follow Location header) → Complete (PUT) +- Missing Location header → error +- Custom part size from env var +- Minimum 5 MiB part size enforcement + +- [ ] **Step 2: Implement multipart upload** + +```go +type Client struct { + httpClient *http.Client + uploadsURL string + logger *logger.Logger + retryPolicy *retry.Policy + partSize int64 // default 100 MiB +} + +func NewClient(uploadsURL string, httpClient *http.Client, opts ...Option) *Client +func (c *Client) Upload(ctx context.Context, orgDatabaseID, archiveName string, content io.ReadSeeker, size int64) (string, error) +``` + +- [ ] **Step 3: Run tests, commit** + +--- + +### Task 15: Port archive upload orchestration + +**Files:** +- Create: `pkg/archive/uploader.go` +- Create: `pkg/archive/uploader_test.go` + +This coordinates: choose storage backend → upload → return URL. + +**Reference:** `MigrateRepoCommandHandler.cs` `UploadArchive()` method pattern + +- [ ] **Step 1: Write tests** + +- Upload to Azure when Azure configured +- Upload to AWS when AWS configured +- Upload to GitHub-owned when GitHub storage configured +- Error when none configured +- Error when multiple configured + +- [ ] **Step 2: Implement orchestration** + +- [ ] **Step 3: Run tests, commit** + +- [ ] **Step 4: Push PR 4** + +Create draft PR. Base: PR 3 branch. +Title: "Phase 4: Cloud storage clients (Azure Blob, AWS S3, GitHub-owned) (Go port)" + +--- + +### PR 5: gei migrate-repo + migrate-org + +The most complex commands in the suite. + +--- + +### Task 16: Port gei migrate-repo + +**Files:** +- Create: `cmd/gei/migrate_repo.go` +- Create: `cmd/gei/migrate_repo_test.go` + +**Reference:** `src/gei/Commands/MigrateRepo/MigrateRepoCommandHandler.cs` (508 lines), `MigrateRepoCommandArgs.cs` (139 lines) + +- [ ] **Step 1: Write tests for argument validation** + +Port all the cross-field validation from `MigrateRepoCommandArgs.Validate()`: +- Reject URL in org/repo fields +- Default target-repo to source-repo +- Default source PAT to target PAT +- Validate archive URL/path mutual exclusivity +- Validate paired archive options +- AWS bucket requires GHES URL +- no-ssl-verify requires GHES URL +- Azure + GitHub storage conflict + +- [ ] **Step 2: Write tests for the happy path flows** + +- GitHub.com → GitHub.com (direct, no archive upload) +- GHES → GitHub.com via Azure storage +- GHES → GitHub.com via AWS S3 +- GHES → GitHub.com via GitHub-owned storage +- Local archive paths +- Queue-only mode +- Migration failure → error + +- [ ] **Step 3: Implement migrate-repo command** + +~25 flags. Consumer-defined interfaces for all dependencies. + +```go +type migrationRunner interface { + GetOrganizationId(ctx context.Context, org string) (string, error) + GetOrganizationDatabaseId(ctx context.Context, org string) (int, error) + CreateGhecMigrationSource(ctx context.Context, orgID string) (string, error) + StartMigration(ctx context.Context, opts github.StartMigrationOpts) (string, error) + GetMigration(ctx context.Context, id string) (*github.Migration, error) + DoesRepoExist(ctx context.Context, org, repo string) (bool, error) + // GHES archive methods... + StartGitArchiveGeneration(ctx context.Context, org string, repos []string) (int, error) + StartMetadataArchiveGeneration(ctx context.Context, org string, repos []string) (int, error) + GetArchiveMigrationStatus(ctx context.Context, org string, id int) (string, error) + GetArchiveMigrationUrl(ctx context.Context, org string, id int) (string, error) + GetEnterpriseServerVersion(ctx context.Context) (string, error) + UploadArchiveToGithubStorage(ctx context.Context, orgDBId int, archiveName string, content io.ReadSeeker, size int64) (string, error) +} +``` + +- [ ] **Step 4: Wire into gei CLI** + +- [ ] **Step 5: Run tests, commit** + +--- + +### Task 17: Port GHES version checker + +**Files:** +- Create: `pkg/ghes/version.go` +- Create: `pkg/ghes/version_test.go` + +**Reference:** `src/gei/Services/GhesVersionChecker.cs` + +The GHES version checker determines if blob credentials are required based on the GHES version. Versions < 3.8.0 require Azure/AWS storage; >= 3.8.0 can use GitHub-owned storage. + +- [ ] **Step 1: Write tests** +- [ ] **Step 2: Implement** +- [ ] **Step 3: Commit** + +--- + +### Task 18: Port gei migrate-org + +**Files:** +- Create: `cmd/gei/migrate_org.go` +- Create: `cmd/gei/migrate_org_test.go` + +**Reference:** `src/gei/Commands/MigrateOrg/MigrateOrgCommandHandler.cs` + +- [ ] **Step 1: Write tests** +- [ ] **Step 2: Implement** + +Flags: `--github-source-org`, `--github-target-org`, `--github-target-enterprise`, `--github-source-pat`, `--github-target-pat`, `--queue-only`, `--target-api-url` + +- [ ] **Step 3: Wire into gei CLI, run tests, commit** + +--- + +### Task 19: Port gei migrate-secret-alerts and migrate-code-scanning-alerts + +**Files:** +- Create: `pkg/alerts/secret_scanning.go` +- Create: `pkg/alerts/secret_scanning_test.go` +- Create: `pkg/alerts/code_scanning.go` +- Create: `pkg/alerts/code_scanning_test.go` +- Create: `cmd/gei/migrate_secret_alerts.go` +- Create: `cmd/gei/migrate_secret_alerts_test.go` +- Create: `cmd/gei/migrate_code_scanning.go` +- Create: `cmd/gei/migrate_code_scanning_test.go` + +**Reference:** `src/Octoshift/Services/SecretScanningAlertService.cs`, `src/Octoshift/Services/CodeScanningAlertService.cs` + +Also need to port the GitHub API methods for secret/code scanning (REST endpoints in GithubApi.cs). + +- [ ] **Step 1: Add secret scanning REST methods to pkg/github** +- [ ] **Step 2: Add code scanning REST methods to pkg/github** +- [ ] **Step 3: Port SecretScanningAlertService** +- [ ] **Step 4: Port CodeScanningAlertService** +- [ ] **Step 5: Implement migrate-secret-alerts command** +- [ ] **Step 6: Implement migrate-code-scanning-alerts command** +- [ ] **Step 7: Wire into gei CLI, run tests, commit** + +- [ ] **Step 8: Push PR 5** + +Create draft PR. Base: PR 4 branch. +Title: "Phase 5: gei migrate-repo, migrate-org, alert migration commands (Go port)" + +--- + +## Chunk 3: ADO Client & Commands + +### PR 6: ADO API Client + ado2gh Commands + +--- + +### Task 20: Port ADO API client to imroc/req + +Replace the existing `pkg/ado/client.go` (which uses `pkg/http`) with `imroc/req`. + +**Files:** +- Rewrite: `pkg/ado/client.go` +- Rewrite: `pkg/ado/client_test.go` +- Modify: `pkg/ado/models.go` (add all ADO model types) +- Modify: `go.mod` (add `github.com/imroc/req/v3`) + +**Reference:** `src/Octoshift/Services/AdoClient.cs` (241 lines), `src/Octoshift/Services/AdoApi.cs` (889 lines) + +The ADO client needs: +- Basic auth (`:pat` base64) +- Continuation token pagination (`x-ms-continuationtoken` header) +- Top/skip pagination (`$top`/`$skip` query params) +- Retry on 503 +- Throttling via `Retry-After` header + +- [ ] **Step 1: Write tests for ADO client pagination patterns** + +Test continuation-token pagination and top/skip pagination. + +- [ ] **Step 2: Implement ADO client with imroc/req** + +```go +type Client struct { + http *req.Client + baseURL string + logger *logger.Logger +} + +func NewClient(baseURL, pat string, opts ...Option) *Client +``` + +- [ ] **Step 3: Port all ~39 ADO API methods** + +Group by area and implement in batches: +1. Org/Identity (GetOrgOwner, GetUserId, GetOrganizations, etc.) +2. Team Projects (GetTeamProjects, GetTeamProjectId) +3. Repos (GetRepos, GetEnabledRepos, GetRepoId, DisableRepo, LockRepo) +4. Pipelines (GetPipelines, GetPipelineId, QueueBuild, GetBuildStatus, etc.) +5. Service Connections (GetGithubAppId, ContainsServiceConnection, ShareServiceConnection) +6. Boards (GetBoardsGithubConnection, CreateBoardsGithubEndpoint, etc.) +7. Git Statistics (GetLastPushDate, GetCommitCountSince, etc.) + +Each group gets its own test file or test section. + +- [ ] **Step 4: Run tests, commit** + +--- + +### Task 21: Port ado2gh migrate-repo + +**Files:** +- Create: `cmd/ado2gh/migrate_repo.go` +- Create: `cmd/ado2gh/migrate_repo_test.go` + +**Reference:** `src/ado2gh/Commands/MigrateRepo/MigrateRepoCommandHandler.cs` (105 lines) + +Simpler than gei migrate-repo: no archive upload, just creates migration source and starts migration. + +- [ ] **Step 1: Write tests** +- [ ] **Step 2: Implement** +- [ ] **Step 3: Wire, run tests, commit** + +--- + +### Task 22: Port ado2gh generate-script + +**Files:** +- Create: `cmd/ado2gh/generate_script.go` +- Create: `cmd/ado2gh/generate_script_test.go` +- Create: `pkg/ado/inspector.go` (AdoInspectorService equivalent) +- Create: `pkg/ado/inspector_test.go` + +**Reference:** `src/ado2gh/Commands/GenerateScript/GenerateScriptCommandHandler.cs` (459 lines) + +This is the ADO variant of generate-script. It: +1. Fetches ADO orgs → projects → repos +2. Optionally loads a CSV repo list +3. Generates PowerShell script with ado2gh-specific commands +4. Supports --all flag for create-teams, lock-repos, disable-repos, rewire-pipelines, etc. + +The script generation itself reuses `pkg/scriptgen` (already ported in Phase 2). + +- [ ] **Step 1: Port AdoInspectorService** + +```go +// pkg/ado/inspector.go +type Inspector struct { + client *Client + logger *logger.Logger + orgFilter string + projectFilter string +} + +func (i *Inspector) GetRepos() (map[string]map[string][]Repository, error) +func (i *Inspector) GetRepoCount() int +``` + +- [ ] **Step 2: Write tests for generate-script** +- [ ] **Step 3: Implement ado2gh generate-script** +- [ ] **Step 4: Validate with scripts/validate-scripts.sh** +- [ ] **Step 5: Wire, run tests, commit** + +--- + +### Task 23: Port ado2gh simple commands + +8 ADO-specific low-complexity commands: + +| Command | Handler Lines | Dependencies | +|---------|:---:|---| +| `lock-ado-repo` | 36 | AdoApi | +| `disable-ado-repo` | ~30 | AdoApi | +| `add-team-to-repo` | ~40 | GithubApi | +| `configure-auto-link` | ~50 | GithubApi | +| `share-service-connection` | ~40 | AdoApi | +| `integrate-boards` | ~80 | AdoApi, GithubApi | +| `rewire-pipeline` | ~100 | AdoApi | +| `test-pipelines` | ~100 | AdoApi | + +**Files per command:** `cmd/ado2gh/.go` + `cmd/ado2gh/_test.go` + +- [ ] **Step 1: Port lock-ado-repo and disable-ado-repo** + +These are the simplest — each is a few ADO API calls. + +- [ ] **Step 2: Port add-team-to-repo and configure-auto-link** + +GitHub API calls via go-github. + +- [ ] **Step 3: Port share-service-connection and integrate-boards** + +ADO-specific API calls (contribution queries). + +- [ ] **Step 4: Port rewire-pipeline** + +More complex: fetches pipeline definition, modifies repository configuration, updates. + +- [ ] **Step 5: Port test-pipelines** + +Concurrent pipeline testing with status polling. + +- [ ] **Step 6: Run all tests, lint, commit** + +--- + +### Task 24: Port ado2gh inventory-report + +**Files:** +- Create: `cmd/ado2gh/inventory_report.go` +- Create: `cmd/ado2gh/inventory_report_test.go` +- Create: `pkg/ado/csvgen.go` (CSV generator services) +- Create: `pkg/ado/csvgen_test.go` + +**Reference:** `src/ado2gh/Commands/InventoryReport/InventoryReportCommandHandler.cs`, `src/ado2gh/Services/OrgsCsvGeneratorService.cs`, `src/ado2gh/Services/TeamProjectsCsvGeneratorService.cs`, `src/ado2gh/Services/ReposCsvGeneratorService.cs`, `src/ado2gh/Services/PipelinesCsvGeneratorService.cs` + +- [ ] **Step 1: Port CSV generators** +- [ ] **Step 2: Port inventory-report command** +- [ ] **Step 3: Run tests, commit** + +- [ ] **Step 4: Push PR 6** + +Create draft PR. Base: PR 5 branch. +Title: "Phase 6: ADO API client + all ado2gh commands (Go port)" + +--- + +## Chunk 4: BBS Client & Commands + +### PR 7: BBS API Client + bbs2gh Commands + +--- + +### Task 25: Port BBS API client to imroc/req + +Replace the existing `pkg/bbs/client.go` with `imroc/req`. + +**Files:** +- Rewrite: `pkg/bbs/client.go` +- Rewrite: `pkg/bbs/client_test.go` +- Modify: `pkg/bbs/models.go` + +**Reference:** `src/Octoshift/Services/BbsClient.cs` (116 lines), `src/Octoshift/Services/BbsApi.cs` (148 lines) + +BBS pagination: `isLastPage` boolean + `nextPageStart` field + `values[]` array. + +- [ ] **Step 1: Write tests for BBS pagination** +- [ ] **Step 2: Implement BBS client with imroc/req** +- [ ] **Step 3: Port all BBS API methods** + +Methods: +- GetServerVersion, StartExport, GetExport +- GetProjects, GetProject, GetRepos +- GetIsRepositoryArchived, GetRepositoryPullRequests, GetRepositoryLatestCommitDate, GetRepositoryAndAttachmentsSize + +- [ ] **Step 4: Run tests, commit** + +--- + +### Task 26: Port BBS archive downloaders (SSH + SMB) + +**Files:** +- Create: `pkg/bbs/ssh_downloader.go` +- Create: `pkg/bbs/ssh_downloader_test.go` +- Create: `pkg/bbs/smb_downloader.go` +- Create: `pkg/bbs/smb_downloader_test.go` +- Modify: `go.mod` (add `golang.org/x/crypto`, evaluate `github.com/hirochachacha/go-smb2`) + +**Reference:** `src/bbs2gh/Services/BbsSshArchiveDownloader.cs`, `src/bbs2gh/Services/BbsSmbArchiveDownloader.cs` + +- [ ] **Step 1: Port SSH archive downloader** + +Uses `golang.org/x/crypto/ssh` to SFTP-download the export archive from BBS. + +- [ ] **Step 2: Port SMB archive downloader** + +Uses `go-smb2` to download over SMB/CIFS. + +- [ ] **Step 3: Write tests, commit** + +--- + +### Task 27: Port bbs2gh migrate-repo + +**Files:** +- Create: `cmd/bbs2gh/migrate_repo.go` +- Create: `cmd/bbs2gh/migrate_repo_test.go` + +**Reference:** `src/bbs2gh/Commands/MigrateRepo/MigrateRepoCommandHandler.cs` (403 lines) + +The most complex BBS command: export generation → archive download (SSH/SMB) → upload (Azure/AWS/GH) → import. + +- [ ] **Step 1: Write tests for all 5 phases** +- [ ] **Step 2: Implement bbs2gh migrate-repo** +- [ ] **Step 3: Wire, run tests, commit** + +--- + +### Task 28: Port bbs2gh generate-script + +**Files:** +- Create: `cmd/bbs2gh/generate_script.go` +- Create: `cmd/bbs2gh/generate_script_test.go` +- Create: `pkg/bbs/inspector.go` +- Create: `pkg/bbs/inspector_test.go` + +**Reference:** `src/bbs2gh/Commands/GenerateScript/GenerateScriptCommandHandler.cs` + +- [ ] **Step 1: Port BbsInspectorService** +- [ ] **Step 2: Write tests for generate-script** +- [ ] **Step 3: Implement bbs2gh generate-script** +- [ ] **Step 4: Validate with scripts/validate-scripts.sh** +- [ ] **Step 5: Wire, run tests, commit** + +--- + +### Task 29: Port bbs2gh inventory-report + +**Files:** +- Create: `cmd/bbs2gh/inventory_report.go` +- Create: `cmd/bbs2gh/inventory_report_test.go` +- Create: `pkg/bbs/csvgen.go` +- Create: `pkg/bbs/csvgen_test.go` + +**Reference:** `src/bbs2gh/Commands/InventoryReport/InventoryReportCommandHandler.cs`, `src/bbs2gh/Services/ProjectsCsvGeneratorService.cs`, `src/bbs2gh/Services/ReposCsvGeneratorService.cs` + +- [ ] **Step 1: Port CSV generators** +- [ ] **Step 2: Port inventory-report command** +- [ ] **Step 3: Run tests, commit** + +- [ ] **Step 4: Push PR 7** + +Create draft PR. Base: PR 6 branch. +Title: "Phase 7: BBS API client + all bbs2gh commands (Go port)" + +--- + +## Chunk 5: CI/CD & E2E Integration + +### PR 8: CI/CD Workflow Updates + +--- + +### Task 30: Update CI.yml build job for Go + +**Files:** +- Modify: `.github/workflows/CI.yml` + +**Reference:** Current CI.yml build job, `justfile` Go targets + +Changes to the `build` job: +- Add Go setup step (`actions/setup-go`) +- Add `go-build`, `go-test`, `golangci-lint` steps alongside existing C# steps +- Keep C# steps until the port is complete (both codebases coexist) + +- [ ] **Step 1: Add Go build and test to CI build job** +- [ ] **Step 2: Add Go lint step** +- [ ] **Step 3: Test by pushing to PR** + +--- + +### Task 31: Update build-for-e2e-test for Go binaries + +**Files:** +- Modify: `.github/workflows/CI.yml` (build-for-e2e-test job) +- Modify: `justfile` (ensure go-publish-* targets produce correct binary names) + +The e2e tests expect binaries named `gei-linux-amd64`, `ado2gh-windows-amd64.exe`, etc. The `just go-publish-*` targets must produce cross-compiled binaries with these exact names. + +- [ ] **Step 1: Update build-for-e2e-test to use Go cross-compilation** + +Replace `dotnet publish` with `GOOS=linux GOARCH=amd64 go build -o dist/gei-linux-amd64 ./cmd/gei` etc. + +- [ ] **Step 2: Verify binary artifact naming matches what e2e-test expects** + +The e2e-test job downloads artifacts and copies them into the gh extension directory. Binary names must match. + +- [ ] **Step 3: Test that e2e-test job can install and run Go binaries** + +--- + +### Task 32: Update publish job for Go binaries + +**Files:** +- Create: `publish-go.sh` or modify `publish.ps1` to support Go builds +- Modify: `.github/workflows/CI.yml` (publish job) + +Cross-compile for all 6 platform targets: +- `linux-amd64`, `linux-arm64` +- `darwin-amd64`, `darwin-arm64` +- `windows-amd64`, `windows-386` + +- [ ] **Step 1: Create Go publish script** +- [ ] **Step 2: Update publish job to build Go binaries** +- [ ] **Step 3: Update release creation to use Go binaries** + +--- + +### Task 33: Update CodeQL and other CI items + +**Files:** +- Modify: `.github/workflows/CI.yml` (CodeQL steps) +- Modify: `.github/codeql/codeql-config.yml` +- Modify: `.github/workflows/copilot-setup-steps.yml` +- Modify: `.github/dependabot.yml` (add `gomod` ecosystem) + +- [ ] **Step 1: Add Go language to CodeQL init** +- [ ] **Step 2: Update copilot-setup-steps for Go** +- [ ] **Step 3: Add gomod to dependabot** +- [ ] **Step 4: Run lint, push PR 8** + +Create draft PR. Base: PR 7 branch. +Title: "Phase 8: CI/CD workflow updates for Go binaries" + +--- + +### PR 9: E2E Test Validation + +This PR ensures all integration tests pass against Go binaries. It should change **only build steps**, not validation logic. + +--- + +### Task 34: Run e2e tests and fix any issues + +- [ ] **Step 1: Trigger manual integration test run** + +Use the `integration-tests.yml` workflow_dispatch with the Go PR. + +- [ ] **Step 2: Analyze failures** + +Failures will likely be in: +- Log output format differences (investigate `.octoshift.log` format expectations) +- Binary exit codes (ensure non-zero on error) +- Command output format (ensure exact match) + +- [ ] **Step 3: Fix any format discrepancies** + +The goal is zero changes to the test validation logic — only the Go code adapts. + +- [ ] **Step 4: Re-run e2e tests until all pass** + +- [ ] **Step 5: Push PR 9** + +Create draft PR. Base: PR 8 branch. +Title: "Phase 9: E2E test compatibility fixes (Go port)" + +--- + +## Chunk 6: Cleanup & Removal + +### PR 10: Remove pkg/http and pkg/app + +After all consumers are migrated to go-github, imroc/req, and explicit wiring: + +**Files:** +- Delete: `pkg/http/client.go`, `pkg/http/client_test.go` +- Delete: `pkg/app/app.go`, `pkg/app/app_test.go` +- Modify: All consumers to remove references + +- [ ] **Step 1: Verify no remaining imports of pkg/http or pkg/app** +- [ ] **Step 2: Remove files** +- [ ] **Step 3: Run tests, lint, commit** +- [ ] **Step 4: Push PR 10** + +Create draft PR. Base: PR 9 branch. +Title: "Phase 10: Remove deprecated pkg/http and pkg/app packages" + +--- + +## Appendix A: Flag Naming Differences Between CLIs + +The same shared command has different flag names across CLIs: + +| Shared Flag (C# base) | gei | ado2gh | bbs2gh | +|------------------------|-----|--------|--------| +| `--github-pat` | `--github-target-pat` | `--github-pat` | `--github-pat` | +| `--target-api-url` | `--target-api-url` | `--target-api-url` | `--target-api-url` | + +Implementation: Create a shared function returning `*cobra.Command`, then use a wrapper in each CLI's main.go to rename flags: + +```go +// Shared implementation +func newWaitForMigrationCmdBase(gh migrationWaiter, log *logger.Logger, patFlagName string) *cobra.Command + +// gei wiring +newWaitForMigrationCmdBase(gh, log, "github-target-pat") + +// ado2gh/bbs2gh wiring +newWaitForMigrationCmdBase(gh, log, "github-pat") +``` + +## Appendix B: Interface Evaluation + +C# uses many interfaces for DI. The Go port should evaluate each: + +| C# Interface | Go Equivalent | Rationale | +|---|---|---| +| `ICommandHandler` | Not needed | Commands are `*cobra.Command` constructors | +| `ISourceGithubApiFactory` | Not needed | Explicit construction in main.go | +| `ITargetGithubApiFactory` | Not needed | Explicit construction in main.go | +| `IVersionProvider` | `version.Checker` (concrete) | No polymorphism needed | +| `IAzureApiFactory` | Not needed | Explicit construction | +| `IBlobServiceClientFactory` | Not needed | Azure SDK handles this | +| `IBbsArchiveDownloader` | Keep as interface | SSH vs SMB runtime dispatch | + +Consumer-defined interfaces at each command file provide testability without global interface declarations. + +## Appendix C: Estimated Sizes + +| PR | Est. Lines | Est. Test Lines | Total | +|---|---:|---:|---:| +| PR 3: GitHub API + shared commands | ~3,000 | ~2,500 | ~5,500 | +| PR 4: Cloud storage clients | ~1,200 | ~800 | ~2,000 | +| PR 5: gei migrate-repo/org + alerts | ~2,500 | ~2,000 | ~4,500 | +| PR 6: ADO client + ado2gh commands | ~3,500 | ~2,500 | ~6,000 | +| PR 7: BBS client + bbs2gh commands | ~2,500 | ~2,000 | ~4,500 | +| PR 8: CI/CD workflows | ~300 | — | ~300 | +| PR 9: E2E fixes | ~200 | — | ~200 | +| PR 10: Cleanup | -400 | -200 | -600 | +| **Total** | **~12,800** | **~9,600** | **~22,400** | + +## Appendix D: Dependency Graph + +``` +PR 3 (shared commands + GitHub API) + └── PR 4 (cloud storage) + └── PR 5 (gei migrate-repo/org + alerts) + └── PR 6 (ADO client + ado2gh) + └── PR 7 (BBS client + bbs2gh) + └── PR 8 (CI/CD) + └── PR 9 (E2E fixes) + └── PR 10 (cleanup) +``` + +Each PR depends on the one above it. They form a linear stack based on `o1/golang-port/2`. diff --git a/go.mod b/go.mod index 561ea1273..5f5ed7bde 100644 --- a/go.mod +++ b/go.mod @@ -4,12 +4,14 @@ go 1.25.4 require ( github.com/avast/retry-go/v4 v4.7.0 + github.com/google/go-github/v68 v68.0.0 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.9 // indirect diff --git a/go.sum b/go.sum index 6d38076e3..c87112a70 100644 --- a/go.sum +++ b/go.sum @@ -3,6 +3,13 @@ github.com/avast/retry-go/v4 v4.7.0/go.mod h1:ZMPDa3sY2bKgpLtap9JRUgk2yTAba7cgiF github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-github/v68 v68.0.0 h1:ZW57zeNZiXTdQ16qrDiZ0k6XucrxZ2CGmoTvcCyQG6s= +github.com/google/go-github/v68 v68.0.0/go.mod h1:K9HAUBovM2sLwM408A18h+wd9vqdLOEqTUCbnRIcx68= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -15,6 +22,7 @@ github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/cmdutil/errors.go b/internal/cmdutil/errors.go new file mode 100644 index 000000000..dd7a7fa49 --- /dev/null +++ b/internal/cmdutil/errors.go @@ -0,0 +1,32 @@ +// Package cmdutil provides shared command infrastructure for the CLI. +package cmdutil + +import "fmt" + +// UserError is a user-facing error that should be displayed without a stack trace. +// It is the Go equivalent of C#'s OctoshiftCliException. +// Use errors.As to check if an error is a UserError at the top-level handler. +type UserError struct { + Message string + Err error // optional inner error for errors.Unwrap chain +} + +func (e *UserError) Error() string { return e.Message } + +// Unwrap returns the inner error, enabling errors.Is/errors.As chains. +func (e *UserError) Unwrap() error { return e.Err } + +// NewUserError creates a UserError with just a message. +func NewUserError(msg string) *UserError { + return &UserError{Message: msg} +} + +// NewUserErrorf creates a UserError with a formatted message. +func NewUserErrorf(format string, args ...any) *UserError { + return &UserError{Message: fmt.Sprintf(format, args...)} +} + +// WrapUserError wraps an existing error as a UserError with a user-friendly message. +func WrapUserError(msg string, err error) *UserError { + return &UserError{Message: msg, Err: err} +} diff --git a/internal/cmdutil/errors_test.go b/internal/cmdutil/errors_test.go new file mode 100644 index 000000000..34fe77464 --- /dev/null +++ b/internal/cmdutil/errors_test.go @@ -0,0 +1,83 @@ +package cmdutil_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/github/gh-gei/internal/cmdutil" +) + +func TestUserError_Error_ReturnsMessage(t *testing.T) { + ue := &cmdutil.UserError{Message: "something went wrong"} + if got := ue.Error(); got != "something went wrong" { + t.Errorf("Error() = %q, want %q", got, "something went wrong") + } +} + +func TestUserError_Unwrap_ReturnsInnerError(t *testing.T) { + inner := errors.New("inner cause") + ue := &cmdutil.UserError{Message: "outer", Err: inner} + got := ue.Unwrap() + if !errors.Is(got, inner) { + t.Errorf("Unwrap() = %v, want %v", got, inner) + } +} + +func TestUserError_Unwrap_ReturnsNilWhenNoInnerError(t *testing.T) { + ue := &cmdutil.UserError{Message: "no inner"} + if got := ue.Unwrap(); got != nil { + t.Errorf("Unwrap() = %v, want nil", got) + } +} + +func TestUserError_ErrorsIs_MatchesInnerError(t *testing.T) { + inner := errors.New("root cause") + ue := &cmdutil.UserError{Message: "wrapper", Err: inner} + if !errors.Is(ue, inner) { + t.Error("errors.Is(userErr, innerErr) should be true") + } +} + +func TestUserError_ErrorsAs_FromWrappedChain(t *testing.T) { + ue := &cmdutil.UserError{Message: "user-facing"} + wrapped := fmt.Errorf("context: %w", ue) + + var target *cmdutil.UserError + if !errors.As(wrapped, &target) { + t.Error("errors.As should find UserError in chain") + } + if target.Message != "user-facing" { + t.Errorf("target.Message = %q, want %q", target.Message, "user-facing") + } +} + +func TestNewUserError(t *testing.T) { + ue := cmdutil.NewUserError("bad input") + if ue.Error() != "bad input" { + t.Errorf("Error() = %q, want %q", ue.Error(), "bad input") + } + if ue.Unwrap() != nil { + t.Error("Unwrap() should be nil for NewUserError") + } +} + +func TestNewUserErrorf(t *testing.T) { + ue := cmdutil.NewUserErrorf("value %q is invalid", "foo") + want := `value "foo" is invalid` + if ue.Error() != want { + t.Errorf("Error() = %q, want %q", ue.Error(), want) + } +} + +func TestWrapUserError(t *testing.T) { + inner := errors.New("underlying issue") + ue := cmdutil.WrapUserError("friendly message", inner) + + if ue.Error() != "friendly message" { + t.Errorf("Error() = %q, want %q", ue.Error(), "friendly message") + } + if !errors.Is(ue, inner) { + t.Error("errors.Is should match inner error") + } +} diff --git a/internal/cmdutil/flags.go b/internal/cmdutil/flags.go new file mode 100644 index 000000000..c4b34f4af --- /dev/null +++ b/internal/cmdutil/flags.go @@ -0,0 +1,77 @@ +package cmdutil + +import ( + "net/url" + "strings" +) + +// IsURL checks if a string is a valid HTTP or HTTPS URL. +func IsURL(s string) bool { + if s == "" { + return false + } + u, err := url.ParseRequestURI(s) + if err != nil { + return false + } + return u.Scheme == "http" || u.Scheme == "https" +} + +// ValidateNoURL returns a UserError if the value looks like a URL. +// flagName is the flag name for the error message (e.g. "--github-org"). +func ValidateNoURL(value, flagName string) error { + if IsURL(value) { + return NewUserErrorf("%s expects a name, not a URL. Remove the URL and provide just the name.", flagName) + } + return nil +} + +// ValidateRequired returns a UserError if the value is empty or whitespace-only. +func ValidateRequired(value, flagName string) error { + if strings.TrimSpace(value) == "" { + return NewUserErrorf("%s must be provided", flagName) + } + return nil +} + +// ValidateMutuallyExclusive returns a UserError if both values are non-empty. +func ValidateMutuallyExclusive(val1, flag1, val2, flag2 string) error { + if val1 != "" && val2 != "" { + return NewUserErrorf("only one of %s or %s can be set at a time", flag1, flag2) + } + return nil +} + +// ValidatePaired returns a UserError if exactly one of the two values is set. +// Both must be provided together or neither. +func ValidatePaired(val1, flag1, val2, flag2 string) error { + set1 := val1 != "" + set2 := val2 != "" + if set1 != set2 { + return NewUserErrorf("%s and %s must be provided together", flag1, flag2) + } + return nil +} + +// ValidateRequiredWhen returns a UserError if condition is true but value is empty. +func ValidateRequiredWhen(value, flagName string, condition bool, conditionDesc string) error { + if condition && strings.TrimSpace(value) == "" { + return NewUserErrorf("%s must be specified when %s", flagName, conditionDesc) + } + return nil +} + +// ValidateOneOf returns a UserError if value is not in the allowed list (case-insensitive). +// An empty value is considered valid (flag not set). +func ValidateOneOf(value, flagName string, allowed ...string) error { + if value == "" { + return nil + } + upper := strings.ToUpper(value) + for _, a := range allowed { + if strings.ToUpper(a) == upper { + return nil + } + } + return NewUserErrorf("%s must be one of: %s", flagName, strings.Join(allowed, ", ")) +} diff --git a/internal/cmdutil/flags_test.go b/internal/cmdutil/flags_test.go new file mode 100644 index 000000000..7016034ef --- /dev/null +++ b/internal/cmdutil/flags_test.go @@ -0,0 +1,192 @@ +package cmdutil_test + +import ( + "errors" + "testing" + + "github.com/github/gh-gei/internal/cmdutil" +) + +func TestIsURL(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"http url", "http://example.com", true}, + {"https url", "https://example.com", true}, + {"https with path", "https://github.com/org", true}, + {"not a url", "my-org", false}, + {"empty string", "", false}, + {"ftp scheme", "ftp://example.com", false}, + {"just a word", "github", false}, + {"url-like but no scheme", "example.com/path", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := cmdutil.IsURL(tt.input); got != tt.want { + t.Errorf("IsURL(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestValidateNoURL(t *testing.T) { + t.Run("returns nil for non-URL value", func(t *testing.T) { + if err := cmdutil.ValidateNoURL("my-org", "--github-org"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns nil for empty value", func(t *testing.T) { + if err := cmdutil.ValidateNoURL("", "--github-org"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns UserError for URL value", func(t *testing.T) { + err := cmdutil.ValidateNoURL("https://github.com/org", "--github-org") + if err == nil { + t.Fatal("expected error, got nil") + } + var ue *cmdutil.UserError + if !errors.As(err, &ue) { + t.Fatalf("expected UserError, got %T", err) + } + }) +} + +func TestValidateRequired(t *testing.T) { + t.Run("returns nil for non-empty value", func(t *testing.T) { + if err := cmdutil.ValidateRequired("val", "--flag"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns UserError for empty value", func(t *testing.T) { + err := cmdutil.ValidateRequired("", "--flag") + if err == nil { + t.Fatal("expected error, got nil") + } + var ue *cmdutil.UserError + if !errors.As(err, &ue) { + t.Fatalf("expected UserError, got %T", err) + } + }) + t.Run("returns UserError for whitespace-only value", func(t *testing.T) { + err := cmdutil.ValidateRequired(" ", "--flag") + if err == nil { + t.Fatal("expected error, got nil") + } + }) +} + +func TestValidateMutuallyExclusive(t *testing.T) { + t.Run("returns nil when neither set", func(t *testing.T) { + if err := cmdutil.ValidateMutuallyExclusive("", "--a", "", "--b"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns nil when only first set", func(t *testing.T) { + if err := cmdutil.ValidateMutuallyExclusive("val", "--a", "", "--b"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns nil when only second set", func(t *testing.T) { + if err := cmdutil.ValidateMutuallyExclusive("", "--a", "val", "--b"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns UserError when both set", func(t *testing.T) { + err := cmdutil.ValidateMutuallyExclusive("v1", "--a", "v2", "--b") + if err == nil { + t.Fatal("expected error, got nil") + } + var ue *cmdutil.UserError + if !errors.As(err, &ue) { + t.Fatalf("expected UserError, got %T", err) + } + }) +} + +func TestValidatePaired(t *testing.T) { + t.Run("returns nil when both set", func(t *testing.T) { + if err := cmdutil.ValidatePaired("a", "--git-archive-url", "b", "--metadata-archive-url"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns nil when neither set", func(t *testing.T) { + if err := cmdutil.ValidatePaired("", "--git-archive-url", "", "--metadata-archive-url"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns UserError when only first set", func(t *testing.T) { + err := cmdutil.ValidatePaired("a", "--git-archive-url", "", "--metadata-archive-url") + if err == nil { + t.Fatal("expected error, got nil") + } + var ue *cmdutil.UserError + if !errors.As(err, &ue) { + t.Fatalf("expected UserError, got %T", err) + } + }) + t.Run("returns UserError when only second set", func(t *testing.T) { + err := cmdutil.ValidatePaired("", "--git-archive-url", "b", "--metadata-archive-url") + if err == nil { + t.Fatal("expected error, got nil") + } + var ue *cmdutil.UserError + if !errors.As(err, &ue) { + t.Fatalf("expected UserError, got %T", err) + } + }) +} + +func TestValidateRequiredWhen(t *testing.T) { + t.Run("returns nil when condition is false", func(t *testing.T) { + if err := cmdutil.ValidateRequiredWhen("", "--flag", false, "--other is set"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns nil when condition is true and value is set", func(t *testing.T) { + if err := cmdutil.ValidateRequiredWhen("val", "--flag", true, "--other is set"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns UserError when condition is true and value is empty", func(t *testing.T) { + err := cmdutil.ValidateRequiredWhen("", "--ghes-api-url", true, "--no-ssl-verify is specified") + if err == nil { + t.Fatal("expected error, got nil") + } + var ue *cmdutil.UserError + if !errors.As(err, &ue) { + t.Fatalf("expected UserError, got %T", err) + } + }) +} + +func TestValidateOneOf(t *testing.T) { + t.Run("returns nil for valid value", func(t *testing.T) { + if err := cmdutil.ValidateOneOf("TEAM", "--actor-type", "TEAM", "USER"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns nil for valid value case-insensitive", func(t *testing.T) { + if err := cmdutil.ValidateOneOf("team", "--actor-type", "TEAM", "USER"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns nil for empty value", func(t *testing.T) { + // empty means the flag was not set, so no validation needed + if err := cmdutil.ValidateOneOf("", "--actor-type", "TEAM", "USER"); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + t.Run("returns UserError for invalid value", func(t *testing.T) { + err := cmdutil.ValidateOneOf("INVALID", "--actor-type", "TEAM", "USER") + if err == nil { + t.Fatal("expected error, got nil") + } + var ue *cmdutil.UserError + if !errors.As(err, &ue) { + t.Fatalf("expected UserError, got %T", err) + } + }) +} diff --git a/justfile b/justfile index 327432f7b..5b162e67c 100644 --- a/justfile +++ b/justfile @@ -195,6 +195,26 @@ go-publish-macos: # Build Go binaries for all platforms go-publish-all: go-publish-linux go-publish-windows go-publish-macos +# Install Go binaries as gh CLI extensions (macOS) +go-install-extensions-macos: go-publish-macos + #!/usr/bin/env bash + set -euo pipefail + for cli in gei ado2gh bbs2gh; do + dir="gh-${cli}" + mkdir -p "$dir" + cp "./dist/osx-x64/${cli}-darwin-amd64" "./${dir}/gh-${cli}" + chmod +x "./${dir}/gh-${cli}" + cd "$dir" && gh extension install . --force && cd .. + done + echo "Go extensions installed successfully!" + +# Run GithubToGithub integration test against Go binaries (macOS) +go-e2e-github: go-install-extensions-macos + direnv exec . dotnet test src/OctoshiftCLI.IntegrationTests/OctoshiftCLI.IntegrationTests.csproj \ + --filter "GithubToGithub" \ + --logger "console;verbosity=normal" \ + /p:VersionPrefix=9.9 + # Run Go CI pipeline go-ci: go-format-check go-build go-test diff --git a/mise.toml b/mise.toml new file mode 100644 index 000000000..154473201 --- /dev/null +++ b/mise.toml @@ -0,0 +1,3 @@ +[tools] +dotnet = "8.0.410" +go = "1.25.4" diff --git a/pkg/download/service.go b/pkg/download/service.go new file mode 100644 index 000000000..c84d00bd3 --- /dev/null +++ b/pkg/download/service.go @@ -0,0 +1,88 @@ +// Package download provides HTTP download functionality. +package download + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "time" +) + +const defaultTimeout = 1 * time.Hour + +// Service downloads files over HTTP. +type Service struct { + client *http.Client +} + +// New creates a new download service with the given HTTP client. +// If client is nil, a default client with 1-hour timeout is used. +func New(client *http.Client) *Service { + if client == nil { + client = &http.Client{Timeout: defaultTimeout} + } + return &Service{client: client} +} + +// DownloadToFile downloads content from url and writes it to the given destPath. +func (s *Service) DownloadToFile(ctx context.Context, url, destPath string) error { + resp, err := s.doGet(ctx, url) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download failed: HTTP %d", resp.StatusCode) + } + + f, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("creating file %s: %w", destPath, err) + } + + if _, err := io.Copy(f, resp.Body); err != nil { + f.Close() + os.Remove(destPath) // clean up partial file + return fmt.Errorf("writing to file %s: %w", destPath, err) + } + + // Close explicitly to catch flush/sync errors. + if err := f.Close(); err != nil { + return fmt.Errorf("closing file %s: %w", destPath, err) + } + return nil +} + +// DownloadToBytes downloads content from url and returns it as a byte slice. +func (s *Service) DownloadToBytes(ctx context.Context, url string) ([]byte, error) { + resp, err := s.doGet(ctx, url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download failed: HTTP %d", resp.StatusCode) + } + + data, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading response body: %w", err) + } + return data, nil +} + +func (s *Service) doGet(ctx context.Context, url string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("creating request for %s: %w", url, err) + } + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("GET %s: %w", url, err) + } + return resp, nil +} diff --git a/pkg/download/service_test.go b/pkg/download/service_test.go new file mode 100644 index 000000000..20abc9bc9 --- /dev/null +++ b/pkg/download/service_test.go @@ -0,0 +1,81 @@ +package download + +import ( + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDownloadToFile_Success(t *testing.T) { + expectedContent := "this is the log file content" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(expectedContent)) + })) + defer srv.Close() + + svc := New(srv.Client()) + + dest := filepath.Join(t.TempDir(), "output.log") + err := svc.DownloadToFile(context.Background(), srv.URL+"/log", dest) + require.NoError(t, err) + + got, err := os.ReadFile(dest) + require.NoError(t, err) + assert.Equal(t, expectedContent, string(got)) +} + +func TestDownloadToFile_Non200Error(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer srv.Close() + + svc := New(srv.Client()) + + dest := filepath.Join(t.TempDir(), "output.log") + err := svc.DownloadToFile(context.Background(), srv.URL+"/missing", dest) + require.Error(t, err) + assert.Contains(t, err.Error(), "404") +} + +func TestDownloadToBytes_Success(t *testing.T) { + expectedContent := "byte content here" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(expectedContent)) + })) + defer srv.Close() + + svc := New(srv.Client()) + + got, err := svc.DownloadToBytes(context.Background(), srv.URL+"/data") + require.NoError(t, err) + assert.Equal(t, []byte(expectedContent), got) +} + +func TestDownloadToBytes_Non200Error(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + svc := New(srv.Client()) + + _, err := svc.DownloadToBytes(context.Background(), srv.URL+"/fail") + require.Error(t, err) + assert.Contains(t, err.Error(), "500") +} + +func TestNew_NilClient_UsesDefault(t *testing.T) { + svc := New(nil) + require.NotNil(t, svc) + require.NotNil(t, svc.client) + assert.Equal(t, defaultTimeout, svc.client.Timeout) +} diff --git a/pkg/github/client.go b/pkg/github/client.go index 672bfb2e5..49f15a9c7 100644 --- a/pkg/github/client.go +++ b/pkg/github/client.go @@ -1,155 +1,1061 @@ +// Package github provides a GitHub API client for the gh-gei migration tool. +// REST operations use go-github; GraphQL uses a thin custom client for migration-specific mutations. package github import ( "context" + "crypto/tls" "encoding/json" "fmt" + "net/http" "net/url" "strings" - "github.com/github/gh-gei/pkg/http" + gogithub "github.com/google/go-github/v68/github" + "github.com/github/gh-gei/pkg/logger" ) -// Client is a GitHub API client +const defaultAPIURL = "https://api.github.com" + +// Client is a GitHub API client that uses go-github for REST and a custom +// graphqlClient for migration-specific GraphQL operations. type Client struct { - http *http.Client - apiURL string - pat string - logger *logger.Logger + rest *gogithub.Client // go-github for REST + graphql *graphqlClient // custom for migration GraphQL + logger *logger.Logger + apiURL string +} + +// Option configures a Client. +type Option func(*clientConfig) + +type clientConfig struct { + apiURL string + logger *logger.Logger + noSSLVerify bool + version string +} + +// WithAPIURL sets the base API URL (for GHES). Defaults to https://api.github.com. +func WithAPIURL(u string) Option { + return func(c *clientConfig) { + c.apiURL = u + } +} + +// WithLogger sets the logger. +func WithLogger(l *logger.Logger) Option { + return func(c *clientConfig) { + c.logger = l + } } -// Config contains configuration for the GitHub API client -type Config struct { - APIURL string // Default: "https://api.github.com" - PAT string // Personal Access Token (from GH_PAT, GH_SOURCE_PAT, or command line) - NoSSLVerify bool // For GHES with self-signed certificates +// WithNoSSLVerify disables TLS verification (for GHES with self-signed certs). +func WithNoSSLVerify() Option { + return func(c *clientConfig) { + c.noSSLVerify = true + } } -// DefaultConfig returns a Config with sensible defaults -func DefaultConfig() Config { - return Config{ - APIURL: "https://api.github.com", - NoSSLVerify: false, +// WithVersion sets the CLI version used in User-Agent headers. +func WithVersion(v string) Option { + return func(c *clientConfig) { + c.version = v } } -// NewClient creates a new GitHub API client -func NewClient(cfg Config, httpClient *http.Client, log *logger.Logger) *Client { - apiURL := cfg.APIURL - if apiURL == "" { - apiURL = "https://api.github.com" +// NewClient creates a new GitHub API client using a PAT for auth. +func NewClient(pat string, opts ...Option) *Client { + cfg := &clientConfig{ + apiURL: defaultAPIURL, + logger: logger.New(false), + version: "0.0.0", + } + for _, o := range opts { + o(cfg) } + cfg.apiURL = strings.TrimRight(cfg.apiURL, "/") - // Trim trailing slash - apiURL = strings.TrimRight(apiURL, "/") + // Build HTTP transport + transport := &http.Transport{} + if cfg.noSSLVerify { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402 + } + httpClient := &http.Client{Transport: transport} + + // Configure go-github REST client + restClient := gogithub.NewClient(httpClient).WithAuthToken(pat) + if cfg.apiURL != defaultAPIURL { + baseURL, err := url.Parse(cfg.apiURL + "/") + if err != nil { + cfg.logger.Warning("Failed to parse API URL %q, falling back to default: %v", cfg.apiURL, err) + } else { + restClient.BaseURL = baseURL + } + } + + // Configure custom GraphQL client + gql := newGraphQLClient(cfg.apiURL, pat, cfg.version, cfg.logger) + if cfg.noSSLVerify { + gql.httpClient = httpClient + } return &Client{ - http: httpClient, - apiURL: apiURL, - pat: cfg.PAT, - logger: log, + rest: restClient, + graphql: gql, + logger: cfg.logger, + apiURL: cfg.apiURL, } } -// GetRepos fetches all repositories for a given organization -// Corresponds to C# GithubApi.GetRepos() - line 114 in GithubApi.cs +// GetRepos fetches all repositories for a given organization using go-github. +// go-github handles pagination natively via ListByOrg with ListOptions. func (c *Client) GetRepos(ctx context.Context, org string) ([]Repo, error) { - // URL encode the org name - escapedOrg := url.PathEscape(org) - apiURL := fmt.Sprintf("%s/orgs/%s/repos?per_page=100", c.apiURL, escapedOrg) - c.logger.Info("Fetching repositories for organization: %s", org) - repos := []Repo{} - page := 1 + var allRepos []Repo + opts := &gogithub.RepositoryListByOrgOptions{ + ListOptions: gogithub.ListOptions{PerPage: 100}, + } for { - pageURL := fmt.Sprintf("%s&page=%d", apiURL, page) - - headers := c.buildHeaders() - body, err := c.http.Get(ctx, pageURL, headers) + ghRepos, resp, err := c.rest.Repositories.ListByOrg(ctx, org, opts) if err != nil { - return nil, fmt.Errorf("failed to fetch repos (page %d): %w", page, err) + return nil, fmt.Errorf("failed to fetch repos: %w", err) } - var pageRepos []map[string]interface{} - if err := json.Unmarshal(body, &pageRepos); err != nil { - return nil, fmt.Errorf("failed to parse repos response: %w", err) + for _, r := range ghRepos { + allRepos = append(allRepos, Repo{ + Name: r.GetName(), + Visibility: r.GetVisibility(), + }) } - // No more repos - if len(pageRepos) == 0 { + c.logger.Debug("Fetched %d repos from page %d", len(ghRepos), opts.Page) + + if resp.NextPage == 0 { break } + opts.Page = resp.NextPage + } + + c.logger.Info("Found %d repositories in organization %s", len(allRepos), org) + return allRepos, nil +} + +// GetVersion fetches the GitHub Enterprise Server version. +// Only applicable for GHES — returns an error for GitHub.com. +// go-github's APIMeta doesn't include installed_version (GHES-specific), +// so we make a raw request via go-github and parse the field ourselves. +func (c *Client) GetVersion(ctx context.Context) (*VersionInfo, error) { + if c.apiURL == defaultAPIURL { + return nil, fmt.Errorf("version endpoint not available on GitHub.com") + } + + req, err := c.rest.NewRequest("GET", "meta", nil) + if err != nil { + return nil, fmt.Errorf("failed to create version request: %w", err) + } + + var meta struct { + InstalledVersion string `json:"installed_version"` + } + _, err = c.rest.Do(ctx, req, &meta) + if err != nil { + return nil, fmt.Errorf("failed to fetch version: %w", err) + } + + return &VersionInfo{ + Version: meta.InstalledVersion, + InstalledVersion: meta.InstalledVersion, + }, nil +} + +// GraphQL sends a GraphQL query and returns the raw "data" field. +func (c *Client) GraphQL(ctx context.Context, query string, variables json.RawMessage) (json.RawMessage, error) { + return c.graphql.Post(ctx, query, variables) +} + +// GraphQLWithPagination sends a paginated GraphQL query, collecting all pages. +func (c *Client) GraphQLWithPagination( + ctx context.Context, + query string, + variables json.RawMessage, + dataPath string, + pageInfoPath string, +) (json.RawMessage, error) { + return c.graphql.PostWithPagination(ctx, query, variables, dataPath, pageInfoPath) +} + +// --------------------------------------------------------------------------- +// Organization / User queries +// --------------------------------------------------------------------------- + +// GetOrganizationId returns the node ID (global ID) for the given organization. +func (c *Client) GetOrganizationId(ctx context.Context, org string) (string, error) { + query := `query($login: String!) {organization(login: $login) { login, id, name } }` + vars, _ := json.Marshal(map[string]string{"login": org}) + + data, err := c.graphql.Post(ctx, query, vars) + if err != nil { + return "", fmt.Errorf("failed to get organization ID for %q: %w", org, err) + } + + orgData, err := navigateJSON(data, "organization.id") + if err != nil { + return "", fmt.Errorf("failed to parse organization ID for %q: %w", org, err) + } + + var id string + if err := json.Unmarshal(orgData, &id); err != nil { + return "", fmt.Errorf("failed to unmarshal organization ID for %q: %w", org, err) + } + return id, nil +} + +// GetOrganizationDatabaseId returns the database ID (integer) for the given organization as a string. +func (c *Client) GetOrganizationDatabaseId(ctx context.Context, org string) (string, error) { + query := `query($login: String!) {organization(login: $login) { login, databaseId, name } }` + vars, _ := json.Marshal(map[string]string{"login": org}) + + data, err := c.graphql.Post(ctx, query, vars) + if err != nil { + return "", fmt.Errorf("failed to get organization database ID for %q: %w", org, err) + } + + dbIDRaw, err := navigateJSON(data, "organization.databaseId") + if err != nil { + return "", fmt.Errorf("failed to parse organization database ID for %q: %w", org, err) + } + + // databaseId comes back as a JSON number — unmarshal to json.Number then convert to string. + var num json.Number + if err := json.Unmarshal(dbIDRaw, &num); err != nil { + return "", fmt.Errorf("failed to unmarshal organization database ID for %q: %w", org, err) + } + return num.String(), nil +} + +// GetEnterpriseId returns the node ID for the given enterprise. +func (c *Client) GetEnterpriseId(ctx context.Context, enterpriseName string) (string, error) { + query := `query($slug: String!) {enterprise (slug: $slug) { slug, id } }` + vars, _ := json.Marshal(map[string]string{"slug": enterpriseName}) + + data, err := c.graphql.Post(ctx, query, vars) + if err != nil { + return "", fmt.Errorf("failed to get enterprise ID for %q: %w", enterpriseName, err) + } + + idRaw, err := navigateJSON(data, "enterprise.id") + if err != nil { + return "", fmt.Errorf("failed to parse enterprise ID for %q: %w", enterpriseName, err) + } + + var id string + if err := json.Unmarshal(idRaw, &id); err != nil { + return "", fmt.Errorf("failed to unmarshal enterprise ID for %q: %w", enterpriseName, err) + } + return id, nil +} + +// GetLoginName returns the login name of the authenticated user (viewer). +func (c *Client) GetLoginName(ctx context.Context) (string, error) { + query := `query{viewer{login}}` + + data, err := c.graphql.Post(ctx, query, nil) + if err != nil { + return "", fmt.Errorf("failed to get login name: %w", err) + } + + loginRaw, err := navigateJSON(data, "viewer.login") + if err != nil { + return "", fmt.Errorf("failed to parse login name: %w", err) + } + + var login string + if err := json.Unmarshal(loginRaw, &login); err != nil { + return "", fmt.Errorf("failed to unmarshal login name: %w", err) + } + return login, nil +} + +// GetUserId returns the node ID for the given user. +func (c *Client) GetUserId(ctx context.Context, login string) (string, error) { + query := `query($login: String!) {user(login: $login) { id, name } }` + vars, _ := json.Marshal(map[string]string{"login": login}) + + data, err := c.graphql.Post(ctx, query, vars) + if err != nil { + return "", fmt.Errorf("failed to get user ID for %q: %w", login, err) + } + + idRaw, err := navigateJSON(data, "user.id") + if err != nil { + return "", fmt.Errorf("failed to parse user ID for %q: %w", login, err) + } + + var id string + if err := json.Unmarshal(idRaw, &id); err != nil { + return "", fmt.Errorf("failed to unmarshal user ID for %q: %w", login, err) + } + return id, nil +} + +// DoesOrgExist checks whether an organization exists (REST GET /orgs/{org}). +// Returns false when the API returns 404. +func (c *Client) DoesOrgExist(ctx context.Context, org string) (bool, error) { + _, resp, err := c.rest.Organizations.Get(ctx, org) + if err != nil { + if resp != nil && resp.StatusCode == http.StatusNotFound { + return false, nil + } + return false, fmt.Errorf("failed to check if org %q exists: %w", org, err) + } + return true, nil +} - for _, repoData := range pageRepos { - name, _ := repoData["name"].(string) - visibility, _ := repoData["visibility"].(string) +// GetOrgMembershipForUser returns the role of a user within an organization. +// Returns "" if the user is not a member (404). +func (c *Client) GetOrgMembershipForUser(ctx context.Context, org, member string) (string, error) { + membership, resp, err := c.rest.Organizations.GetOrgMembership(ctx, member, org) + if err != nil { + if resp != nil && resp.StatusCode == http.StatusNotFound { + return "", nil + } + return "", fmt.Errorf("failed to get membership for %q in %q: %w", member, org, err) + } + return membership.GetRole(), nil +} + +// --------------------------------------------------------------------------- +// Migration sources & mutations +// --------------------------------------------------------------------------- + +// createMigrationSource is the shared implementation for creating ADO, BBS, and GHEC migration sources. +func (c *Client) createMigrationSource(ctx context.Context, name, sourceURL, orgID, sourceType string) (string, error) { + mutation := `mutation createMigrationSource($name: String!, $url: String!, $ownerId: ID!, $type: MigrationSourceType!) { + createMigrationSource(input: {name: $name, url: $url, ownerId: $ownerId, type: $type}) { + migrationSource { id, name, url, type } + } + }` + vars, _ := json.Marshal(map[string]string{ + "name": name, + "url": sourceURL, + "ownerId": orgID, + "type": sourceType, + }) + + data, err := c.graphql.Post(ctx, mutation, vars) + if err != nil { + return "", fmt.Errorf("failed to create %s migration source: %w", sourceType, err) + } + + idRaw, err := navigateJSON(data, "createMigrationSource.migrationSource.id") + if err != nil { + return "", fmt.Errorf("failed to parse migration source ID: %w", err) + } + + var id string + if err := json.Unmarshal(idRaw, &id); err != nil { + return "", fmt.Errorf("failed to unmarshal migration source ID: %w", err) + } + return id, nil +} + +// CreateAdoMigrationSource creates an Azure DevOps migration source. +func (c *Client) CreateAdoMigrationSource(ctx context.Context, orgID string, adoServerURL string) (string, error) { + sourceURL := adoServerURL + if sourceURL == "" { + sourceURL = "https://dev.azure.com" + } + return c.createMigrationSource(ctx, "Azure DevOps Source", sourceURL, orgID, "AZURE_DEVOPS") +} + +// CreateBbsMigrationSource creates a Bitbucket Server migration source. +func (c *Client) CreateBbsMigrationSource(ctx context.Context, orgID string) (string, error) { + return c.createMigrationSource(ctx, "Bitbucket Server Source", "https://not-used", orgID, "BITBUCKET_SERVER") +} + +// CreateGhecMigrationSource creates a GitHub Enterprise Cloud migration source. +func (c *Client) CreateGhecMigrationSource(ctx context.Context, orgID string) (string, error) { + return c.createMigrationSource(ctx, "GHEC Source", "https://github.com", orgID, "GITHUB_ARCHIVE") +} + +// StartMigration starts a repository migration. +func (c *Client) StartMigration(ctx context.Context, migrationSourceID, sourceRepoURL, orgID, repo, sourceToken, targetToken string, opts ...StartMigrationOption) (string, error) { + params := &startMigrationParams{} + for _, o := range opts { + o(params) + } - if name != "" { - repos = append(repos, Repo{ - Name: name, - Visibility: visibility, - }) + mutation := `mutation startRepositoryMigration( + $sourceId: ID!, $ownerId: ID!, $sourceRepositoryUrl: URI!, $repositoryName: String!, + $continueOnError: Boolean!, $gitArchiveUrl: String, $metadataArchiveUrl: String, + $accessToken: String!, $githubPat: String, $skipReleases: Boolean, + $targetRepoVisibility: String, $lockSource: Boolean + ) { + startRepositoryMigration(input: { + sourceId: $sourceId, ownerId: $ownerId, sourceRepositoryUrl: $sourceRepositoryUrl, + repositoryName: $repositoryName, continueOnError: $continueOnError, + gitArchiveUrl: $gitArchiveUrl, metadataArchiveUrl: $metadataArchiveUrl, + accessToken: $accessToken, githubPat: $githubPat, skipReleases: $skipReleases, + targetRepoVisibility: $targetRepoVisibility, lockSource: $lockSource + }) { + repositoryMigration { id, databaseId, migrationSource { id, name, type }, sourceUrl, state, failureReason } + } + }` + + varsMap := map[string]interface{}{ + "sourceId": migrationSourceID, + "ownerId": orgID, + "sourceRepositoryUrl": sourceRepoURL, + "repositoryName": repo, + "continueOnError": true, + "accessToken": sourceToken, + "githubPat": targetToken, + "gitArchiveUrl": params.gitArchiveURL, + "metadataArchiveUrl": params.metadataArchiveURL, + "skipReleases": params.skipReleases, + "targetRepoVisibility": params.targetRepoVisibility, + "lockSource": params.lockSource, + } + vars, _ := json.Marshal(varsMap) + + data, err := c.graphql.Post(ctx, mutation, vars) + if err != nil { + return "", fmt.Errorf("failed to start migration for %q: %w", repo, err) + } + + idRaw, err := navigateJSON(data, "startRepositoryMigration.repositoryMigration.id") + if err != nil { + return "", fmt.Errorf("failed to parse migration ID: %w", err) + } + + var id string + if err := json.Unmarshal(idRaw, &id); err != nil { + return "", fmt.Errorf("failed to unmarshal migration ID: %w", err) + } + return id, nil +} + +// StartBbsMigration starts a Bitbucket Server migration. +func (c *Client) StartBbsMigration(ctx context.Context, migrationSourceID, bbsRepoURL, orgID, repo, targetToken, archiveURL, targetRepoVisibility string) (string, error) { + return c.StartMigration(ctx, migrationSourceID, bbsRepoURL, orgID, repo, + "not-used", targetToken, + WithGitArchiveURL(archiveURL), + WithMetadataArchiveURL("https://not-used"), + WithSkipReleases(false), + WithTargetRepoVisibility(targetRepoVisibility), + WithLockSource(false), + ) +} + +// StartOrganizationMigration starts an organization-level migration. +func (c *Client) StartOrganizationMigration(ctx context.Context, sourceOrgURL, targetOrgName, targetEnterpriseID, sourceAccessToken string) (string, error) { + mutation := `mutation startOrganizationMigration( + $sourceOrgUrl: URI!, $targetOrgName: String!, $targetEnterpriseId: ID!, $sourceAccessToken: String! + ) { + startOrganizationMigration(input: { + sourceOrgUrl: $sourceOrgUrl, targetOrgName: $targetOrgName, + targetEnterpriseId: $targetEnterpriseId, sourceAccessToken: $sourceAccessToken + }) { + orgMigration { id, databaseId } + } + }` + vars, _ := json.Marshal(map[string]string{ + "sourceOrgUrl": sourceOrgURL, + "targetOrgName": targetOrgName, + "targetEnterpriseId": targetEnterpriseID, + "sourceAccessToken": sourceAccessToken, + }) + + data, err := c.graphql.Post(ctx, mutation, vars) + if err != nil { + return "", fmt.Errorf("failed to start organization migration: %w", err) + } + + idRaw, err := navigateJSON(data, "startOrganizationMigration.orgMigration.id") + if err != nil { + return "", fmt.Errorf("failed to parse org migration ID: %w", err) + } + + var id string + if err := json.Unmarshal(idRaw, &id); err != nil { + return "", fmt.Errorf("failed to unmarshal org migration ID: %w", err) + } + return id, nil +} + +// GetMigration retrieves migration details by node ID. +func (c *Client) GetMigration(ctx context.Context, migrationID string) (*Migration, error) { + query := `query($id: ID!) { + node(id: $id) { + ... on Migration { + id, sourceUrl, migrationLogUrl, migrationSource { name }, state, warningsCount, failureReason, repositoryName } } + }` + vars, _ := json.Marshal(map[string]string{"id": migrationID}) - c.logger.Debug("Fetched %d repos from page %d", len(pageRepos), page) + data, err := c.graphql.Post(ctx, query, vars) + if err != nil { + return nil, fmt.Errorf("failed to get migration %q: %w", migrationID, err) + } - // Check if there are more pages - // GitHub returns less than 100 if it's the last page - if len(pageRepos) < 100 { - break + nodeRaw, err := navigateJSON(data, "node") + if err != nil { + return nil, fmt.Errorf("failed to parse migration %q: %w", migrationID, err) + } + + var result struct { + ID string `json:"id"` + SourceURL string `json:"sourceUrl"` + MigrationLogURL string `json:"migrationLogUrl"` + MigrationSource struct { + Name string `json:"name"` + } `json:"migrationSource"` + State string `json:"state"` + WarningsCount int `json:"warningsCount"` + FailureReason string `json:"failureReason"` + RepositoryName string `json:"repositoryName"` + } + if err := json.Unmarshal(nodeRaw, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal migration %q: %w", migrationID, err) + } + + return &Migration{ + ID: result.ID, + SourceURL: result.SourceURL, + MigrationLogURL: result.MigrationLogURL, + State: result.State, + WarningsCount: result.WarningsCount, + FailureReason: result.FailureReason, + RepositoryName: result.RepositoryName, + MigrationSource: MigrationSource{Name: result.MigrationSource.Name}, + }, nil +} + +// GetOrganizationMigration retrieves an organization migration by node ID. +func (c *Client) GetOrganizationMigration(ctx context.Context, migrationID string) (*OrgMigration, error) { + query := `query($id: ID!) { + node(id: $id) { + ... on OrganizationMigration { + state, sourceOrgUrl, targetOrgName, failureReason, remainingRepositoriesCount, totalRepositoriesCount + } + } + }` + vars, _ := json.Marshal(map[string]string{"id": migrationID}) + + data, err := c.graphql.Post(ctx, query, vars) + if err != nil { + return nil, fmt.Errorf("failed to get organization migration %q: %w", migrationID, err) + } + + nodeRaw, err := navigateJSON(data, "node") + if err != nil { + return nil, fmt.Errorf("failed to parse organization migration %q: %w", migrationID, err) + } + + var result struct { + State string `json:"state"` + SourceOrgURL string `json:"sourceOrgUrl"` + TargetOrgName string `json:"targetOrgName"` + FailureReason string `json:"failureReason"` + RemainingRepositoriesCount int `json:"remainingRepositoriesCount"` + TotalRepositoriesCount int `json:"totalRepositoriesCount"` + } + if err := json.Unmarshal(nodeRaw, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal organization migration %q: %w", migrationID, err) + } + + return &OrgMigration{ + State: result.State, + SourceOrgURL: result.SourceOrgURL, + TargetOrgName: result.TargetOrgName, + FailureReason: result.FailureReason, + RemainingRepositoriesCount: result.RemainingRepositoriesCount, + TotalRepositoriesCount: result.TotalRepositoriesCount, + }, nil +} + +// GetMigrationLogUrl looks up the migration log URL for the most recent migration +// of a given repository within an organization. +func (c *Client) GetMigrationLogUrl(ctx context.Context, org, repo string) (*MigrationLogResult, error) { + query := `query($org: String!, $repo: String!) { + organization(login: $org) { + repositoryMigrations(last: 1, repositoryName: $repo) { + nodes { id, migrationLogUrl } + } } + }` + vars, _ := json.Marshal(map[string]string{"org": org, "repo": repo}) - page++ + data, err := c.graphql.Post(ctx, query, vars) + if err != nil { + return nil, fmt.Errorf("failed to get migration log URL for %s/%s: %w", org, repo, err) } - c.logger.Info("Found %d repositories in organization %s", len(repos), org) + nodesRaw, err := navigateJSON(data, "organization.repositoryMigrations.nodes") + if err != nil { + return nil, fmt.Errorf("failed to parse migration log URL response: %w", err) + } - return repos, nil + var nodes []struct { + ID string `json:"id"` + MigrationLogURL string `json:"migrationLogUrl"` + } + if err := json.Unmarshal(nodesRaw, &nodes); err != nil { + return nil, fmt.Errorf("failed to unmarshal migration log URL response: %w", err) + } + + if len(nodes) == 0 { + return &MigrationLogResult{}, nil + } + + return &MigrationLogResult{ + MigrationLogURL: nodes[0].MigrationLogURL, + MigrationID: nodes[0].ID, + }, nil } -// GetVersion fetches the GitHub Enterprise Server version -// Used by generate-script to determine if blob credentials are required -func (c *Client) GetVersion(ctx context.Context) (*VersionInfo, error) { - // Only applicable for GHES - if c.apiURL == "https://api.github.com" { - return nil, fmt.Errorf("version endpoint not available on GitHub.com") +// AbortMigration aborts a repository migration by ID. +// Returns whether the abort was successful. +func (c *Client) AbortMigration(ctx context.Context, migrationID string) (bool, error) { + mutation := `mutation abortRepositoryMigration($migrationId: ID!) { + abortRepositoryMigration(input: { migrationId: $migrationId }) { success } + }` + vars, _ := json.Marshal(map[string]string{"migrationId": migrationID}) + + data, err := c.graphql.Post(ctx, mutation, vars) + if err != nil { + if strings.Contains(err.Error(), "Could not resolve to a node") { + return false, fmt.Errorf("invalid migration id: %s", migrationID) + } + return false, fmt.Errorf("failed to abort migration %q: %w", migrationID, err) } - apiURL := fmt.Sprintf("%s/meta", c.apiURL) + successRaw, err := navigateJSON(data, "abortRepositoryMigration.success") + if err != nil { + return false, fmt.Errorf("failed to parse abort response: %w", err) + } - headers := c.buildHeaders() - body, err := c.http.Get(ctx, apiURL, headers) + var success bool + if err := json.Unmarshal(successRaw, &success); err != nil { + return false, fmt.Errorf("failed to unmarshal abort response: %w", err) + } + return success, nil +} + +// GrantMigratorRole grants the migrator role to an actor within an organization. +func (c *Client) GrantMigratorRole(ctx context.Context, orgID, actor, actorType string) (bool, error) { + mutation := `mutation grantMigratorRole($organizationId: ID!, $actor: String!, $actor_type: ActorType!) { + grantMigratorRole(input: { organizationId: $organizationId, actor: $actor, actorType: $actor_type }) { success } + }` + vars, _ := json.Marshal(map[string]string{ + "organizationId": orgID, + "actor": actor, + "actor_type": actorType, + }) + + data, err := c.graphql.Post(ctx, mutation, vars) if err != nil { - return nil, fmt.Errorf("failed to fetch version: %w", err) + // C# catches HttpRequestException and returns false + c.logger.Warning("Failed to grant migrator role for %q in org: %v", actor, err) + return false, nil //nolint:nilerr } - var meta map[string]interface{} - if err := json.Unmarshal(body, &meta); err != nil { - return nil, fmt.Errorf("failed to parse version response: %w", err) + successRaw, err := navigateJSON(data, "grantMigratorRole.success") + if err != nil { + c.logger.Warning("Failed to parse grant migrator role response: %v", err) + return false, nil } - version, _ := meta["installed_version"].(string) + var success bool + if err := json.Unmarshal(successRaw, &success); err != nil { + c.logger.Warning("Failed to unmarshal grant migrator role response: %v", err) + return false, nil + } + return success, nil +} - return &VersionInfo{ - Version: version, - InstalledVersion: version, +// RevokeMigratorRole revokes the migrator role from an actor within an organization. +func (c *Client) RevokeMigratorRole(ctx context.Context, orgID, actor, actorType string) (bool, error) { + mutation := `mutation revokeMigratorRole($organizationId: ID!, $actor: String!, $actor_type: ActorType!) { + revokeMigratorRole(input: { organizationId: $organizationId, actor: $actor, actorType: $actor_type }) { success } + }` + vars, _ := json.Marshal(map[string]string{ + "organizationId": orgID, + "actor": actor, + "actor_type": actorType, + }) + + data, err := c.graphql.Post(ctx, mutation, vars) + if err != nil { + // C# catches HttpRequestException and returns false + c.logger.Warning("Failed to revoke migrator role for %q in org: %v", actor, err) + return false, nil //nolint:nilerr + } + + successRaw, err := navigateJSON(data, "revokeMigratorRole.success") + if err != nil { + c.logger.Warning("Failed to parse revoke migrator role response: %v", err) + return false, nil + } + + var success bool + if err := json.Unmarshal(successRaw, &success); err != nil { + c.logger.Warning("Failed to unmarshal revoke migrator role response: %v", err) + return false, nil + } + return success, nil +} + +// --------------------------------------------------------------------------- +// Team methods +// --------------------------------------------------------------------------- + +// CreateTeam creates a team in the given organization with "closed" privacy. +// On 5xx errors, it checks whether the team was already created (idempotency). +func (c *Client) CreateTeam(ctx context.Context, org, teamName string) (*Team, error) { + team, resp, err := c.rest.Teams.CreateTeam(ctx, org, gogithub.NewTeam{ + Name: teamName, + Privacy: gogithub.Ptr("closed"), + }) + if err != nil { + // On 5xx, check if team was actually created (idempotency) + if resp != nil && resp.StatusCode >= 500 { + teams, listErr := c.GetTeams(ctx, org) + if listErr == nil { + for _, t := range teams { + if strings.EqualFold(t.Name, teamName) { + return &t, nil + } + } + } + } + return nil, fmt.Errorf("failed to create team %q in %q: %w", teamName, org, err) + } + + return &Team{ + ID: fmt.Sprintf("%d", team.GetID()), + Name: team.GetName(), + Slug: team.GetSlug(), }, nil } -// buildHeaders constructs the HTTP headers for GitHub API requests -func (c *Client) buildHeaders() map[string]string { - headers := map[string]string{ - "Accept": "application/vnd.github+json", - "X-GitHub-Api-Version": "2022-11-28", +// GetTeams lists all teams in an organization. +func (c *Client) GetTeams(ctx context.Context, org string) ([]Team, error) { + var allTeams []Team + opts := &gogithub.ListOptions{PerPage: 100} + + for { + teams, resp, err := c.rest.Teams.ListTeams(ctx, org, opts) + if err != nil { + return nil, fmt.Errorf("failed to list teams for %q: %w", org, err) + } + + for _, t := range teams { + allTeams = append(allTeams, Team{ + ID: fmt.Sprintf("%d", t.GetID()), + Name: t.GetName(), + Slug: t.GetSlug(), + }) + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + + return allTeams, nil +} + +// GetTeamMembers lists the members of a team by slug. +func (c *Client) GetTeamMembers(ctx context.Context, org, teamSlug string) ([]string, error) { + var members []string + opts := &gogithub.TeamListTeamMembersOptions{ + ListOptions: gogithub.ListOptions{PerPage: 100}, + } + + for { + users, resp, err := c.rest.Teams.ListTeamMembersBySlug(ctx, org, teamSlug, opts) + if err != nil { + return nil, fmt.Errorf("failed to list team members for %q/%q: %w", org, teamSlug, err) + } + + for _, u := range users { + members = append(members, u.GetLogin()) + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + + return members, nil +} + +// RemoveTeamMember removes a user from a team. +func (c *Client) RemoveTeamMember(ctx context.Context, org, teamSlug, member string) error { + _, err := c.rest.Teams.RemoveTeamMembershipBySlug(ctx, org, teamSlug, member) + if err != nil { + return fmt.Errorf("failed to remove %q from team %q/%q: %w", member, org, teamSlug, err) + } + return nil +} + +// GetTeamSlug finds a team by name (case-insensitive) and returns its slug. +func (c *Client) GetTeamSlug(ctx context.Context, org, teamName string) (string, error) { + teams, err := c.GetTeams(ctx, org) + if err != nil { + return "", err + } + + for _, t := range teams { + if strings.EqualFold(t.Name, teamName) { + return t.Slug, nil + } + } + + return "", fmt.Errorf("team %q not found in organization %q", teamName, org) +} + +// AddTeamSync sets up team sync group mappings for a team. +func (c *Client) AddTeamSync(ctx context.Context, org, teamSlug, groupID, groupName, groupDescription string) error { + payload := map[string]interface{}{ + "groups": []map[string]string{ + { + "group_id": groupID, + "group_name": groupName, + "group_description": groupDescription, + }, + }, + } + + u := fmt.Sprintf("orgs/%s/teams/%s/team-sync/group-mappings", org, teamSlug) + req, err := c.rest.NewRequest("PATCH", u, payload) + if err != nil { + return fmt.Errorf("failed to create team sync request: %w", err) + } + + _, err = c.rest.Do(ctx, req, nil) + if err != nil { + return fmt.Errorf("failed to add team sync for %q/%q: %w", org, teamSlug, err) + } + return nil +} + +// AddTeamToRepo adds a team to a repository with the given permission role. +func (c *Client) AddTeamToRepo(ctx context.Context, org, teamSlug, repo, role string) error { + opts := &gogithub.TeamAddTeamRepoOptions{Permission: role} + _, err := c.rest.Teams.AddTeamRepoBySlug(ctx, org, teamSlug, org, repo, opts) + if err != nil { + return fmt.Errorf("failed to add team %q to repo %s/%s: %w", teamSlug, org, repo, err) } + return nil +} + +// GetIdpGroupId looks up the external group ID for a given group name (case-insensitive). +func (c *Client) GetIdpGroupId(ctx context.Context, org, groupName string) (int, error) { + u := fmt.Sprintf("orgs/%s/external-groups", org) - if c.pat != "" { - headers["Authorization"] = fmt.Sprintf("Bearer %s", c.pat) + req, err := c.rest.NewRequest("GET", u, nil) + if err != nil { + return 0, fmt.Errorf("failed to create external groups request: %w", err) } - return headers + var result struct { + Groups []struct { + GroupID int `json:"group_id"` + GroupName string `json:"group_name"` + } `json:"groups"` + } + _, err = c.rest.Do(ctx, req, &result) + if err != nil { + return 0, fmt.Errorf("failed to get external groups for %q: %w", org, err) + } + + for _, g := range result.Groups { + if strings.EqualFold(g.GroupName, groupName) { + return g.GroupID, nil + } + } + + return 0, fmt.Errorf("external group %q not found in organization %q", groupName, org) +} + +// AddEmuGroupToTeam links an EMU external group to a team. +func (c *Client) AddEmuGroupToTeam(ctx context.Context, org, teamSlug string, groupID int) error { + payload := map[string]int{"group_id": groupID} + + u := fmt.Sprintf("orgs/%s/teams/%s/external-groups", org, teamSlug) + req, err := c.rest.NewRequest("PATCH", u, payload) + if err != nil { + return fmt.Errorf("failed to create EMU group request: %w", err) + } + + _, err = c.rest.Do(ctx, req, nil) + if err != nil { + return fmt.Errorf("failed to add EMU group %d to team %q/%q: %w", groupID, org, teamSlug, err) + } + return nil +} + +// --------------------------------------------------------------------------- +// Mannequin methods +// --------------------------------------------------------------------------- + +// GetMannequins retrieves all mannequins for an organization (by org node ID). +func (c *Client) GetMannequins(ctx context.Context, orgID string) ([]Mannequin, error) { + query := `query($id: ID!, $first: Int, $after: String) { + node(id: $id) { + ... on Organization { + mannequins(first: $first, after: $after) { + pageInfo { endCursor, hasNextPage } + nodes { login, id, claimant { login, id } } + } + } + } + }` + vars, _ := json.Marshal(map[string]string{"id": orgID}) + + data, err := c.graphql.PostWithPagination(ctx, query, vars, + "node.mannequins.nodes", "node.mannequins.pageInfo") + if err != nil { + return nil, fmt.Errorf("failed to get mannequins for org %q: %w", orgID, err) + } + + return parseMannequins(data) +} + +// GetMannequinsByLogin retrieves mannequins for an organization filtered by login. +func (c *Client) GetMannequinsByLogin(ctx context.Context, orgID, login string) ([]Mannequin, error) { + query := `query($id: ID!, $first: Int, $after: String, $login: String) { + node(id: $id) { + ... on Organization { + mannequins(first: $first, after: $after, login: $login) { + pageInfo { endCursor, hasNextPage } + nodes { login, id, claimant { login, id } } + } + } + } + }` + vars, _ := json.Marshal(map[string]interface{}{ + "id": orgID, + "login": login, + }) + + data, err := c.graphql.PostWithPagination(ctx, query, vars, + "node.mannequins.nodes", "node.mannequins.pageInfo") + if err != nil { + return nil, fmt.Errorf("failed to get mannequins by login %q for org %q: %w", login, orgID, err) + } + + return parseMannequins(data) +} + +// parseMannequins converts raw JSON mannequin nodes into Mannequin structs. +func parseMannequins(data json.RawMessage) ([]Mannequin, error) { + var nodes []struct { + Login string `json:"login"` + ID string `json:"id"` + Claimant *struct { + Login string `json:"login"` + ID string `json:"id"` + } `json:"claimant"` + } + if err := json.Unmarshal(data, &nodes); err != nil { + return nil, fmt.Errorf("failed to unmarshal mannequins: %w", err) + } + + mannequins := make([]Mannequin, 0, len(nodes)) + for _, n := range nodes { + m := Mannequin{ + ID: n.ID, + Login: n.Login, + } + if n.Claimant != nil { + m.MappedUser = &MannequinUser{ + ID: n.Claimant.ID, + Login: n.Claimant.Login, + } + } + mannequins = append(mannequins, m) + } + return mannequins, nil +} + +// CreateAttributionInvitation creates an attribution invitation to map a mannequin to a user. +func (c *Client) CreateAttributionInvitation(ctx context.Context, orgID, sourceID, targetID string) (*CreateAttributionInvitationResult, error) { + mutation := `mutation($orgId: ID!, $sourceId: ID!, $targetId: ID!) { + createAttributionInvitation(input: { ownerId: $orgId, sourceId: $sourceId, targetId: $targetId }) { + source { ... on Mannequin { id, login } } + target { ... on User { id, login } } + } + }` + vars, _ := json.Marshal(map[string]string{ + "orgId": orgID, + "sourceId": sourceID, + "targetId": targetID, + }) + + // Use raw Post to capture potential errors in the response body + data, err := c.graphql.Post(ctx, mutation, vars) + if err != nil { + return nil, fmt.Errorf("failed to create attribution invitation: %w", err) + } + + invRaw, err := navigateJSON(data, "createAttributionInvitation") + if err != nil { + return nil, fmt.Errorf("failed to parse attribution invitation response: %w", err) + } + + var result CreateAttributionInvitationResult + if err := json.Unmarshal(invRaw, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal attribution invitation response: %w", err) + } + return &result, nil +} + +// ReclaimMannequinSkipInvitation reclaims a mannequin, skipping the email invitation. +func (c *Client) ReclaimMannequinSkipInvitation(ctx context.Context, orgID, sourceID, targetID string) (*ReattributeMannequinToUserResult, error) { + mutation := `mutation($orgId: ID!, $sourceId: ID!, $targetId: ID!) { + reattributeMannequinToUser(input: { ownerId: $orgId, sourceId: $sourceId, targetId: $targetId }) { + source { ... on Mannequin { id, login } } + target { ... on User { id, login } } + } + }` + vars, _ := json.Marshal(map[string]string{ + "orgId": orgID, + "sourceId": sourceID, + "targetId": targetID, + }) + + data, err := c.graphql.Post(ctx, mutation, vars) + if err != nil { + errStr := err.Error() + if strings.Contains(errStr, "Field 'reattributeMannequinToUser' doesn't exist on type 'Mutation'") { + return nil, fmt.Errorf("reclaim mannequin (skip invitation) is not available for this GitHub product. Error: %w", err) + } + if strings.Contains(errStr, "Target must be a member") { + return &ReattributeMannequinToUserResult{ + Errors: []ErrorData{{Message: errStr}}, + }, nil + } + return nil, fmt.Errorf("failed to reclaim mannequin: %w", err) + } + + resultRaw, err := navigateJSON(data, "reattributeMannequinToUser") + if err != nil { + return nil, fmt.Errorf("failed to parse reclaim mannequin response: %w", err) + } + + var result ReattributeMannequinToUserResult + if err := json.Unmarshal(resultRaw, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal reclaim mannequin response: %w", err) + } + return &result, nil } diff --git a/pkg/github/client_test.go b/pkg/github/client_test.go index f25053f4b..352a6a0d9 100644 --- a/pkg/github/client_test.go +++ b/pkg/github/client_test.go @@ -2,79 +2,78 @@ package github import ( "context" + "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" + "strings" "testing" - ghHttp "github.com/github/gh-gei/pkg/http" "github.com/github/gh-gei/pkg/logger" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestNewClient(t *testing.T) { - log := logger.New(false) - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := DefaultConfig() - - client := NewClient(cfg, httpClient, log) + client := NewClient("test-pat") assert.NotNil(t, client) assert.Equal(t, "https://api.github.com", client.apiURL) + assert.NotNil(t, client.rest) + assert.NotNil(t, client.graphql) + assert.NotNil(t, client.logger) // should get a default logger } -func TestNewClient_CustomAPIURL(t *testing.T) { - log := logger.New(false) - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: "https://ghes.example.com/api/v3", - PAT: "test-pat", - } - - client := NewClient(cfg, httpClient, log) +func TestNewClient_WithOptions(t *testing.T) { + log := logger.New(true) + client := NewClient("test-pat", + WithAPIURL("https://ghes.example.com/api/v3"), + WithLogger(log), + WithVersion("2.0.0"), + ) assert.NotNil(t, client) assert.Equal(t, "https://ghes.example.com/api/v3", client.apiURL) - assert.Equal(t, "test-pat", client.pat) + assert.Equal(t, log, client.logger) } func TestNewClient_TrimsTrailingSlash(t *testing.T) { - log := logger.New(false) - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: "https://ghes.example.com/api/v3/", - } - - client := NewClient(cfg, httpClient, log) + client := NewClient("test-pat", + WithAPIURL("https://ghes.example.com/api/v3/"), + ) assert.Equal(t, "https://ghes.example.com/api/v3", client.apiURL) } -func TestClient_GetRepos(t *testing.T) { - log := logger.New(false) +func TestNewClient_DefaultsToGitHubDotCom(t *testing.T) { + client := NewClient("test-pat") + assert.Equal(t, "https://api.github.com", client.apiURL) +} + +func TestClient_GetRepos(t *testing.T) { t.Run("successful fetch with single page", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/orgs/test-org/repos", r.URL.Path) - assert.Contains(t, r.URL.RawQuery, "per_page=100") + // go-github requests /api/v3/orgs/{org}/repos or /orgs/{org}/repos + assert.Contains(t, r.URL.Path, "/orgs/test-org/repos") assert.Equal(t, "Bearer test-pat", r.Header.Get("Authorization")) + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(`[ + fmt.Fprint(w, `[ {"name": "repo1", "visibility": "public"}, {"name": "repo2", "visibility": "private"}, {"name": "repo3", "visibility": "internal"} - ]`)) + ]`) })) defer server.Close() - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: server.URL, - PAT: "test-pat", - } - client := NewClient(cfg, httpClient, log) + log := logger.New(false) + client := NewClient("test-pat", + WithAPIURL(server.URL), + WithLogger(log), + ) repos, err := client.GetRepos(context.Background(), "test-org") @@ -90,57 +89,60 @@ func TestClient_GetRepos(t *testing.T) { t.Run("successful fetch with multiple pages", func(t *testing.T) { callCount := 0 - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Use a mux so the Link header can reference the server URL + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { callCount++ - w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") if callCount == 1 { - // First page - return 100 repos to trigger pagination repos := "[" - for i := 0; i < 100; i++ { + for i := 0; i < 30; i++ { if i > 0 { repos += "," } repos += fmt.Sprintf(`{"name": "repo%d", "visibility": "public"}`, i) } repos += "]" - w.Write([]byte(repos)) + // Link header pointing to the next page on this same server + w.Header().Set("Link", fmt.Sprintf(`<%s%s?page=2>; rel="next"`, server.URL, r.URL.Path)) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, repos) } else { - // Second page - return fewer than 100 to signal end - w.Write([]byte(`[ - {"name": "repo101", "visibility": "private"} - ]`)) + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `[{"name": "repo-last", "visibility": "private"}]`) } - })) - defer server.Close() + }) - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: server.URL, - PAT: "test-pat", - } - client := NewClient(cfg, httpClient, log) + log := logger.New(false) + client := NewClient("test-pat", + WithAPIURL(server.URL), + WithLogger(log), + ) repos, err := client.GetRepos(context.Background(), "test-org") require.NoError(t, err) - assert.Equal(t, 101, len(repos)) + assert.Equal(t, 31, len(repos)) assert.Equal(t, 2, callCount) }) t.Run("no repos found", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(`[]`)) + fmt.Fprint(w, `[]`) })) defer server.Close() - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: server.URL, - PAT: "test-pat", - } - client := NewClient(cfg, httpClient, log) + log := logger.New(false) + client := NewClient("test-pat", + WithAPIURL(server.URL), + WithLogger(log), + ) repos, err := client.GetRepos(context.Background(), "empty-org") @@ -150,110 +152,920 @@ func TestClient_GetRepos(t *testing.T) { t.Run("API error", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusNotFound) - w.Write([]byte(`{"message": "Not Found"}`)) + fmt.Fprint(w, `{"message": "Not Found", "documentation_url": "https://docs.github.com"}`) })) defer server.Close() - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: server.URL, - PAT: "test-pat", - } - client := NewClient(cfg, httpClient, log) + log := logger.New(false) + client := NewClient("test-pat", + WithAPIURL(server.URL), + WithLogger(log), + ) _, err := client.GetRepos(context.Background(), "nonexistent-org") require.Error(t, err) - assert.Contains(t, err.Error(), "failed to fetch repos") }) +} - t.Run("URL encodes org name", func(t *testing.T) { +func TestClient_GetVersion(t *testing.T) { + t.Run("successful version fetch for GHES", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // httptest server decodes URLs, so we check the raw query is properly formed - // The path will be decoded, but we verify the request succeeds - assert.Contains(t, r.URL.Path, "org with spaces") + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(`[]`)) + fmt.Fprint(w, `{"installed_version": "3.9.0", "verifiable_password_authentication": true}`) })) defer server.Close() - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: server.URL, - } - client := NewClient(cfg, httpClient, log) + log := logger.New(false) + client := NewClient("test-pat", + WithAPIURL(server.URL), + WithLogger(log), + ) - _, err := client.GetRepos(context.Background(), "org with spaces") + version, err := client.GetVersion(context.Background()) require.NoError(t, err) + assert.NotNil(t, version) + assert.Equal(t, "3.9.0", version.InstalledVersion) + }) + + t.Run("version not available on GitHub.com", func(t *testing.T) { + log := logger.New(false) + client := NewClient("test-pat", + WithAPIURL("https://api.github.com"), + WithLogger(log), + ) + + _, err := client.GetVersion(context.Background()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "not available on GitHub.com") }) } -func TestClient_GetVersion(t *testing.T) { - log := logger.New(false) +func TestClient_GraphQL(t *testing.T) { + t.Run("GraphQL method delegates to graphqlClient", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/graphql" { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"data":{"organization":{"id":"org123"}}}`) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() - t.Run("successful version fetch for GHES", func(t *testing.T) { + log := logger.New(false) + client := NewClient("test-pat", + WithAPIURL(server.URL), + WithLogger(log), + ) + + data, err := client.GraphQL(context.Background(), "query { organization(login: \"test\") { id } }", nil) + + require.NoError(t, err) + assert.JSONEq(t, `{"organization":{"id":"org123"}}`, string(data)) + }) +} + +// --------------------------------------------------------------------------- +// Helper: create a test client with a server that handles GraphQL requests +// --------------------------------------------------------------------------- + +func newGraphQLTestServer(t *testing.T, handler func(w http.ResponseWriter, body string)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/graphql" { + t.Errorf("unexpected path: %s", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + return + } + assert.Equal(t, "Bearer test-pat", r.Header.Get("Authorization")) + bodyBytes, _ := io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + handler(w, string(bodyBytes)) + })) +} + +func newTestClient(t *testing.T, server *httptest.Server) *Client { + t.Helper() + return NewClient("test-pat", + WithAPIURL(server.URL), + WithLogger(logger.New(false)), + ) +} + +// --------------------------------------------------------------------------- +// Group 1: Organization/User queries (GraphQL-based) +// --------------------------------------------------------------------------- + +func TestClient_GetOrganizationId(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "organization") + fmt.Fprint(w, `{"data":{"organization":{"login":"test-org","id":"ORG_ID_123","name":"Test Org"}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.GetOrganizationId(context.Background(), "test-org") + + require.NoError(t, err) + assert.Equal(t, "ORG_ID_123", id) +} + +func TestClient_GetOrganizationDatabaseId(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "databaseId") + fmt.Fprint(w, `{"data":{"organization":{"login":"test-org","databaseId":12345,"name":"Test Org"}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.GetOrganizationDatabaseId(context.Background(), "test-org") + + require.NoError(t, err) + assert.Equal(t, "12345", id) +} + +func TestClient_GetEnterpriseId(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "enterprise") + fmt.Fprint(w, `{"data":{"enterprise":{"slug":"test-ent","id":"ENT_ID_456"}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.GetEnterpriseId(context.Background(), "test-ent") + + require.NoError(t, err) + assert.Equal(t, "ENT_ID_456", id) +} + +func TestClient_GetLoginName(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "viewer") + fmt.Fprint(w, `{"data":{"viewer":{"login":"monalisa"}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + login, err := client.GetLoginName(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "monalisa", login) +} + +func TestClient_GetUserId(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "user") + fmt.Fprint(w, `{"data":{"user":{"id":"USER_ID_789","name":"Mona Lisa"}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.GetUserId(context.Background(), "monalisa") + + require.NoError(t, err) + assert.Equal(t, "USER_ID_789", id) +} + +// --------------------------------------------------------------------------- +// Group 2: REST org methods +// --------------------------------------------------------------------------- + +func TestClient_DoesOrgExist(t *testing.T) { + t.Run("org exists", func(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/meta", r.URL.Path) + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"installed_version": "3.9.0"}`)) + fmt.Fprint(w, `{"login":"test-org"}`) })) defer server.Close() - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: server.URL, - PAT: "test-pat", - } - client := NewClient(cfg, httpClient, log) + client := newTestClient(t, server) + exists, err := client.DoesOrgExist(context.Background(), "test-org") - version, err := client.GetVersion(context.Background()) + require.NoError(t, err) + assert.True(t, exists) + }) + + t.Run("org not found", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"message":"Not Found"}`) + })) + defer server.Close() + + client := newTestClient(t, server) + exists, err := client.DoesOrgExist(context.Background(), "no-such-org") require.NoError(t, err) - assert.NotNil(t, version) - assert.Equal(t, "3.9.0", version.InstalledVersion) + assert.False(t, exists) }) +} - t.Run("version not available on GitHub.com", func(t *testing.T) { - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) - cfg := Config{ - APIURL: "https://api.github.com", +func TestClient_GetOrgMembershipForUser(t *testing.T) { + t.Run("member found", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"role":"admin","state":"active","url":"https://api.github.com/orgs/test-org/memberships/test-user","organization_url":"https://api.github.com/orgs/test-org"}`) + })) + defer server.Close() + + client := newTestClient(t, server) + role, err := client.GetOrgMembershipForUser(context.Background(), "test-org", "test-user") + + require.NoError(t, err) + assert.Equal(t, "admin", role) + }) + + t.Run("not a member", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{"message":"Not Found"}`) + })) + defer server.Close() + + client := newTestClient(t, server) + role, err := client.GetOrgMembershipForUser(context.Background(), "test-org", "not-a-member") + + require.NoError(t, err) + assert.Equal(t, "", role) + }) +} + +// --------------------------------------------------------------------------- +// Group 3: Migration mutations (GraphQL-based) +// --------------------------------------------------------------------------- + +func TestClient_CreateAdoMigrationSource(t *testing.T) { + t.Run("with ado server URL", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "AZURE_DEVOPS") + assert.Contains(t, body, "https://ado.example.com") + fmt.Fprint(w, `{"data":{"createMigrationSource":{"migrationSource":{"id":"MS_ADO_123","name":"Azure DevOps Source","url":"https://ado.example.com","type":"AZURE_DEVOPS"}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.CreateAdoMigrationSource(context.Background(), "ORG_ID", "https://ado.example.com") + + require.NoError(t, err) + assert.Equal(t, "MS_ADO_123", id) + }) + + t.Run("with empty ado server URL defaults to dev.azure.com", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "AZURE_DEVOPS") + assert.Contains(t, body, "https://dev.azure.com") + fmt.Fprint(w, `{"data":{"createMigrationSource":{"migrationSource":{"id":"MS_ADO_456","name":"Azure DevOps Source","url":"https://dev.azure.com","type":"AZURE_DEVOPS"}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.CreateAdoMigrationSource(context.Background(), "ORG_ID", "") + + require.NoError(t, err) + assert.Equal(t, "MS_ADO_456", id) + }) +} + +func TestClient_CreateBbsMigrationSource(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "BITBUCKET_SERVER") + fmt.Fprint(w, `{"data":{"createMigrationSource":{"migrationSource":{"id":"MS_BBS_123","name":"Bitbucket Server Source","url":"https://not-used","type":"BITBUCKET_SERVER"}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.CreateBbsMigrationSource(context.Background(), "ORG_ID") + + require.NoError(t, err) + assert.Equal(t, "MS_BBS_123", id) +} + +func TestClient_CreateGhecMigrationSource(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "GITHUB_ARCHIVE") + fmt.Fprint(w, `{"data":{"createMigrationSource":{"migrationSource":{"id":"MS_GHEC_123","name":"GHEC Source","url":"https://github.com","type":"GITHUB_ARCHIVE"}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.CreateGhecMigrationSource(context.Background(), "ORG_ID") + + require.NoError(t, err) + assert.Equal(t, "MS_GHEC_123", id) +} + +func TestClient_StartMigration(t *testing.T) { + t.Run("basic migration", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "startRepositoryMigration") + assert.Contains(t, body, "SRC_ID") + assert.Contains(t, body, "ORG_ID") + assert.Contains(t, body, "my-repo") + fmt.Fprint(w, `{"data":{"startRepositoryMigration":{"repositoryMigration":{"id":"MIG_123","databaseId":"999","migrationSource":{"id":"SRC_ID","name":"Test Source","type":"GITHUB_ARCHIVE"},"sourceUrl":"https://github.com/org/repo","state":"QUEUED","failureReason":""}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.StartMigration(context.Background(), "SRC_ID", "https://github.com/org/repo", "ORG_ID", "my-repo", "source-token", "target-token") + + require.NoError(t, err) + assert.Equal(t, "MIG_123", id) + }) + + t.Run("with options", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + // Parse the body to verify options were set + var req struct { + Variables map[string]interface{} `json:"variables"` + } + err := json.Unmarshal([]byte(body), &req) + require.NoError(t, err) + assert.Equal(t, true, req.Variables["skipReleases"]) + assert.Equal(t, true, req.Variables["lockSource"]) + assert.Equal(t, "private", req.Variables["targetRepoVisibility"]) + + fmt.Fprint(w, `{"data":{"startRepositoryMigration":{"repositoryMigration":{"id":"MIG_456","databaseId":"998","migrationSource":{"id":"SRC_ID","name":"Test","type":"GITHUB_ARCHIVE"},"sourceUrl":"https://github.com/org/repo","state":"QUEUED","failureReason":""}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.StartMigration(context.Background(), + "SRC_ID", "https://github.com/org/repo", "ORG_ID", "my-repo", "src-tok", "tgt-tok", + WithSkipReleases(true), + WithLockSource(true), + WithTargetRepoVisibility("private"), + ) + + require.NoError(t, err) + assert.Equal(t, "MIG_456", id) + }) +} + +func TestClient_StartBbsMigration(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + // Verify that StartBbsMigration delegates to StartMigration with proper params + var req struct { + Variables map[string]interface{} `json:"variables"` } - client := NewClient(cfg, httpClient, log) + err := json.Unmarshal([]byte(body), &req) + require.NoError(t, err) + assert.Equal(t, "not-used", req.Variables["accessToken"]) + assert.Equal(t, "https://archive.example.com/archive.tar.gz", req.Variables["gitArchiveUrl"]) - _, err := client.GetVersion(context.Background()) + fmt.Fprint(w, `{"data":{"startRepositoryMigration":{"repositoryMigration":{"id":"MIG_BBS_123","databaseId":"997","migrationSource":{"id":"SRC_ID","name":"BBS Source","type":"BITBUCKET_SERVER"},"sourceUrl":"https://bbs.example.com/repo","state":"QUEUED","failureReason":""}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.StartBbsMigration(context.Background(), + "SRC_ID", "https://bbs.example.com/repo", "ORG_ID", "my-repo", + "target-token", "https://archive.example.com/archive.tar.gz", "private", + ) + + require.NoError(t, err) + assert.Equal(t, "MIG_BBS_123", id) +} + +func TestClient_StartOrganizationMigration(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "startOrganizationMigration") + assert.Contains(t, body, "https://github.com/source-org") + assert.Contains(t, body, "target-org") + fmt.Fprint(w, `{"data":{"startOrganizationMigration":{"orgMigration":{"id":"ORG_MIG_123","databaseId":"888"}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + id, err := client.StartOrganizationMigration(context.Background(), + "https://github.com/source-org", "target-org", "ENT_ID", "source-token", + ) + + require.NoError(t, err) + assert.Equal(t, "ORG_MIG_123", id) +} + +// --------------------------------------------------------------------------- +// Group 4: Migration queries (GraphQL-based) +// --------------------------------------------------------------------------- + +func TestClient_GetMigration(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "node") + fmt.Fprint(w, `{"data":{"node":{ + "id":"MIG_123", + "sourceUrl":"https://github.com/org/repo", + "migrationLogUrl":"https://example.com/log", + "migrationSource":{"name":"GHEC Source"}, + "state":"SUCCEEDED", + "warningsCount":5, + "failureReason":"", + "repositoryName":"my-repo" + }}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + mig, err := client.GetMigration(context.Background(), "MIG_123") + + require.NoError(t, err) + assert.Equal(t, "MIG_123", mig.ID) + assert.Equal(t, "https://github.com/org/repo", mig.SourceURL) + assert.Equal(t, "https://example.com/log", mig.MigrationLogURL) + assert.Equal(t, "GHEC Source", mig.MigrationSource.Name) + assert.Equal(t, "SUCCEEDED", mig.State) + assert.Equal(t, 5, mig.WarningsCount) + assert.Equal(t, "", mig.FailureReason) + assert.Equal(t, "my-repo", mig.RepositoryName) +} + +func TestClient_GetOrganizationMigration(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + fmt.Fprint(w, `{"data":{"node":{ + "state":"IN_PROGRESS", + "sourceOrgUrl":"https://github.com/source-org", + "targetOrgName":"target-org", + "failureReason":"", + "remainingRepositoriesCount":3, + "totalRepositoriesCount":10 + }}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + mig, err := client.GetOrganizationMigration(context.Background(), "ORG_MIG_123") + + require.NoError(t, err) + assert.Equal(t, "IN_PROGRESS", mig.State) + assert.Equal(t, "https://github.com/source-org", mig.SourceOrgURL) + assert.Equal(t, "target-org", mig.TargetOrgName) + assert.Equal(t, "", mig.FailureReason) + assert.Equal(t, 3, mig.RemainingRepositoriesCount) + assert.Equal(t, 10, mig.TotalRepositoriesCount) +} + +func TestClient_GetMigrationLogUrl(t *testing.T) { + t.Run("found", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + fmt.Fprint(w, `{"data":{"organization":{"repositoryMigrations":{"nodes":[{"id":"MIG_123","migrationLogUrl":"https://example.com/log"}]}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + result, err := client.GetMigrationLogUrl(context.Background(), "test-org", "my-repo") + + require.NoError(t, err) + assert.Equal(t, "https://example.com/log", result.MigrationLogURL) + assert.Equal(t, "MIG_123", result.MigrationID) + }) + + t.Run("not found", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + fmt.Fprint(w, `{"data":{"organization":{"repositoryMigrations":{"nodes":[]}}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + result, err := client.GetMigrationLogUrl(context.Background(), "test-org", "no-such-repo") + + require.NoError(t, err) + assert.Equal(t, "", result.MigrationLogURL) + assert.Equal(t, "", result.MigrationID) + }) +} + +// --------------------------------------------------------------------------- +// Group 5: Abort and migrator role (GraphQL-based) +// --------------------------------------------------------------------------- + +func TestClient_AbortMigration(t *testing.T) { + t.Run("success", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + fmt.Fprint(w, `{"data":{"abortRepositoryMigration":{"success":true}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + success, err := client.AbortMigration(context.Background(), "MIG_123") + + require.NoError(t, err) + assert.True(t, success) + }) + + t.Run("invalid migration id", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + fmt.Fprint(w, `{"errors":[{"message":"Could not resolve to a node with the global id of 'INVALID_ID'.","type":"NOT_FOUND"}]}`) + }) + defer server.Close() + + client := newTestClient(t, server) + _, err := client.AbortMigration(context.Background(), "INVALID_ID") require.Error(t, err) - assert.Contains(t, err.Error(), "not available on GitHub.com") + assert.Contains(t, err.Error(), "invalid migration id") }) } -func TestClient_BuildHeaders(t *testing.T) { - log := logger.New(false) - httpClient := ghHttp.NewClient(ghHttp.DefaultConfig(), log) +func TestClient_GrantMigratorRole(t *testing.T) { + t.Run("success", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "grantMigratorRole") + fmt.Fprint(w, `{"data":{"grantMigratorRole":{"success":true}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + success, err := client.GrantMigratorRole(context.Background(), "ORG_ID", "monalisa", "USER") + + assert.NoError(t, err) + assert.True(t, success) + }) - t.Run("headers with PAT", func(t *testing.T) { - cfg := Config{ - PAT: "test-token", + t.Run("error returns false", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `Internal Server Error`) + })) + defer server.Close() + + client := newTestClient(t, server) + success, err := client.GrantMigratorRole(context.Background(), "ORG_ID", "monalisa", "USER") + + assert.NoError(t, err) // NOT an error — matches C# behavior + assert.False(t, success) + }) +} + +func TestClient_RevokeMigratorRole(t *testing.T) { + t.Run("success", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "revokeMigratorRole") + fmt.Fprint(w, `{"data":{"revokeMigratorRole":{"success":true}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + success, err := client.RevokeMigratorRole(context.Background(), "ORG_ID", "monalisa", "USER") + + assert.NoError(t, err) + assert.True(t, success) + }) + + t.Run("error returns false", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprint(w, `Internal Server Error`) + })) + defer server.Close() + + client := newTestClient(t, server) + success, err := client.RevokeMigratorRole(context.Background(), "ORG_ID", "monalisa", "USER") + + assert.NoError(t, err) // NOT an error — matches C# behavior + assert.False(t, success) + }) +} + +// --------------------------------------------------------------------------- +// Group 6: Team methods (REST-based) +// --------------------------------------------------------------------------- + +func TestClient_CreateTeam(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/orgs/test-org/teams") && r.Method == http.MethodPost { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + fmt.Fprint(w, `{"id":42,"name":"my-team","slug":"my-team"}`) + return } - client := NewClient(cfg, httpClient, log) + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + client := newTestClient(t, server) + team, err := client.CreateTeam(context.Background(), "test-org", "my-team") + + require.NoError(t, err) + assert.Equal(t, "42", team.ID) + assert.Equal(t, "my-team", team.Name) + assert.Equal(t, "my-team", team.Slug) +} + +func TestClient_GetTeams(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `[ + {"id":1,"name":"alpha","slug":"alpha"}, + {"id":2,"name":"beta","slug":"beta"}, + {"id":3,"name":"gamma","slug":"gamma"} + ]`) + })) + defer server.Close() + + client := newTestClient(t, server) + teams, err := client.GetTeams(context.Background(), "test-org") + + require.NoError(t, err) + require.Len(t, teams, 3) + assert.Equal(t, "alpha", teams[0].Name) + assert.Equal(t, "beta", teams[1].Name) + assert.Equal(t, "gamma", teams[2].Name) +} + +func TestClient_GetTeamMembers(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `[ + {"login":"alice","id":1}, + {"login":"bob","id":2} + ]`) + })) + defer server.Close() - headers := client.buildHeaders() + client := newTestClient(t, server) + members, err := client.GetTeamMembers(context.Background(), "test-org", "my-team") - assert.Equal(t, "application/vnd.github+json", headers["Accept"]) - assert.Equal(t, "2022-11-28", headers["X-GitHub-Api-Version"]) - assert.Equal(t, "Bearer test-token", headers["Authorization"]) + require.NoError(t, err) + require.Len(t, members, 2) + assert.Equal(t, "alice", members[0]) + assert.Equal(t, "bob", members[1]) +} + +func TestClient_RemoveTeamMember(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodDelete, r.Method) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := newTestClient(t, server) + err := client.RemoveTeamMember(context.Background(), "test-org", "my-team", "alice") + + require.NoError(t, err) +} + +func TestClient_GetTeamSlug(t *testing.T) { + t.Run("found", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `[ + {"id":1,"name":"Alpha Team","slug":"alpha-team"}, + {"id":2,"name":"Beta Team","slug":"beta-team"} + ]`) + })) + defer server.Close() + + client := newTestClient(t, server) + // Case-insensitive match + slug, err := client.GetTeamSlug(context.Background(), "test-org", "alpha team") + + require.NoError(t, err) + assert.Equal(t, "alpha-team", slug) + }) + + t.Run("not found", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `[{"id":1,"name":"Alpha Team","slug":"alpha-team"}]`) + })) + defer server.Close() + + client := newTestClient(t, server) + _, err := client.GetTeamSlug(context.Background(), "test-org", "nonexistent-team") + + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +func TestClient_AddTeamSync(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PATCH", r.Method) + assert.Contains(t, r.URL.Path, "team-sync/group-mappings") + + body, _ := io.ReadAll(r.Body) + assert.Contains(t, string(body), "group_id") + assert.Contains(t, string(body), "Test Group") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"groups":[{"group_id":"42","group_name":"Test Group","group_description":"A test group"}]}`) + })) + defer server.Close() + + client := newTestClient(t, server) + err := client.AddTeamSync(context.Background(), "test-org", "my-team", "42", "Test Group", "A test group") + + require.NoError(t, err) +} + +func TestClient_AddTeamToRepo(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, http.MethodPut, r.Method) + w.WriteHeader(http.StatusNoContent) + })) + defer server.Close() + + client := newTestClient(t, server) + err := client.AddTeamToRepo(context.Background(), "test-org", "my-team", "my-repo", "push") + + require.NoError(t, err) +} + +func TestClient_GetIdpGroupId(t *testing.T) { + t.Run("found", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Contains(t, r.URL.Path, "external-groups") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"groups":[{"group_id":42,"group_name":"Test Group"}]}`) + })) + defer server.Close() + + client := newTestClient(t, server) + groupID, err := client.GetIdpGroupId(context.Background(), "test-org", "Test Group") + + require.NoError(t, err) + assert.Equal(t, 42, groupID) + }) + + t.Run("not found", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"groups":[]}`) + })) + defer server.Close() + + client := newTestClient(t, server) + _, err := client.GetIdpGroupId(context.Background(), "test-org", "No Such Group") + + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) +} + +func TestClient_AddEmuGroupToTeam(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PATCH", r.Method) + assert.Contains(t, r.URL.Path, "external-groups") + + body, _ := io.ReadAll(r.Body) + assert.Contains(t, string(body), "group_id") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"group_id":42}`) + })) + defer server.Close() + + client := newTestClient(t, server) + err := client.AddEmuGroupToTeam(context.Background(), "test-org", "my-team", 42) + + require.NoError(t, err) +} + +// --------------------------------------------------------------------------- +// Group 7: Mannequin methods (GraphQL-based) +// --------------------------------------------------------------------------- + +func TestClient_GetMannequins(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "mannequins") + // Return a single page with hasNextPage=false + fmt.Fprint(w, `{"data":{"node":{"mannequins":{ + "pageInfo":{"endCursor":"cursor1","hasNextPage":false}, + "nodes":[ + {"login":"mona","id":"MANN_1","claimant":null}, + {"login":"lisa","id":"MANN_2","claimant":{"login":"real-lisa","id":"USER_2"}} + ] + }}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + mannequins, err := client.GetMannequins(context.Background(), "ORG_ID") + + require.NoError(t, err) + require.Len(t, mannequins, 2) + + assert.Equal(t, "mona", mannequins[0].Login) + assert.Equal(t, "MANN_1", mannequins[0].ID) + assert.Nil(t, mannequins[0].MappedUser) + + assert.Equal(t, "lisa", mannequins[1].Login) + assert.Equal(t, "MANN_2", mannequins[1].ID) + require.NotNil(t, mannequins[1].MappedUser) + assert.Equal(t, "real-lisa", mannequins[1].MappedUser.Login) + assert.Equal(t, "USER_2", mannequins[1].MappedUser.ID) +} + +func TestClient_GetMannequinsByLogin(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "mannequins") + // Verify login variable is present in the body + assert.Contains(t, body, "mona") + fmt.Fprint(w, `{"data":{"node":{"mannequins":{ + "pageInfo":{"endCursor":"","hasNextPage":false}, + "nodes":[ + {"login":"mona","id":"MANN_1","claimant":null} + ] + }}}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + mannequins, err := client.GetMannequinsByLogin(context.Background(), "ORG_ID", "mona") + + require.NoError(t, err) + require.Len(t, mannequins, 1) + assert.Equal(t, "mona", mannequins[0].Login) + assert.Equal(t, "MANN_1", mannequins[0].ID) +} + +func TestClient_CreateAttributionInvitation(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "createAttributionInvitation") + fmt.Fprint(w, `{"data":{"createAttributionInvitation":{ + "source":{"id":"MANN_1","login":"mona"}, + "target":{"id":"USER_1","login":"real-mona"} + }}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + result, err := client.CreateAttributionInvitation(context.Background(), "ORG_ID", "MANN_1", "USER_1") + + require.NoError(t, err) + require.NotNil(t, result.Source) + assert.Equal(t, "MANN_1", result.Source.ID) + assert.Equal(t, "mona", result.Source.Login) + require.NotNil(t, result.Target) + assert.Equal(t, "USER_1", result.Target.ID) + assert.Equal(t, "real-mona", result.Target.Login) +} + +func TestClient_ReclaimMannequinSkipInvitation(t *testing.T) { + t.Run("success", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + assert.Contains(t, body, "reattributeMannequinToUser") + fmt.Fprint(w, `{"data":{"reattributeMannequinToUser":{ + "source":{"id":"MANN_1","login":"mona"}, + "target":{"id":"USER_1","login":"real-mona"} + }}}`) + }) + defer server.Close() + + client := newTestClient(t, server) + result, err := client.ReclaimMannequinSkipInvitation(context.Background(), "ORG_ID", "MANN_1", "USER_1") + + require.NoError(t, err) + require.NotNil(t, result.Source) + assert.Equal(t, "MANN_1", result.Source.ID) + require.NotNil(t, result.Target) + assert.Equal(t, "USER_1", result.Target.ID) + assert.Empty(t, result.Errors) + }) + + t.Run("mutation not available", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + fmt.Fprint(w, `{"errors":[{"message":"Field 'reattributeMannequinToUser' doesn't exist on type 'Mutation'"}]}`) + }) + defer server.Close() + + client := newTestClient(t, server) + _, err := client.ReclaimMannequinSkipInvitation(context.Background(), "ORG_ID", "MANN_1", "USER_1") + + require.Error(t, err) + assert.Contains(t, err.Error(), "not available") }) - t.Run("headers without PAT", func(t *testing.T) { - cfg := Config{} - client := NewClient(cfg, httpClient, log) + t.Run("target must be member", func(t *testing.T) { + server := newGraphQLTestServer(t, func(w http.ResponseWriter, body string) { + fmt.Fprint(w, `{"errors":[{"message":"Target must be a member of the organization"}]}`) + }) + defer server.Close() - headers := client.buildHeaders() + client := newTestClient(t, server) + result, err := client.ReclaimMannequinSkipInvitation(context.Background(), "ORG_ID", "MANN_1", "USER_1") - assert.Equal(t, "application/vnd.github+json", headers["Accept"]) - assert.Equal(t, "2022-11-28", headers["X-GitHub-Api-Version"]) - assert.NotContains(t, headers, "Authorization") + // Should NOT return a Go error — returns result with Errors populated + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, result.Errors, 1) + assert.Contains(t, result.Errors[0].Message, "Target must be a member") }) } diff --git a/pkg/github/graphql.go b/pkg/github/graphql.go new file mode 100644 index 000000000..4e43be417 --- /dev/null +++ b/pkg/github/graphql.go @@ -0,0 +1,282 @@ +package github + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/github/gh-gei/pkg/logger" +) + +const ( + graphqlPath = "/graphql" + graphqlFeaturesValue = "import_api,mannequin_claiming_emu,org_import_api" + maxRetries = 3 + defaultPageSize = 100 +) + +// graphqlClient is an internal GraphQL client for GitHub's migration API. +// It handles auth, required headers, pagination, and secondary rate limiting. +type graphqlClient struct { + httpClient *http.Client + baseURL string + pat string + version string + logger *logger.Logger +} + +// graphqlRequest is the JSON body sent to the GraphQL endpoint. +type graphqlRequest struct { + Query string `json:"query"` + Variables json.RawMessage `json:"variables,omitempty"` +} + +// graphqlResponse is the top-level response from the GraphQL endpoint. +type graphqlResponse struct { + Data json.RawMessage `json:"data"` + Errors []graphqlError `json:"errors"` +} + +type graphqlError struct { + Message string `json:"message"` + Type string `json:"type,omitempty"` +} + +// pageInfo mirrors the GraphQL PageInfo type. +type pageInfo struct { + HasNextPage bool `json:"hasNextPage"` + EndCursor string `json:"endCursor"` +} + +func newGraphQLClient(baseURL, pat, version string, log *logger.Logger) *graphqlClient { + return &graphqlClient{ + httpClient: &http.Client{Timeout: 30 * time.Second}, + baseURL: strings.TrimRight(baseURL, "/"), + pat: pat, + version: version, + logger: log, + } +} + +// Post sends a GraphQL query and returns the raw "data" field. +func (c *graphqlClient) Post(ctx context.Context, query string, variables json.RawMessage) (json.RawMessage, error) { + return c.doWithRetry(ctx, query, variables, 0) +} + +func (c *graphqlClient) doWithRetry(ctx context.Context, query string, variables json.RawMessage, retryCount int) (json.RawMessage, error) { + reqBody := graphqlRequest{ + Query: query, + Variables: variables, + } + bodyBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("graphql: failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+graphqlPath, bytes.NewReader(bodyBytes)) + if err != nil { + return nil, fmt.Errorf("graphql: failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+c.pat) + req.Header.Set("GraphQL-Features", graphqlFeaturesValue) + req.Header.Set("User-Agent", "OctoshiftCLI/"+c.version) + req.Header.Set("Content-Type", "application/json") + + c.logger.Debug("GraphQL POST: %s", c.baseURL+graphqlPath) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("graphql: request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("graphql: failed to read response: %w", err) + } + + // Check for secondary rate limit before checking status + if isSecondaryRateLimit(resp.StatusCode, string(respBody)) { + if retryCount >= maxRetries { + return nil, fmt.Errorf("graphql: secondary rate limit exceeded after %d retries", maxRetries) + } + delay := computeBackoff(resp, retryCount) + c.logger.Warning("Secondary rate limit hit, retrying in %v (attempt %d/%d)", delay, retryCount+1, maxRetries) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + return c.doWithRetry(ctx, query, variables, retryCount+1) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return nil, fmt.Errorf("graphql: HTTP %d: %s", resp.StatusCode, string(respBody)) + } + + var gqlResp graphqlResponse + if err := json.Unmarshal(respBody, &gqlResp); err != nil { + return nil, fmt.Errorf("graphql: failed to parse response: %w", err) + } + + if len(gqlResp.Errors) > 0 { + return nil, fmt.Errorf("graphql: %s", gqlResp.Errors[0].Message) + } + + return gqlResp.Data, nil +} + +// PostWithPagination sends a paginated GraphQL query, collecting all pages. +// dataPath is a dot-separated path to the array in the data (e.g. "organization.repositories.nodes"). +// pageInfoPath is a dot-separated path to the pageInfo object (e.g. "organization.repositories.pageInfo"). +func (c *graphqlClient) PostWithPagination( + ctx context.Context, + query string, + variables json.RawMessage, + dataPath string, + pageInfoPath string, +) (json.RawMessage, error) { + var allItems []json.RawMessage + var cursor *string + + for { + // Inject first and after into variables + vars, err := injectPaginationVars(variables, defaultPageSize, cursor) + if err != nil { + return nil, fmt.Errorf("graphql pagination: failed to inject variables: %w", err) + } + + data, err := c.Post(ctx, query, vars) + if err != nil { + return nil, err + } + + // Navigate to the data array + items, err := navigateJSON(data, dataPath) + if err != nil { + return nil, fmt.Errorf("graphql pagination: failed to navigate data path %q: %w", dataPath, err) + } + + // Parse items as array + var pageItems []json.RawMessage + if err := json.Unmarshal(items, &pageItems); err != nil { + return nil, fmt.Errorf("graphql pagination: data at path %q is not an array: %w", dataPath, err) + } + allItems = append(allItems, pageItems...) + + // Navigate to pageInfo + piRaw, err := navigateJSON(data, pageInfoPath) + if err != nil { + return nil, fmt.Errorf("graphql pagination: failed to navigate pageInfo path %q: %w", pageInfoPath, err) + } + + var pi pageInfo + if err := json.Unmarshal(piRaw, &pi); err != nil { + return nil, fmt.Errorf("graphql pagination: failed to parse pageInfo: %w", err) + } + + if !pi.HasNextPage { + break + } + cursor = &pi.EndCursor + } + + result, err := json.Marshal(allItems) + if err != nil { + return nil, fmt.Errorf("graphql pagination: failed to marshal results: %w", err) + } + return result, nil +} + +// injectPaginationVars merges "first" and "after" into the variables map. +func injectPaginationVars(variables json.RawMessage, first int, after *string) (json.RawMessage, error) { + var vars map[string]interface{} + if len(variables) > 0 { + if err := json.Unmarshal(variables, &vars); err != nil { + return nil, err + } + } + if vars == nil { + vars = make(map[string]interface{}) + } + vars["first"] = first + if after != nil { + vars["after"] = *after + } + return json.Marshal(vars) +} + +// navigateJSON walks a dot-separated path through a JSON object. +func navigateJSON(data json.RawMessage, path string) (json.RawMessage, error) { + parts := strings.Split(path, ".") + current := data + for _, part := range parts { + var obj map[string]json.RawMessage + if err := json.Unmarshal(current, &obj); err != nil { + return nil, fmt.Errorf("expected object at %q: %w", part, err) + } + val, ok := obj[part] + if !ok { + return nil, fmt.Errorf("key %q not found", part) + } + current = val + } + return current, nil +} + +// isSecondaryRateLimit checks whether the response indicates a secondary rate limit. +// It returns true for: +// - Any 429 (unless body contains "API RATE LIMIT EXCEEDED") +// - 403 with body containing "SECONDARY RATE LIMIT" or "ABUSE DETECTION" +// +// It excludes primary rate limits (403 with "API RATE LIMIT EXCEEDED"). +func isSecondaryRateLimit(statusCode int, body string) bool { + upper := strings.ToUpper(body) + + // Primary rate limit — never retry + if strings.Contains(upper, "API RATE LIMIT EXCEEDED") { + return false + } + + if statusCode == http.StatusTooManyRequests { + return true + } + + if statusCode == http.StatusForbidden { + return strings.Contains(upper, "SECONDARY RATE LIMIT") || strings.Contains(upper, "ABUSE DETECTION") + } + + return false +} + +// computeBackoff determines how long to wait before retrying. +// Priority: Retry-After header → X-RateLimit-Reset → exponential 60*2^retryCount. +func computeBackoff(resp *http.Response, retryCount int) time.Duration { + // Try Retry-After header (seconds) + if ra := resp.Header.Get("Retry-After"); ra != "" { + if seconds, err := strconv.Atoi(ra); err == nil { + return time.Duration(seconds) * time.Second + } + } + + // Try X-RateLimit-Reset (unix timestamp) + if reset := resp.Header.Get("X-RateLimit-Reset"); reset != "" { + if ts, err := strconv.ParseInt(reset, 10, 64); err == nil { + resetTime := time.Unix(ts, 0) + delay := time.Until(resetTime) + if delay > 0 { + return delay + } + } + } + + // Exponential backoff: 60 * 2^retryCount seconds + return time.Duration(60*(1< 0 { + s.log.Errorf("Failed to send reclaim invitation email to %s for mannequin %s (%s): %s", targetUser, mannequinUser, m.ID, result.Errors[0].Message) + return false + } + + if result.Source == nil || result.Target == nil || + result.Source.ID != m.ID || + result.Target.ID != targetUserID { + s.log.Errorf("Failed to send reclaim invitation email to %s for mannequin %s (%s)", targetUser, mannequinUser, m.ID) + return false + } + + s.log.Info("Mannequin reclaim invitation email successfully sent to %s for %s (%s)", targetUser, mannequinUser, m.ID) + return true +} + +func (s *ReclaimService) handleReclamationResult(mannequinUser, targetUser string, m github.Mannequin, targetUserID string, result *github.ReattributeMannequinToUserResult) bool { + if len(result.Errors) > 0 { + msg := result.Errors[0].Message + if strings.Contains(msg, "is not an Enterprise Managed Users (EMU) organization") { + s.log.Errorf("Failed to reclaim mannequins. The --skip-invitation flag is only available to EMU organizations.") + return false + } + s.log.Warning("Failed to reattribute content belonging to mannequin %s (%s) to %s: %s", mannequinUser, m.ID, targetUser, msg) + return true + } + + if result.Source == nil || result.Target == nil || + result.Source.ID != m.ID || + result.Target.ID != targetUserID { + s.log.Warning("Failed to reattribute content belonging to mannequin %s (%s) to %s", mannequinUser, m.ID, targetUser) + return true + } + + s.log.Info("Successfully reclaimed content belonging to mannequin %s (%s) to %s", mannequinUser, m.ID, targetUser) + return true +} + +// parsedEntry represents a parsed line from the mannequin CSV. +type parsedEntry struct { + login string + id string + claimantLogin string +} + +// --- helper functions --- + +func filterByLogin(mannequins []github.Mannequin, login, id string) []github.Mannequin { + var result []github.Mannequin + for _, m := range mannequins { + if strings.EqualFold(m.Login, login) && (id == "" || strings.EqualFold(m.ID, id)) { + result = append(result, m) + } + } + return result +} + +func isClaimed(mannequins []github.Mannequin, login, id string) bool { + for _, m := range mannequins { + if strings.EqualFold(m.Login, login) && + (id == "" || strings.EqualFold(m.ID, id)) && + m.MappedUser != nil { + return true + } + } + return false +} + +func findFirst(mannequins []github.Mannequin, login, id string) *github.Mannequin { + for i, m := range mannequins { + if strings.EqualFold(m.Login, login) && strings.EqualFold(m.ID, id) { + return &mannequins[i] + } + } + return nil +} + +func uniqueUsers(mannequins []github.Mannequin) []github.Mannequin { + seen := make(map[string]bool) + var result []github.Mannequin + for _, m := range mannequins { + key := fmt.Sprintf("%s__%s", m.ID, m.Login) + if !seen[key] { + seen[key] = true + result = append(result, m) + } + } + return result +} + +func isDuplicate(parsed []parsedEntry, idx int) bool { + target := parsed[idx] + count := 0 + for _, p := range parsed { + if p.login == target.login && p.id == target.id { + count++ + } + } + return count > 1 +} diff --git a/pkg/mannequin/service_test.go b/pkg/mannequin/service_test.go new file mode 100644 index 000000000..51097bf0c --- /dev/null +++ b/pkg/mannequin/service_test.go @@ -0,0 +1,598 @@ +package mannequin_test + +import ( + "bytes" + "context" + "fmt" + "testing" + + "github.com/github/gh-gei/pkg/github" + "github.com/github/gh-gei/pkg/logger" + "github.com/github/gh-gei/pkg/mannequin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockGitHubAPI implements mannequin.GitHubAPI for testing. +type mockGitHubAPI struct { + orgID string + orgIDErr error + + mannequins []github.Mannequin + mannequinsErr error + + mannequinsByLogin []github.Mannequin + mannequinsByLoginErr error + + userID string + userIDErr error + + invitationResult *github.CreateAttributionInvitationResult + invitationErr error + + reclaimResult *github.ReattributeMannequinToUserResult + reclaimErr error + + // capture calls + gotOrgIDOrg string + gotMannequinsOrgID string + gotMannequinsByLogin struct{ OrgID, Login string } + gotUserIDLogin string + invitations []struct{ OrgID, SourceID, TargetID string } + reclaims []struct{ OrgID, SourceID, TargetID string } +} + +func (m *mockGitHubAPI) GetOrganizationId(_ context.Context, org string) (string, error) { + m.gotOrgIDOrg = org + return m.orgID, m.orgIDErr +} + +func (m *mockGitHubAPI) GetMannequins(_ context.Context, orgID string) ([]github.Mannequin, error) { + m.gotMannequinsOrgID = orgID + return m.mannequins, m.mannequinsErr +} + +func (m *mockGitHubAPI) GetMannequinsByLogin(_ context.Context, orgID, login string) ([]github.Mannequin, error) { + m.gotMannequinsByLogin = struct{ OrgID, Login string }{orgID, login} + return m.mannequinsByLogin, m.mannequinsByLoginErr +} + +func (m *mockGitHubAPI) GetUserId(_ context.Context, login string) (string, error) { + m.gotUserIDLogin = login + return m.userID, m.userIDErr +} + +func (m *mockGitHubAPI) CreateAttributionInvitation(_ context.Context, orgID, sourceID, targetID string) (*github.CreateAttributionInvitationResult, error) { + m.invitations = append(m.invitations, struct{ OrgID, SourceID, TargetID string }{orgID, sourceID, targetID}) + return m.invitationResult, m.invitationErr +} + +func (m *mockGitHubAPI) ReclaimMannequinSkipInvitation(_ context.Context, orgID, sourceID, targetID string) (*github.ReattributeMannequinToUserResult, error) { + m.reclaims = append(m.reclaims, struct{ OrgID, SourceID, TargetID string }{orgID, sourceID, targetID}) + return m.reclaimResult, m.reclaimErr +} + +const ( + testOrg = "FooOrg" + testOrgID = "org-id-123" +) + +// ---------- ParseLine tests (tested indirectly through ReclaimMannequins) ---------- + +func TestReclaimMannequins_InvalidLine_Skipped(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "bad-line", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "Invalid line") +} + +func TestReclaimMannequins_EmptyFieldsLine_Skipped(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + ",,", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "Mannequin login is not defined") +} + +func TestReclaimMannequins_EmptyFile_WarnsAndReturns(t *testing.T) { + mock := &mockGitHubAPI{} + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequins(context.Background(), []string{}, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "File is empty. Nothing to reclaim") +} + +func TestReclaimMannequins_InvalidHeader_ReturnsError(t *testing.T) { + mock := &mockGitHubAPI{} + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{"bad-header"} + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "Invalid Header") +} + +func TestReclaimMannequins_BlankLines_Skipped(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{}, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "", + " ", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + // No warnings about invalid lines — blank lines are silently skipped +} + +func TestReclaimMannequins_AlreadyClaimed_NotForce_Skips(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona", MappedUser: &github.MannequinUser{ID: "u1", Login: "mona_gh"}}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "already claimed") +} + +func TestReclaimMannequins_MannequinNotFound_Skips(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{}, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "not found") +} + +func TestReclaimMannequins_DuplicateInCSV_Skips(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Source: &github.Mannequin{ID: "m1", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + "mona,m1,target_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "duplicate") +} + +func TestReclaimMannequins_ClaimantNotFound_Skips(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userIDErr: fmt.Errorf("Could not resolve to a User with the login"), + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,nonexistent_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "Claimant") + assert.Contains(t, buf.String(), "not found") +} + +func TestReclaimMannequins_HappyPath_Invitation(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Source: &github.Mannequin{ID: "m1", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "invitation email successfully sent") + require.Len(t, mock.invitations, 1) + assert.Equal(t, testOrgID, mock.invitations[0].OrgID) + assert.Equal(t, "m1", mock.invitations[0].SourceID) + assert.Equal(t, "target-id", mock.invitations[0].TargetID) +} + +func TestReclaimMannequins_HappyPath_SkipInvitation(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + reclaimResult: &github.ReattributeMannequinToUserResult{ + Source: &github.Mannequin{ID: "m1", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, true) + require.NoError(t, err) + assert.Contains(t, buf.String(), "Successfully reclaimed") + require.Len(t, mock.reclaims, 1) +} + +func TestReclaimMannequins_SkipInvitation_EMUError_StopsProcessing(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + {ID: "m2", Login: "lisa"}, + }, + userID: "target-id", + reclaimResult: &github.ReattributeMannequinToUserResult{ + Errors: []github.ErrorData{ + {Message: "is not an Enterprise Managed Users (EMU) organization"}, + }, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + "lisa,m2,target_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, true) + require.NoError(t, err) + assert.Contains(t, buf.String(), "EMU organizations") + // Should only have tried one reclaim (stopped after first) + assert.Len(t, mock.reclaims, 1) +} + +func TestReclaimMannequins_AlreadyClaimed_Force_Proceeds(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona", MappedUser: &github.MannequinUser{ID: "u1", Login: "old_user"}}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Source: &github.Mannequin{ID: "m1", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + } + // force=true should proceed even though already claimed + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, true, false) + require.NoError(t, err) + require.Len(t, mock.invitations, 1) +} + +// ---------- ReclaimMannequin (single mode) tests ---------- + +func TestReclaimMannequin_UserNotMannequin_ReturnsError(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{}, // no mannequins found + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target", testOrg, false, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "is not a mannequin") +} + +func TestReclaimMannequin_AlreadyMapped_NotForce_ReturnsError(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona", MappedUser: &github.MannequinUser{ID: "u1", Login: "old"}}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target", testOrg, false, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "already mapped") +} + +func TestReclaimMannequin_HappyPath_Invitation(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Source: &github.Mannequin{ID: "m1", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target_user", testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "invitation email successfully sent") + require.Len(t, mock.invitations, 1) +} + +func TestReclaimMannequin_HappyPath_SkipInvitation(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + reclaimResult: &github.ReattributeMannequinToUserResult{ + Source: &github.Mannequin{ID: "m1", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target_user", testOrg, false, true) + require.NoError(t, err) + assert.Contains(t, buf.String(), "Successfully reclaimed") + require.Len(t, mock.reclaims, 1) +} + +func TestReclaimMannequin_WithMannequinID_FiltersCorrectly(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + {ID: "m2", Login: "mona"}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Source: &github.Mannequin{ID: "m2", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "m2", "target_user", testOrg, false, false) + require.NoError(t, err) + // Should only have invited one mannequin (m2) + require.Len(t, mock.invitations, 1) + assert.Equal(t, "m2", mock.invitations[0].SourceID) +} + +func TestReclaimMannequin_InvitationFails_ReturnsError(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Errors: []github.ErrorData{{Message: "some error"}}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target_user", testOrg, false, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "Failed to send reclaim mannequin invitation(s)") +} + +func TestReclaimMannequin_SkipInvitation_EMUError_ReturnsError(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + reclaimResult: &github.ReattributeMannequinToUserResult{ + Errors: []github.ErrorData{ + {Message: "is not an Enterprise Managed Users (EMU) organization"}, + }, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target_user", testOrg, false, true) + require.Error(t, err) + assert.Contains(t, err.Error(), "Failed to reclaim mannequin") +} + +func TestReclaimMannequin_InvitationResultMismatch_ReturnsError(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Source: &github.Mannequin{ID: "wrong-id", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target_user", testOrg, false, false) + require.Error(t, err) + assert.Contains(t, err.Error(), "Failed to send reclaim mannequin invitation(s)") +} + +func TestReclaimMannequin_Force_AlreadyMapped_Proceeds(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequinsByLogin: []github.Mannequin{ + {ID: "m1", Login: "mona", MappedUser: &github.MannequinUser{ID: "u1", Login: "old"}}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Source: &github.Mannequin{ID: "m1", Login: "mona"}, + Target: &github.MannequinUser{ID: "target-id", Login: "target_user"}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + err := svc.ReclaimMannequin(context.Background(), "mona", "", "target_user", testOrg, true, false) + require.NoError(t, err) + require.Len(t, mock.invitations, 1) +} + +// ---------- HandleInvitationResult edge cases ---------- + +func TestReclaimMannequins_Invitation_WithErrors_LogsError(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + }, + userID: "target-id", + invitationResult: &github.CreateAttributionInvitationResult{ + Errors: []github.ErrorData{{Message: "invitation error"}}, + }, + } + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + } + // In CSV mode, invitation errors are logged but processing continues + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, false) + require.NoError(t, err) + assert.Contains(t, buf.String(), "Failed to send reclaim invitation") +} + +func TestReclaimMannequins_SkipInvitation_OtherError_ContinuesProcessing(t *testing.T) { + mock := &mockGitHubAPI{ + orgID: testOrgID, + mannequins: []github.Mannequin{ + {ID: "m1", Login: "mona"}, + {ID: "m2", Login: "lisa"}, + }, + userID: "target-id", + } + mock.reclaimResult = &github.ReattributeMannequinToUserResult{ + Errors: []github.ErrorData{{Message: "some non-EMU error"}}, + } + + var buf bytes.Buffer + log := logger.New(false, &buf) + svc := mannequin.NewReclaimService(mock, log) + + lines := []string{ + mannequin.CSVHeader, + "mona,m1,target_user", + "lisa,m2,target_user", + } + err := svc.ReclaimMannequins(context.Background(), lines, testOrg, false, true) + require.NoError(t, err) + // Should have tried both reclaims (non-EMU errors continue) + assert.Len(t, mock.reclaims, 2) +} diff --git a/pkg/migration/status.go b/pkg/migration/status.go new file mode 100644 index 000000000..eff12d31e --- /dev/null +++ b/pkg/migration/status.go @@ -0,0 +1,58 @@ +// Package migration provides constants and helpers for migration status tracking. +package migration + +import "strings" + +// Repository migration status constants. +const ( + RepoQueued = "QUEUED" + RepoInProgress = "IN_PROGRESS" + RepoFailed = "FAILED" + RepoSucceeded = "SUCCEEDED" + RepoPendingValidation = "PENDING_VALIDATION" + RepoFailedValidation = "FAILED_VALIDATION" +) + +// IsRepoSucceeded returns true if the repository migration state is SUCCEEDED. +func IsRepoSucceeded(state string) bool { return normalize(state) == RepoSucceeded } + +// IsRepoPending returns true if the repository migration is still in progress. +func IsRepoPending(state string) bool { + s := normalize(state) + return s == RepoQueued || s == RepoInProgress || s == RepoPendingValidation +} + +// IsRepoFailed returns true if the repository migration has failed. +// Any state that is neither pending nor succeeded is considered failed. +func IsRepoFailed(state string) bool { return !IsRepoPending(state) && !IsRepoSucceeded(state) } + +// Organization migration status constants. +const ( + OrgQueued = "QUEUED" + OrgInProgress = "IN_PROGRESS" + OrgFailed = "FAILED" + OrgSucceeded = "SUCCEEDED" + OrgNotStarted = "NOT_STARTED" + OrgPostRepoMigration = "POST_REPO_MIGRATION" + OrgPreRepoMigration = "PRE_REPO_MIGRATION" + OrgRepoMigration = "REPO_MIGRATION" +) + +// IsOrgSucceeded returns true if the organization migration state is SUCCEEDED. +func IsOrgSucceeded(state string) bool { return normalize(state) == OrgSucceeded } + +// IsOrgPending returns true if the organization migration is still in progress. +func IsOrgPending(state string) bool { + s := normalize(state) + return s == OrgQueued || s == OrgInProgress || s == OrgNotStarted || + s == OrgPostRepoMigration || s == OrgPreRepoMigration || s == OrgRepoMigration +} + +// IsOrgFailed returns true if the organization migration has failed. +// Any state that is neither pending nor succeeded is considered failed. +func IsOrgFailed(state string) bool { return !IsOrgPending(state) && !IsOrgSucceeded(state) } + +// IsOrgRepoMigration returns true if the organization migration is in the repo migration phase. +func IsOrgRepoMigration(state string) bool { return normalize(state) == OrgRepoMigration } + +func normalize(s string) string { return strings.ToUpper(strings.TrimSpace(s)) } diff --git a/pkg/migration/status_test.go b/pkg/migration/status_test.go new file mode 100644 index 000000000..d8ebcb54c --- /dev/null +++ b/pkg/migration/status_test.go @@ -0,0 +1,175 @@ +package migration_test + +import ( + "testing" + + "github.com/github/gh-gei/pkg/migration" +) + +func TestIsRepoSucceeded(t *testing.T) { + tests := []struct { + name string + state string + want bool + }{ + {"exact match", "SUCCEEDED", true}, + {"lowercase", "succeeded", true}, + {"mixed case", "Succeeded", true}, + {"with whitespace", " SUCCEEDED ", true}, + {"queued", "QUEUED", false}, + {"in progress", "IN_PROGRESS", false}, + {"failed", "FAILED", false}, + {"empty", "", false}, + {"unknown", "UNKNOWN", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := migration.IsRepoSucceeded(tt.state); got != tt.want { + t.Errorf("IsRepoSucceeded(%q) = %v, want %v", tt.state, got, tt.want) + } + }) + } +} + +func TestIsRepoPending(t *testing.T) { + tests := []struct { + name string + state string + want bool + }{ + {"queued", "QUEUED", true}, + {"in progress", "IN_PROGRESS", true}, + {"pending validation", "PENDING_VALIDATION", true}, + {"queued lowercase", "queued", true}, + {"with whitespace", " IN_PROGRESS ", true}, + {"succeeded", "SUCCEEDED", false}, + {"failed", "FAILED", false}, + {"empty", "", false}, + {"unknown", "SOMETHING_ELSE", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := migration.IsRepoPending(tt.state); got != tt.want { + t.Errorf("IsRepoPending(%q) = %v, want %v", tt.state, got, tt.want) + } + }) + } +} + +func TestIsRepoFailed(t *testing.T) { + tests := []struct { + name string + state string + want bool + }{ + {"failed", "FAILED", true}, + {"failed validation", "FAILED_VALIDATION", true}, + {"unknown value", "WEIRD", true}, + {"empty string", "", true}, + {"succeeded", "SUCCEEDED", false}, + {"queued", "QUEUED", false}, + {"in progress", "IN_PROGRESS", false}, + {"pending validation", "PENDING_VALIDATION", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := migration.IsRepoFailed(tt.state); got != tt.want { + t.Errorf("IsRepoFailed(%q) = %v, want %v", tt.state, got, tt.want) + } + }) + } +} + +func TestIsOrgSucceeded(t *testing.T) { + tests := []struct { + name string + state string + want bool + }{ + {"exact match", "SUCCEEDED", true}, + {"lowercase", "succeeded", true}, + {"with whitespace", " SUCCEEDED ", true}, + {"queued", "QUEUED", false}, + {"failed", "FAILED", false}, + {"empty", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := migration.IsOrgSucceeded(tt.state); got != tt.want { + t.Errorf("IsOrgSucceeded(%q) = %v, want %v", tt.state, got, tt.want) + } + }) + } +} + +func TestIsOrgPending(t *testing.T) { + tests := []struct { + name string + state string + want bool + }{ + {"queued", "QUEUED", true}, + {"in progress", "IN_PROGRESS", true}, + {"not started", "NOT_STARTED", true}, + {"post repo migration", "POST_REPO_MIGRATION", true}, + {"pre repo migration", "PRE_REPO_MIGRATION", true}, + {"repo migration", "REPO_MIGRATION", true}, + {"lowercase", "queued", true}, + {"with whitespace", " IN_PROGRESS ", true}, + {"succeeded", "SUCCEEDED", false}, + {"failed", "FAILED", false}, + {"empty", "", false}, + {"unknown", "SOMETHING_ELSE", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := migration.IsOrgPending(tt.state); got != tt.want { + t.Errorf("IsOrgPending(%q) = %v, want %v", tt.state, got, tt.want) + } + }) + } +} + +func TestIsOrgFailed(t *testing.T) { + tests := []struct { + name string + state string + want bool + }{ + {"failed", "FAILED", true}, + {"unknown value", "WEIRD", true}, + {"empty string", "", true}, + {"succeeded", "SUCCEEDED", false}, + {"queued", "QUEUED", false}, + {"in progress", "IN_PROGRESS", false}, + {"repo migration", "REPO_MIGRATION", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := migration.IsOrgFailed(tt.state); got != tt.want { + t.Errorf("IsOrgFailed(%q) = %v, want %v", tt.state, got, tt.want) + } + }) + } +} + +func TestIsOrgRepoMigration(t *testing.T) { + tests := []struct { + name string + state string + want bool + }{ + {"exact match", "REPO_MIGRATION", true}, + {"lowercase", "repo_migration", true}, + {"with whitespace", " REPO_MIGRATION ", true}, + {"other state", "IN_PROGRESS", false}, + {"empty", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := migration.IsOrgRepoMigration(tt.state); got != tt.want { + t.Errorf("IsOrgRepoMigration(%q) = %v, want %v", tt.state, got, tt.want) + } + }) + } +} diff --git a/pkg/status/github.go b/pkg/status/github.go new file mode 100644 index 000000000..63f109b3f --- /dev/null +++ b/pkg/status/github.go @@ -0,0 +1,47 @@ +package status + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" +) + +// statusResponse is the minimal shape of the GitHub status API response. +type statusResponse struct { + Incidents []json.RawMessage `json:"incidents"` +} + +// GetUnresolvedIncidentsCount fetches the count of unresolved GitHub incidents. +// baseURL allows overriding the status API base URL for testing. +func GetUnresolvedIncidentsCount(ctx context.Context, client *http.Client, baseURL string) (int, error) { + url := baseURL + "/api/v2/incidents/unresolved.json" + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return 0, fmt.Errorf("creating status request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return 0, fmt.Errorf("fetching GitHub status: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("GitHub status API returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) // 1MB limit + if err != nil { + return 0, fmt.Errorf("reading status response: %w", err) + } + + var result statusResponse + if err := json.Unmarshal(body, &result); err != nil { + return 0, fmt.Errorf("parsing status response: %w", err) + } + + return len(result.Incidents), nil +} diff --git a/pkg/status/github_test.go b/pkg/status/github_test.go new file mode 100644 index 000000000..5d1d416e8 --- /dev/null +++ b/pkg/status/github_test.go @@ -0,0 +1,83 @@ +package status + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetUnresolvedIncidentsCount(t *testing.T) { + t.Run("returns count of unresolved incidents", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/v2/incidents/unresolved.json", r.URL.Path) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"incidents":[{"id":"1","name":"incident1"},{"id":"2","name":"incident2"}]}`)) + })) + defer server.Close() + + count, err := GetUnresolvedIncidentsCount(context.Background(), &http.Client{}, server.URL) + + require.NoError(t, err) + assert.Equal(t, 2, count) + }) + + t.Run("returns zero when no incidents", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"incidents":[]}`)) + })) + defer server.Close() + + count, err := GetUnresolvedIncidentsCount(context.Background(), &http.Client{}, server.URL) + + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + + t.Run("returns error on network failure", func(t *testing.T) { + _, err := GetUnresolvedIncidentsCount(context.Background(), &http.Client{}, "http://localhost:1") + + assert.Error(t, err) + }) + + t.Run("returns error on non-200 status", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer server.Close() + + _, err := GetUnresolvedIncidentsCount(context.Background(), &http.Client{}, server.URL) + + assert.Error(t, err) + }) + + t.Run("returns error on invalid JSON", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`not json`)) + })) + defer server.Close() + + _, err := GetUnresolvedIncidentsCount(context.Background(), &http.Client{}, server.URL) + + assert.Error(t, err) + }) + + t.Run("returns zero when incidents key missing", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"page":{}}`)) + })) + defer server.Close() + + // When "incidents" key is missing, the result should be 0 incidents (empty array default) + count, err := GetUnresolvedIncidentsCount(context.Background(), &http.Client{}, server.URL) + + require.NoError(t, err) + assert.Equal(t, 0, count) + }) +} diff --git a/pkg/version/checker.go b/pkg/version/checker.go new file mode 100644 index 000000000..2a67c1156 --- /dev/null +++ b/pkg/version/checker.go @@ -0,0 +1,148 @@ +package version + +import ( + "context" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/github/gh-gei/pkg/logger" +) + +const defaultVersionURL = "https://raw.githubusercontent.com/github/gh-gei/main/LATEST-VERSION.txt" + +// semver represents a simple major.minor.patch version. +type semver struct { + major, minor, patch int +} + +func (v semver) compare(other semver) int { + if v.major != other.major { + if v.major > other.major { + return 1 + } + return -1 + } + if v.minor != other.minor { + if v.minor > other.minor { + return 1 + } + return -1 + } + if v.patch != other.patch { + if v.patch > other.patch { + return 1 + } + return -1 + } + return 0 +} + +// parseVersion parses a version string like "v1.27.0" into a semver. +func parseVersion(s string) (semver, error) { + s = strings.TrimSpace(s) + s = strings.TrimPrefix(s, "v") + s = strings.TrimPrefix(s, "V") + + parts := strings.Split(s, ".") + if len(parts) != 3 { + return semver{}, fmt.Errorf("invalid version format: %q (expected major.minor.patch)", s) + } + + major, err := strconv.Atoi(parts[0]) + if err != nil { + return semver{}, fmt.Errorf("invalid major version %q: %w", parts[0], err) + } + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return semver{}, fmt.Errorf("invalid minor version %q: %w", parts[1], err) + } + patch, err := strconv.Atoi(parts[2]) + if err != nil { + return semver{}, fmt.Errorf("invalid patch version %q: %w", parts[2], err) + } + + return semver{major, minor, patch}, nil +} + +// Checker checks whether the current CLI version is the latest. +type Checker struct { + httpClient *http.Client + logger *logger.Logger + version string + latestVersion *string // cached + versionURL string +} + +// NewChecker creates a new version Checker. +func NewChecker(httpClient *http.Client, log *logger.Logger, version string) *Checker { + return &Checker{ + httpClient: httpClient, + logger: log, + version: version, + versionURL: defaultVersionURL, + } +} + +// GetLatestVersion fetches the latest version string from GitHub. +func (c *Checker) GetLatestVersion(ctx context.Context) (string, error) { + if c.latestVersion != nil { + return *c.latestVersion, nil + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.versionURL, nil) + if err != nil { + return "", fmt.Errorf("creating version check request: %w", err) + } + req.Header.Set("User-Agent", "OctoshiftCLI/"+c.version) + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetching latest version: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("version check returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) // version file is tiny + if err != nil { + return "", fmt.Errorf("reading version response: %w", err) + } + + raw := strings.TrimSpace(string(body)) + + // Validate it parses (parseVersion handles v/V prefix stripping) + v, err := parseVersion(raw) + if err != nil { + return "", err + } + + // Cache the normalized version string (no prefix) + normalized := fmt.Sprintf("%d.%d.%d", v.major, v.minor, v.patch) + c.latestVersion = &normalized + return normalized, nil +} + +// IsLatest returns true if the current version is >= the latest published version. +func (c *Checker) IsLatest(ctx context.Context) (bool, error) { + latestStr, err := c.GetLatestVersion(ctx) + if err != nil { + return false, err + } + + current, err := parseVersion(c.version) + if err != nil { + return false, fmt.Errorf("parsing current version: %w", err) + } + + latest, err := parseVersion(latestStr) + if err != nil { + return false, fmt.Errorf("parsing latest version: %w", err) + } + + return current.compare(latest) >= 0, nil +} diff --git a/pkg/version/checker_test.go b/pkg/version/checker_test.go new file mode 100644 index 000000000..85e7d0a26 --- /dev/null +++ b/pkg/version/checker_test.go @@ -0,0 +1,211 @@ +package version + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/gh-gei/pkg/logger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseVersion(t *testing.T) { + tests := []struct { + name string + input string + want semver + wantErr bool + }{ + {"three parts", "1.27.0", semver{1, 27, 0}, false}, + {"with v prefix", "v1.27.0", semver{1, 27, 0}, false}, + {"with V prefix", "V1.27.0", semver{1, 27, 0}, false}, + {"with whitespace", " v1.27.0\n", semver{1, 27, 0}, false}, + {"two parts", "1.27", semver{0, 0, 0}, true}, + {"empty", "", semver{0, 0, 0}, true}, + {"non-numeric", "a.b.c", semver{0, 0, 0}, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseVersion(tt.input) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSemverCompare(t *testing.T) { + tests := []struct { + name string + a, b semver + want int + }{ + {"equal", semver{1, 2, 3}, semver{1, 2, 3}, 0}, + {"major greater", semver{2, 0, 0}, semver{1, 9, 9}, 1}, + {"major less", semver{1, 0, 0}, semver{2, 0, 0}, -1}, + {"minor greater", semver{1, 3, 0}, semver{1, 2, 9}, 1}, + {"minor less", semver{1, 2, 0}, semver{1, 3, 0}, -1}, + {"patch greater", semver{1, 2, 4}, semver{1, 2, 3}, 1}, + {"patch less", semver{1, 2, 3}, semver{1, 2, 4}, -1}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.a.compare(tt.b) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChecker_IsLatest(t *testing.T) { + t.Run("current version equals latest returns true", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("v1.27.0\n")) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + isLatest, err := checker.IsLatest(context.Background()) + + require.NoError(t, err) + assert.True(t, isLatest) + }) + + t.Run("current version greater than latest returns true", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("v1.26.0\n")) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + isLatest, err := checker.IsLatest(context.Background()) + + require.NoError(t, err) + assert.True(t, isLatest) + }) + + t.Run("current version less than latest returns false", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("v1.28.0\n")) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + isLatest, err := checker.IsLatest(context.Background()) + + require.NoError(t, err) + assert.False(t, isLatest) + }) + + t.Run("network error returns error", func(t *testing.T) { + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = "http://localhost:1" // nothing listening + + _, err := checker.IsLatest(context.Background()) + + assert.Error(t, err) + }) + + t.Run("server returns non-200 returns error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + _, err := checker.IsLatest(context.Background()) + + assert.Error(t, err) + }) + + t.Run("server returns invalid version returns error", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not-a-version")) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + _, err := checker.IsLatest(context.Background()) + + assert.Error(t, err) + }) + + t.Run("caches latest version after first fetch", func(t *testing.T) { + calls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls++ + w.WriteHeader(http.StatusOK) + w.Write([]byte("v1.27.0\n")) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + _, _ = checker.IsLatest(context.Background()) + _, _ = checker.IsLatest(context.Background()) + + assert.Equal(t, 1, calls) + }) + + t.Run("sets user-agent header", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "OctoshiftCLI/1.27.0", r.Header.Get("User-Agent")) + w.WriteHeader(http.StatusOK) + w.Write([]byte("v1.27.0\n")) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + _, _ = checker.IsLatest(context.Background()) + }) +} + +func TestChecker_GetLatestVersion(t *testing.T) { + t.Run("returns fetched version string", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("v1.28.0\n")) + })) + defer server.Close() + + log := logger.New(false) + checker := NewChecker(&http.Client{}, log, "1.27.0") + checker.versionURL = server.URL + + ver, err := checker.GetLatestVersion(context.Background()) + + require.NoError(t, err) + assert.Equal(t, "1.28.0", ver) + }) +}