diff --git a/src/sftp/HISTORY.rst b/src/sftp/HISTORY.rst index c815f77e924..ad611bc18b6 100644 --- a/src/sftp/HISTORY.rst +++ b/src/sftp/HISTORY.rst @@ -3,6 +3,10 @@ Release History =============== +1.0.0b3 ++++++++ +* Prompt the user before overwriting an existing SSH key pair when generating new keys for ``az sftp cert`` and ``az sftp connect``. Add ``--yes/-y`` to skip the prompt and overwrite without confirmation. + 1.0.0b2 +++++++ * Add ``--buffer-size`` parameter to ``az sftp connect`` for configuring SFTP transfer buffer size. diff --git a/src/sftp/azext_sftp/_help.py b/src/sftp/azext_sftp/_help.py index 6ff47f5bea0..8840742a476 100644 --- a/src/sftp/azext_sftp/_help.py +++ b/src/sftp/azext_sftp/_help.py @@ -43,6 +43,14 @@ - Certificates are valid for a limited time (typically 1 hour) - Private keys are generated with 'id_rsa' name when key pair is created + KEY PAIR OVERWRITE: + - When a key pair already exists at the target location, you will be + prompted before it is overwritten. Selecting 'y' (default) regenerates + the key pair; selecting 'n' reuses the existing keys. + - Pass --yes/-y to skip the prompt and overwrite without confirmation. + - In non-interactive sessions (no TTY) the existing keys are reused + unless --yes is supplied. + The certificate can be used with 'az sftp connect' or with standard SFTP clients. examples: - name: Generate a certificate using an existing public key @@ -51,6 +59,8 @@ text: az sftp cert --file ~/my_cert.pub - name: Generate a certificate with custom SSH client folder text: az sftp cert --file ~/my_cert.pub --ssh-client-folder "C:\\Program Files\\OpenSSH" + - name: Generate a certificate and overwrite an existing key pair without prompting + text: az sftp cert --file ~/my_cert.pub --yes """ helps['sftp connect'] = """ @@ -98,6 +108,8 @@ text: az sftp connect --storage-account mystorageaccount --buffer-size 1048576 - name: Connect with a custom endpoint suffix text: az sftp connect --storage-account mystorageaccount --endpoint-suffix blob.core.usgovcloudapi.net + - name: Connect and overwrite an existing SSH key pair without prompting + text: az sftp connect --storage-account mystorageaccount --yes - name: Connect with additional SFTP arguments for debugging text: az sftp connect --storage-account mystorageaccount --sftp-args="-v" - name: Connect with custom SSH client folder (Windows) diff --git a/src/sftp/azext_sftp/_params.py b/src/sftp/azext_sftp/_params.py index 4549a4c5ad5..998511b019e 100644 --- a/src/sftp/azext_sftp/_params.py +++ b/src/sftp/azext_sftp/_params.py @@ -16,6 +16,8 @@ def load_arguments(self, _): c.argument('ssh_client_folder', options_list=['--ssh-client-folder'], help='Folder path that contains ssh executables (ssh-keygen, ssh). ' 'Default to ssh executables in your PATH or C:\\Windows\\System32\\OpenSSH on Windows.') + c.argument('yes', options_list=['--yes', '-y'], action='store_true', + help='Do not prompt for confirmation when overwriting an existing SSH key pair.') with self.argument_context('sftp connect') as c: c.argument('storage_account', options_list=['--storage-account', '-s'], @@ -48,3 +50,5 @@ def load_arguments(self, _): help='Custom storage account endpoint suffix. ' 'Default: Uses endpoint based on Azure environment ' '(e.g., blob.core.windows.net).') + c.argument('yes', options_list=['--yes', '-y'], action='store_true', + help='Do not prompt for confirmation when overwriting an existing SSH key pair.') diff --git a/src/sftp/azext_sftp/custom.py b/src/sftp/azext_sftp/custom.py index ab73bef58c6..a6b3c7d5cce 100644 --- a/src/sftp/azext_sftp/custom.py +++ b/src/sftp/azext_sftp/custom.py @@ -19,7 +19,7 @@ logger = log.get_logger(__name__) -def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None): +def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None, yes=False): """Generate SSH certificate for SFTP authentication using Azure AD.""" logger.debug("Starting SFTP certificate generation") @@ -53,7 +53,7 @@ def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None) try: public_key_file, _, delete_keys = file_utils.check_or_create_public_private_files( - public_key_file, None, keys_folder, ssh_client_folder) + public_key_file, None, keys_folder, ssh_client_folder, yes_without_prompt=yes) cert_file, _ = file_utils.get_and_write_certificate(cmd, public_key_file, cert_path, ssh_client_folder) except Exception as e: logger.debug("Certificate generation failed: %s", str(e)) @@ -75,7 +75,7 @@ def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None) def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_file=None, public_key_file=None, sftp_args=None, ssh_client_folder=None, - buffer_size=None, endpoint_suffix=None): + buffer_size=None, endpoint_suffix=None, yes=False): """Connect to Azure Storage Account via SFTP with automatic certificate generation if needed.""" logger.debug("Starting SFTP connection to storage account: %s", storage_account) @@ -116,14 +116,14 @@ def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_fi try: if auto_generate_cert: public_key_file, private_key_file, _ = file_utils.check_or_create_public_private_files( - None, None, credentials_folder, ssh_client_folder) + None, None, credentials_folder, ssh_client_folder, yes_without_prompt=yes) cert_file, user = file_utils.get_and_write_certificate(cmd, public_key_file, None, ssh_client_folder) elif not cert_file: profile = Profile(cli_ctx=cmd.cli_ctx) profile.get_subscription() public_key_file, private_key_file, _ = file_utils.check_or_create_public_private_files( - public_key_file, private_key_file, None, ssh_client_folder) + public_key_file, private_key_file, None, ssh_client_folder, yes_without_prompt=yes) print_styled_text((Style.ACTION, "Generating certificate...")) cert_file, user = file_utils.get_and_write_certificate(cmd, public_key_file, None, ssh_client_folder) delete_cert = True diff --git a/src/sftp/azext_sftp/file_utils.py b/src/sftp/azext_sftp/file_utils.py index 776b979db06..2c7d07d6f23 100644 --- a/src/sftp/azext_sftp/file_utils.py +++ b/src/sftp/azext_sftp/file_utils.py @@ -13,6 +13,7 @@ from azure.cli.core import telemetry from azure.cli.core._profile import Profile from knack import log +from knack.prompting import prompt_y_n, NoTTYException from . import rsa_parser from . import sftp_utils @@ -20,6 +21,46 @@ logger = log.get_logger(__name__) +def _should_regenerate_key_pair(private_key_file, public_key_file, yes_without_prompt): + """Determine whether to regenerate an SSH key pair when files already exist. + + Returns True if new keys should be generated (no existing keys, --yes specified, + or user confirmed overwrite at the prompt). Returns False to reuse existing keys. + """ + private_key_exists = bool(private_key_file) and os.path.isfile(private_key_file) + public_key_exists = bool(public_key_file) and os.path.isfile(public_key_file) + + if not private_key_exists and not public_key_exists: + return True + + keys_folder = os.path.dirname(private_key_file or public_key_file) + if private_key_exists and public_key_exists: + existing_files = "private key and public key" + elif private_key_exists: + existing_files = "private key only" + else: + existing_files = "public key only" + + logger.debug("Existing SSH key pair detected in '%s' (%s)", keys_folder, existing_files) + + if yes_without_prompt: + logger.debug("--yes specified, will overwrite existing key pair") + return True + + message = (f"An existing SSH key pair was found in '{keys_folder}' ({existing_files}). " + "Selecting 'y' will generate a new key pair and overwrite the existing files. " + "Selecting 'n' will use the existing key pair. Overwrite?") + try: + overwrite = prompt_y_n(message, default='y') + except NoTTYException: + logger.warning("No TTY available to prompt for key pair overwrite. Reusing existing " + "key pair in '%s'. Use --yes to overwrite without prompting.", keys_folder) + return False + + logger.debug("User chose to %s existing key pair", "overwrite" if overwrite else "reuse") + return overwrite + + def delete_file(file_path, message, warning=False): """Delete a file with error handling.""" if os.path.isfile(file_path): @@ -33,7 +74,8 @@ def delete_file(file_path, message, warning=False): raise azclierror.FileOperationError(f"{message}Error: {str(e)}") from e -def check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, ssh_client_folder=None): +def check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, + ssh_client_folder=None, yes_without_prompt=False): """Check for existing key files or create new ones if needed.""" delete_keys = False @@ -47,13 +89,14 @@ def check_or_create_public_private_files(public_key_file, private_key_file, cred public_key_file = os.path.join(credentials_folder, "id_rsa.pub") private_key_file = os.path.join(credentials_folder, "id_rsa") - # Check if existing keys are present before generating new ones - if not (os.path.isfile(public_key_file) and os.path.isfile(private_key_file)): - # Only generate new keys if both don't exist + if _should_regenerate_key_pair(private_key_file, public_key_file, yes_without_prompt): + # Remove existing files so ssh-keygen does not prompt again to overwrite. + for stale in (private_key_file, public_key_file): + if os.path.isfile(stale): + delete_file(stale, f"Failed to remove existing key file {stale}. ", warning=True) sftp_utils.create_ssh_keyfile(private_key_file, ssh_client_folder) - # Only set delete_keys to True when we actually create new keys delete_keys = True - # If existing keys are found, delete_keys remains False + # else: existing keys reused, delete_keys stays False if not public_key_file: if private_key_file: diff --git a/src/sftp/azext_sftp/tests/latest/test_custom.py b/src/sftp/azext_sftp/tests/latest/test_custom.py index 1e154979222..c9b3874c7fd 100644 --- a/src/sftp/azext_sftp/tests/latest/test_custom.py +++ b/src/sftp/azext_sftp/tests/latest/test_custom.py @@ -83,7 +83,7 @@ def test_sftp_cert(self, mock_write_cert, mock_get_keys, mock_abspath, mock_isdi custom.sftp_cert(cmd, "cert", "pubkey") - mock_get_keys.assert_called_once_with('/pubkey/path', None, None, None) + mock_get_keys.assert_called_once_with('/pubkey/path', None, None, None, yes_without_prompt=False) mock_write_cert.assert_called_once_with(cmd, 'pubkey', '/cert/path', None) @mock.patch('azext_sftp.custom._do_sftp_op') @@ -181,7 +181,7 @@ def test_sftp_connect_key_generation_scenarios(self, mock_mkdtemp, mock_create_k ) # Verify function calls - mock_create_keys.assert_called_once_with(*expected_create_keys_args) + mock_create_keys.assert_called_once_with(*expected_create_keys_args, yes_without_prompt=False) mock_gen_cert.assert_called_once() mock_do_sftp.assert_called_once() @@ -463,7 +463,7 @@ def test_sftp_cert_parameter_combinations(self, mock_abspath, mock_isdir, mock_c # Verify calls expected_keys_dir = os.path.dirname(cert_path) if expected_keys_folder == "cert_dir" else expected_keys_folder - mock_check_files.assert_called_once_with(public_key_file, None, expected_keys_dir, ssh_client_folder) + mock_check_files.assert_called_once_with(public_key_file, None, expected_keys_dir, ssh_client_folder, yes_without_prompt=False) mock_write_cert.assert_called_once_with(cmd, effective_public_key, cert_path, ssh_client_folder) def test_sftp_cert_error_cases(self): @@ -551,7 +551,7 @@ def test_sftp_cert_valid_minimal_call(self): custom.sftp_cert(cmd, public_key_file="pubkey.pub") # Verify function was called correctly - mock_check_files.assert_called_once_with("pubkey.pub", None, None, None) + mock_check_files.assert_called_once_with("pubkey.pub", None, None, None, yes_without_prompt=False) # Additional tests for private helper functions diff --git a/src/sftp/azext_sftp/tests/latest/test_file_utils.py b/src/sftp/azext_sftp/tests/latest/test_file_utils.py index 75ee1143337..55f3047853f 100644 --- a/src/sftp/azext_sftp/tests/latest/test_file_utils.py +++ b/src/sftp/azext_sftp/tests/latest/test_file_utils.py @@ -371,5 +371,146 @@ def test_get_modulus_exponent_parse_error(self, mock_parser_class): self.assertIn("Could not parse public key", str(context.exception)) +class SftpKeyPairOverwritePromptTest(unittest.TestCase): + """Test suite for the SSH key-pair overwrite prompt.""" + + def setUp(self): + super().setUp() + self.temp_dir = tempfile.mkdtemp(prefix="sftp_keypair_prompt_test_") + self.private_key = os.path.join(self.temp_dir, "id_rsa") + self.public_key = os.path.join(self.temp_dir, "id_rsa.pub") + + def tearDown(self): + super().tearDown() + if os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + + def _write_existing_keys(self): + with open(self.private_key, 'w') as f: + f.write("existing private key") + with open(self.public_key, 'w') as f: + f.write("existing public key") + + def test_should_regenerate_no_existing_keys_returns_true(self): + """No existing keys -> generate without prompting.""" + with mock.patch('azext_sftp.file_utils.prompt_y_n') as mock_prompt: + result = file_utils._should_regenerate_key_pair( + self.private_key, self.public_key, yes_without_prompt=False) + self.assertTrue(result) + mock_prompt.assert_not_called() + + def test_should_regenerate_yes_flag_skips_prompt(self): + """--yes specified -> overwrite without prompting.""" + self._write_existing_keys() + with mock.patch('azext_sftp.file_utils.prompt_y_n') as mock_prompt: + result = file_utils._should_regenerate_key_pair( + self.private_key, self.public_key, yes_without_prompt=True) + self.assertTrue(result) + mock_prompt.assert_not_called() + + @mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True) + def test_should_regenerate_user_confirms_overwrite(self, mock_prompt): + """Existing keys + user answers 'y' -> regenerate.""" + self._write_existing_keys() + result = file_utils._should_regenerate_key_pair( + self.private_key, self.public_key, yes_without_prompt=False) + self.assertTrue(result) + mock_prompt.assert_called_once() + + @mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=False) + def test_should_regenerate_user_declines_overwrite(self, mock_prompt): + """Existing keys + user answers 'n' -> reuse.""" + self._write_existing_keys() + result = file_utils._should_regenerate_key_pair( + self.private_key, self.public_key, yes_without_prompt=False) + self.assertFalse(result) + mock_prompt.assert_called_once() + + @mock.patch('azext_sftp.file_utils.prompt_y_n', + side_effect=file_utils.NoTTYException()) + def test_should_regenerate_no_tty_defaults_to_reuse(self, mock_prompt): + """No TTY available -> reuse existing keys (safe default).""" + self._write_existing_keys() + result = file_utils._should_regenerate_key_pair( + self.private_key, self.public_key, yes_without_prompt=False) + self.assertFalse(result) + + @mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True) + def test_should_regenerate_only_private_key_exists(self, mock_prompt): + """Partial existence (private only) is detected and prompted.""" + with open(self.private_key, 'w') as f: + f.write("existing private key") + result = file_utils._should_regenerate_key_pair( + self.private_key, self.public_key, yes_without_prompt=False) + self.assertTrue(result) + prompt_text = mock_prompt.call_args[0][0] + self.assertIn("private key only", prompt_text) + + @mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True) + def test_should_regenerate_only_public_key_exists(self, mock_prompt): + """Partial existence (public only) is detected and prompted.""" + with open(self.public_key, 'w') as f: + f.write("existing public key") + result = file_utils._should_regenerate_key_pair( + self.private_key, self.public_key, yes_without_prompt=False) + self.assertTrue(result) + prompt_text = mock_prompt.call_args[0][0] + self.assertIn("public key only", prompt_text) + + @mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile') + @mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=False) + def test_check_or_create_reuses_existing_keys_when_user_declines( + self, mock_prompt, mock_create_keyfile): + """Existing keys + user declines -> reuse, do not regenerate, delete_keys=False.""" + self._write_existing_keys() + public_key, private_key, delete_keys = file_utils.check_or_create_public_private_files( + None, None, self.temp_dir, ssh_client_folder=None, yes_without_prompt=False) + self.assertEqual(public_key, self.public_key) + self.assertEqual(private_key, self.private_key) + self.assertFalse(delete_keys) + mock_create_keyfile.assert_not_called() + mock_prompt.assert_called_once() + + @mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile') + @mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True) + def test_check_or_create_regenerates_keys_when_user_confirms( + self, mock_prompt, mock_create_keyfile): + """Existing keys + user confirms -> regenerate, delete_keys=True.""" + self._write_existing_keys() + + def recreate(private_key_path, _ssh_client_folder): + with open(private_key_path, 'w') as f: + f.write("new private key") + with open(private_key_path + ".pub", 'w') as f: + f.write("new public key") + mock_create_keyfile.side_effect = recreate + + _, private_key, delete_keys = file_utils.check_or_create_public_private_files( + None, None, self.temp_dir, ssh_client_folder=None, yes_without_prompt=False) + self.assertTrue(delete_keys) + mock_create_keyfile.assert_called_once_with(private_key, None) + mock_prompt.assert_called_once() + + @mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile') + @mock.patch('azext_sftp.file_utils.prompt_y_n') + def test_check_or_create_yes_flag_skips_prompt_and_regenerates( + self, mock_prompt, mock_create_keyfile): + """Existing keys + yes_without_prompt=True -> regenerate without prompt.""" + self._write_existing_keys() + + def recreate(private_key_path, _ssh_client_folder): + with open(private_key_path, 'w') as f: + f.write("new private key") + with open(private_key_path + ".pub", 'w') as f: + f.write("new public key") + mock_create_keyfile.side_effect = recreate + + _, private_key, delete_keys = file_utils.check_or_create_public_private_files( + None, None, self.temp_dir, ssh_client_folder=None, yes_without_prompt=True) + self.assertTrue(delete_keys) + mock_create_keyfile.assert_called_once_with(private_key, None) + mock_prompt.assert_not_called() + + if __name__ == '__main__': unittest.main() diff --git a/src/sftp/setup.py b/src/sftp/setup.py index 486792ced88..994121f5784 100644 --- a/src/sftp/setup.py +++ b/src/sftp/setup.py @@ -16,7 +16,7 @@ # TODO: Confirm this is the right version number you want and it matches your # HISTORY.rst entry. -VERSION = '1.0.0b2' +VERSION = '1.0.0b3' # The full list of classifiers is available at # https://pypi.python.org/pypi?%3Aaction=list_classifiers