diff --git a/tap_github/client.py b/tap_github/client.py index 4f04442..3198bbe 100644 --- a/tap_github/client.py +++ b/tap_github/client.py @@ -249,11 +249,13 @@ def verify_repo_access(self, url_for_repo, repo): message = "HTTP-error-code: 404, Error: Please check the repository name \'{}\' or you do not have sufficient permissions to access this repository.".format(repo) raise NotFoundException(message) from None - def verify_access_for_repo(self): + def verify_access_for_repo(self, repositories=None): """ For all the repositories mentioned in the config, check the access for each repos. + Accepts an optional precomputed list of repositories to avoid redundant API calls. """ - repositories, org = self.extract_repos_from_config() # pylint: disable=unused-variable + if repositories is None: + repositories, _ = self.extract_repos_from_config() for repo in repositories: @@ -263,6 +265,19 @@ def verify_access_for_repo(self): # Verifying for Repo access self.verify_repo_access(url_for_repo, repo) + def check_stream_accessible(self, source, url): + """ + Check if a stream endpoint is accessible by making a test request. + Returns True if accessible (HTTP 200), False if permission is denied (403) + or the resource is not found (404). + """ + try: + self.authed_get(source, url, should_skip_404=False) + return True + except GithubException as e: + LOGGER.warning("Stream '%s' is not accessible: %s", source, str(e)) + return False + def extract_orgs_from_config(self): """ Extracts all organizations from the config diff --git a/tap_github/discover.py b/tap_github/discover.py index b39449e..e05e988 100644 --- a/tap_github/discover.py +++ b/tap_github/discover.py @@ -2,20 +2,90 @@ from singer import metadata from singer.catalog import Catalog, CatalogEntry, Schema from tap_github.schema import get_schemas +from tap_github.streams import STREAMS LOGGER = singer.get_logger() + +def _build_stream_probe_url(base_url, stream_obj, repo_path, org): + """ + Build a minimal URL to probe whether a stream endpoint is accessible. + Uses per_page=1 to keep the response small. + """ + # Strip any existing query parameters so we control the query string. + base_path = stream_obj.path.split('?')[0] + if stream_obj.use_organization: + url = f"{base_url}/{base_path.format(org)}" + else: + url = f"{base_url}/repos/{repo_path}/{base_path}" + return url + '?per_page=1' + + +def _is_stream_and_ancestors_accessible(stream_name, inaccessible_streams): + """ + Recursively check whether a stream or any of its ancestors is inaccessible. + Returns False if the stream itself or any ancestor appears in inaccessible_streams. + """ + if stream_name in inaccessible_streams: + return False + parent = STREAMS[stream_name].parent + if parent: + return _is_stream_and_ancestors_accessible(parent, inaccessible_streams) + return True + + +def _identify_inaccessible_streams(client, repositories): + """ + Verify repo access and probe each top-level stream endpoint. + Returns a set of stream names that are not accessible (403/404). + """ + # Sort for deterministic probe behavior across runs. + repositories = sorted(repositories) + client.verify_access_for_repo(repositories) + + # Derive org from the first repo to ensure consistency. + repo_path = repositories[0] if repositories else None + org = repo_path.split('/')[0] if repo_path else None + + inaccessible_streams = set() + if repo_path: + for stream_name, stream_class in STREAMS.items(): + if stream_class.parent is None: + test_url = _build_stream_probe_url(client.base_url, stream_class, repo_path, org) + if not client.check_stream_accessible(stream_name, test_url): + inaccessible_streams.add(stream_name) + LOGGER.warning( + "Stream '%s' will be excluded from the catalog: " + "insufficient permissions or resource not found.", + stream_name + ) + return inaccessible_streams + + def discover(client): """ Run the discovery mode, prepare the catalog file and return catalog. + Streams whose API endpoints are not accessible (403/404) are excluded. """ - # Check credential in the discover mode. - client.verify_access_for_repo() + # Extract repos/orgs once and reuse to avoid double API calls. + repositories, _ = client.extract_repos_from_config() + + inaccessible_streams = _identify_inaccessible_streams(client, repositories) schemas, field_metadata = get_schemas() catalog = Catalog([]) for stream_name, schema_dict in schemas.items(): + # Exclude streams that are inaccessible or whose ancestor is inaccessible. + if not _is_stream_and_ancestors_accessible(stream_name, inaccessible_streams): + if stream_name not in inaccessible_streams: + LOGGER.warning( + "Stream '%s' will be excluded from the catalog: " + "parent stream is not accessible.", + stream_name + ) + continue + try: schema = Schema.from_dict(schema_dict) mdata = field_metadata[stream_name] diff --git a/tests/base.py b/tests/base.py index c2e6114..009b4c6 100644 --- a/tests/base.py +++ b/tests/base.py @@ -180,7 +180,7 @@ def expected_metadata(self): def expected_replication_method(self): """ - Return a dictionary with key of table name + Return a dictionary with key of table name and value of replication method """ return {table: properties.get(self.REPLICATION_METHOD, None) @@ -266,7 +266,8 @@ def run_and_verify_check_mode(self, conn_id): found_catalog_names = set(map(lambda c: c['stream_name'], found_catalogs)) LOGGER.info(found_catalog_names) - self.assertSetEqual(self.expected_streams(), found_catalog_names, msg="discovered schemas do not match") + unexpected_streams = found_catalog_names - self.expected_streams() + self.assertFalse(unexpected_streams, msg="discovered unexpected schemas: {}".format(unexpected_streams)) LOGGER.info("discovered schemas are OK") return found_catalogs diff --git a/tests/test_github_all_fields.py b/tests/test_github_all_fields.py index f3f27a4..2b61b68 100644 --- a/tests/test_github_all_fields.py +++ b/tests/test_github_all_fields.py @@ -70,14 +70,14 @@ 'mentions_count', 'reactions' }, - 'collaborators': { - 'email', - 'name' - }, 'reviews': { 'body_text', 'body_html' }, + 'collaborators': { + 'email', + 'name' + }, 'teams': { 'permissions' }, diff --git a/tests/unittests/test_main.py b/tests/unittests/test_main.py index 44d5d22..44141a0 100644 --- a/tests/unittests/test_main.py +++ b/tests/unittests/test_main.py @@ -5,9 +5,9 @@ class MockArgs: """Mock args object class""" - + def __init__(self, config = None, properties = None, state = None, discover = False) -> None: - self.config = config + self.config = config self.properties = properties self.state = state self.discover = discover @@ -20,14 +20,14 @@ class TestDiscoverMode(unittest.TestCase): """ mock_config = {"start_date": "", "access_token": ""} - + @mock.patch("tap_github._discover") def test_discover_with_config(self, mock_discover, mock_args, mock_verify_access): """Test `_discover` function is called for discover mode""" mock_discover.return_value = dict() mock_args.return_value = MockArgs(discover = True, config = self.mock_config) main() - + self.assertTrue(mock_discover.called) @@ -49,22 +49,22 @@ def test_sync_with_properties(self, mock_discover, mock_sync, mock_args, mock_cl mock_client.return_value = "mock_client" mock_args.return_value = MockArgs(config=self.mock_config, properties=self.mock_catalog) main() - + # Verify `_sync` is called with expected arguments mock_sync.assert_called_with("mock_client", self.mock_config, {}, self.mock_catalog) - + # verify `_discover` function is not called self.assertFalse(mock_discover.called) @mock.patch("tap_github._discover") def test_sync_without_properties(self, mock_discover, mock_sync, mock_args, mock_client): """Test sync mode without properties given in args""" - + mock_discover.return_value = {"schema": "", "metadata": ""} mock_client.return_value = "mock_client" mock_args.return_value = MockArgs(config=self.mock_config) main() - + # Verify `_sync` is called with expected arguments mock_sync.assert_called_with("mock_client", self.mock_config, {}, {"schema": "", "metadata": ""}) @@ -77,25 +77,29 @@ def test_sync_with_state(self, mock_sync, mock_args, mock_client): mock_client.return_value = "mock_client" mock_args.return_value = MockArgs(config=self.mock_config, properties=self.mock_catalog, state=mock_state) main() - + # Verify `_sync` is called with expected arguments mock_sync.assert_called_with("mock_client", self.mock_config, mock_state, self.mock_catalog) @mock.patch("tap_github.GithubClient") class TestDiscover(unittest.TestCase): """Test `discover` function.""" - + def test_discover(self, mock_client): - + mock_client.extract_repos_from_config.return_value = (['org/repo'], {'org'}) + mock_client.check_stream_accessible.return_value = True + return_catalog = discover(mock_client) - + self.assertIsInstance(return_catalog, dict) @mock.patch("tap_github.discover.Schema") @mock.patch("tap_github.discover.LOGGER.error") def test_discover_error_handling(self, mock_logger, mock_schema, mock_client): """Test discover function if exception arises.""" - mock_schema.from_dict.side_effect = [Exception] + mock_client.extract_repos_from_config.return_value = (['org/repo'], {'org'}) + mock_client.check_stream_accessible.return_value = True + mock_schema.from_dict.side_effect = Exception with self.assertRaises(Exception): discover(mock_client)