Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions tap_github/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,13 @@ def verify_repo_access(self, url_for_repo, repo):
message = "HTTP-error-code: 404, Error: Please check the repository name \'{}\' or you do not have sufficient permissions to access this repository.".format(repo)
raise NotFoundException(message) from None

def verify_access_for_repo(self):
def verify_access_for_repo(self, repositories=None):
"""
For all the repositories mentioned in the config, check the access for each repos.
Accepts an optional precomputed list of repositories to avoid redundant API calls.
"""
repositories, org = self.extract_repos_from_config() # pylint: disable=unused-variable
if repositories is None:
repositories, _ = self.extract_repos_from_config()

for repo in repositories:

Expand All @@ -263,6 +265,19 @@ def verify_access_for_repo(self):
# Verifying for Repo access
self.verify_repo_access(url_for_repo, repo)

def check_stream_accessible(self, source, url):
"""
Check if a stream endpoint is accessible by making a test request.
Returns True if accessible (HTTP 200), False if permission is denied (403)
or the resource is not found (404).
"""
try:
self.authed_get(source, url, should_skip_404=False)
return True
except GithubException as e:
LOGGER.warning("Stream '%s' is not accessible: %s", source, str(e))
return False

def extract_orgs_from_config(self):
"""
Extracts all organizations from the config
Expand Down
74 changes: 72 additions & 2 deletions tap_github/discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,90 @@
from singer import metadata
from singer.catalog import Catalog, CatalogEntry, Schema
from tap_github.schema import get_schemas
from tap_github.streams import STREAMS

LOGGER = singer.get_logger()


def _build_stream_probe_url(base_url, stream_obj, repo_path, org):
"""
Build a minimal URL to probe whether a stream endpoint is accessible.
Uses per_page=1 to keep the response small.
"""
# Strip any existing query parameters so we control the query string.
base_path = stream_obj.path.split('?')[0]
if stream_obj.use_organization:
url = f"{base_url}/{base_path.format(org)}"
else:
url = f"{base_url}/repos/{repo_path}/{base_path}"
return url + '?per_page=1'


def _is_stream_and_ancestors_accessible(stream_name, inaccessible_streams):
"""
Recursively check whether a stream or any of its ancestors is inaccessible.
Returns False if the stream itself or any ancestor appears in inaccessible_streams.
"""
if stream_name in inaccessible_streams:
return False
parent = STREAMS[stream_name].parent
if parent:
return _is_stream_and_ancestors_accessible(parent, inaccessible_streams)
return True


def _identify_inaccessible_streams(client, repositories):
"""
Verify repo access and probe each top-level stream endpoint.
Returns a set of stream names that are not accessible (403/404).
"""
# Sort for deterministic probe behavior across runs.
repositories = sorted(repositories)
client.verify_access_for_repo(repositories)

# Derive org from the first repo to ensure consistency.
repo_path = repositories[0] if repositories else None
org = repo_path.split('/')[0] if repo_path else None

inaccessible_streams = set()
if repo_path:
for stream_name, stream_class in STREAMS.items():
if stream_class.parent is None:
test_url = _build_stream_probe_url(client.base_url, stream_class, repo_path, org)
if not client.check_stream_accessible(stream_name, test_url):
inaccessible_streams.add(stream_name)
LOGGER.warning(
"Stream '%s' will be excluded from the catalog: "
"insufficient permissions or resource not found.",
stream_name
)
Comment on lines +50 to +61
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discovery() method is handling mulitple responsibilities, please separate those out.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

return inaccessible_streams


def discover(client):
"""
Run the discovery mode, prepare the catalog file and return catalog.
Streams whose API endpoints are not accessible (403/404) are excluded.
"""
# Check credential in the discover mode.
client.verify_access_for_repo()
# Extract repos/orgs once and reuse to avoid double API calls.
repositories, _ = client.extract_repos_from_config()

inaccessible_streams = _identify_inaccessible_streams(client, repositories)

schemas, field_metadata = get_schemas()
catalog = Catalog([])

for stream_name, schema_dict in schemas.items():
# Exclude streams that are inaccessible or whose ancestor is inaccessible.
if not _is_stream_and_ancestors_accessible(stream_name, inaccessible_streams):
if stream_name not in inaccessible_streams:
LOGGER.warning(
"Stream '%s' will be excluded from the catalog: "
"parent stream is not accessible.",
stream_name
)
continue

try:
schema = Schema.from_dict(schema_dict)
mdata = field_metadata[stream_name]
Expand Down
5 changes: 3 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def expected_metadata(self):

def expected_replication_method(self):
"""
Return a dictionary with key of table name
Return a dictionary with key of table name
and value of replication method
"""
return {table: properties.get(self.REPLICATION_METHOD, None)
Expand Down Expand Up @@ -266,7 +266,8 @@ def run_and_verify_check_mode(self, conn_id):

found_catalog_names = set(map(lambda c: c['stream_name'], found_catalogs))
LOGGER.info(found_catalog_names)
self.assertSetEqual(self.expected_streams(), found_catalog_names, msg="discovered schemas do not match")
unexpected_streams = found_catalog_names - self.expected_streams()
self.assertFalse(unexpected_streams, msg="discovered unexpected schemas: {}".format(unexpected_streams))
LOGGER.info("discovered schemas are OK")

return found_catalogs
Expand Down
8 changes: 4 additions & 4 deletions tests/test_github_all_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@
'mentions_count',
'reactions'
},
'collaborators': {
'email',
'name'
},
'reviews': {
'body_text',
'body_html'
},
'collaborators': {
'email',
'name'
},
'teams': {
'permissions'
},
Expand Down
30 changes: 17 additions & 13 deletions tests/unittests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

class MockArgs:
"""Mock args object class"""

def __init__(self, config = None, properties = None, state = None, discover = False) -> None:
self.config = config
self.config = config
self.properties = properties
self.state = state
self.discover = discover
Expand All @@ -20,14 +20,14 @@ class TestDiscoverMode(unittest.TestCase):
"""

mock_config = {"start_date": "", "access_token": ""}

@mock.patch("tap_github._discover")
def test_discover_with_config(self, mock_discover, mock_args, mock_verify_access):
"""Test `_discover` function is called for discover mode"""
mock_discover.return_value = dict()
mock_args.return_value = MockArgs(discover = True, config = self.mock_config)
main()

self.assertTrue(mock_discover.called)


Expand All @@ -49,22 +49,22 @@ def test_sync_with_properties(self, mock_discover, mock_sync, mock_args, mock_cl
mock_client.return_value = "mock_client"
mock_args.return_value = MockArgs(config=self.mock_config, properties=self.mock_catalog)
main()

# Verify `_sync` is called with expected arguments
mock_sync.assert_called_with("mock_client", self.mock_config, {}, self.mock_catalog)

# verify `_discover` function is not called
self.assertFalse(mock_discover.called)

@mock.patch("tap_github._discover")
def test_sync_without_properties(self, mock_discover, mock_sync, mock_args, mock_client):
"""Test sync mode without properties given in args"""

mock_discover.return_value = {"schema": "", "metadata": ""}
mock_client.return_value = "mock_client"
mock_args.return_value = MockArgs(config=self.mock_config)
main()

# Verify `_sync` is called with expected arguments
mock_sync.assert_called_with("mock_client", self.mock_config, {}, {"schema": "", "metadata": ""})

Expand All @@ -77,25 +77,29 @@ def test_sync_with_state(self, mock_sync, mock_args, mock_client):
mock_client.return_value = "mock_client"
mock_args.return_value = MockArgs(config=self.mock_config, properties=self.mock_catalog, state=mock_state)
main()

# Verify `_sync` is called with expected arguments
mock_sync.assert_called_with("mock_client", self.mock_config, mock_state, self.mock_catalog)

@mock.patch("tap_github.GithubClient")
class TestDiscover(unittest.TestCase):
"""Test `discover` function."""

def test_discover(self, mock_client):

mock_client.extract_repos_from_config.return_value = (['org/repo'], {'org'})
mock_client.check_stream_accessible.return_value = True

return_catalog = discover(mock_client)

self.assertIsInstance(return_catalog, dict)

@mock.patch("tap_github.discover.Schema")
@mock.patch("tap_github.discover.LOGGER.error")
def test_discover_error_handling(self, mock_logger, mock_schema, mock_client):
"""Test discover function if exception arises."""
mock_schema.from_dict.side_effect = [Exception]
mock_client.extract_repos_from_config.return_value = (['org/repo'], {'org'})
mock_client.check_stream_accessible.return_value = True
mock_schema.from_dict.side_effect = Exception
with self.assertRaises(Exception):
discover(mock_client)

Expand Down