Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sftp/HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
Release History
===============

1.0.0b3
+++++++
* Prompt the user before overwriting an existing SSH key pair when generating new keys for ``az sftp cert`` and ``az sftp connect``. Add ``--yes/-y`` to skip the prompt and overwrite without confirmation.

1.0.0b2
+++++++
* Add ``--buffer-size`` parameter to ``az sftp connect`` for configuring SFTP transfer buffer size.
Expand Down
12 changes: 12 additions & 0 deletions src/sftp/azext_sftp/_help.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@
- Certificates are valid for a limited time (typically 1 hour)
- Private keys are generated with 'id_rsa' name when key pair is created

KEY PAIR OVERWRITE:
- When a key pair already exists at the target location, you will be
prompted before it is overwritten. Selecting 'y' (default) regenerates
the key pair; selecting 'n' reuses the existing keys.
- Pass --yes/-y to skip the prompt and overwrite without confirmation.
- In non-interactive sessions (no TTY) the existing keys are reused
unless --yes is supplied.

The certificate can be used with 'az sftp connect' or with standard SFTP clients.
examples:
- name: Generate a certificate using an existing public key
Expand All @@ -51,6 +59,8 @@
text: az sftp cert --file ~/my_cert.pub
- name: Generate a certificate with custom SSH client folder
text: az sftp cert --file ~/my_cert.pub --ssh-client-folder "C:\\Program Files\\OpenSSH"
- name: Generate a certificate and overwrite an existing key pair without prompting
text: az sftp cert --file ~/my_cert.pub --yes
"""

helps['sftp connect'] = """
Expand Down Expand Up @@ -98,6 +108,8 @@
text: az sftp connect --storage-account mystorageaccount --buffer-size 1048576
- name: Connect with a custom endpoint suffix
text: az sftp connect --storage-account mystorageaccount --endpoint-suffix blob.core.usgovcloudapi.net
- name: Connect and overwrite an existing SSH key pair without prompting
text: az sftp connect --storage-account mystorageaccount --yes
- name: Connect with additional SFTP arguments for debugging
text: az sftp connect --storage-account mystorageaccount --sftp-args="-v"
- name: Connect with custom SSH client folder (Windows)
Expand Down
4 changes: 4 additions & 0 deletions src/sftp/azext_sftp/_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def load_arguments(self, _):
c.argument('ssh_client_folder', options_list=['--ssh-client-folder'],
help='Folder path that contains ssh executables (ssh-keygen, ssh). '
'Default to ssh executables in your PATH or C:\\Windows\\System32\\OpenSSH on Windows.')
c.argument('yes', options_list=['--yes', '-y'], action='store_true',
help='Do not prompt for confirmation when overwriting an existing SSH key pair.')

with self.argument_context('sftp connect') as c:
c.argument('storage_account', options_list=['--storage-account', '-s'],
Expand Down Expand Up @@ -48,3 +50,5 @@ def load_arguments(self, _):
help='Custom storage account endpoint suffix. '
'Default: Uses endpoint based on Azure environment '
'(e.g., blob.core.windows.net).')
c.argument('yes', options_list=['--yes', '-y'], action='store_true',
help='Do not prompt for confirmation when overwriting an existing SSH key pair.')
10 changes: 5 additions & 5 deletions src/sftp/azext_sftp/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
logger = log.get_logger(__name__)


def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None):
def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None, yes=False):
"""Generate SSH certificate for SFTP authentication using Azure AD."""
logger.debug("Starting SFTP certificate generation")

Expand Down Expand Up @@ -53,7 +53,7 @@ def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None)

try:
public_key_file, _, delete_keys = file_utils.check_or_create_public_private_files(
public_key_file, None, keys_folder, ssh_client_folder)
public_key_file, None, keys_folder, ssh_client_folder, yes_without_prompt=yes)
cert_file, _ = file_utils.get_and_write_certificate(cmd, public_key_file, cert_path, ssh_client_folder)
except Exception as e:
logger.debug("Certificate generation failed: %s", str(e))
Expand All @@ -75,7 +75,7 @@ def sftp_cert(cmd, cert_path=None, public_key_file=None, ssh_client_folder=None)

def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_file=None,
public_key_file=None, sftp_args=None, ssh_client_folder=None,
buffer_size=None, endpoint_suffix=None):
buffer_size=None, endpoint_suffix=None, yes=False):
"""Connect to Azure Storage Account via SFTP with automatic certificate generation if needed."""
logger.debug("Starting SFTP connection to storage account: %s", storage_account)

Expand Down Expand Up @@ -116,14 +116,14 @@ def sftp_connect(cmd, storage_account, port=None, cert_file=None, private_key_fi
try:
if auto_generate_cert:
public_key_file, private_key_file, _ = file_utils.check_or_create_public_private_files(
None, None, credentials_folder, ssh_client_folder)
None, None, credentials_folder, ssh_client_folder, yes_without_prompt=yes)
cert_file, user = file_utils.get_and_write_certificate(cmd, public_key_file, None, ssh_client_folder)
elif not cert_file:
profile = Profile(cli_ctx=cmd.cli_ctx)
profile.get_subscription()

public_key_file, private_key_file, _ = file_utils.check_or_create_public_private_files(
public_key_file, private_key_file, None, ssh_client_folder)
public_key_file, private_key_file, None, ssh_client_folder, yes_without_prompt=yes)
print_styled_text((Style.ACTION, "Generating certificate..."))
cert_file, user = file_utils.get_and_write_certificate(cmd, public_key_file, None, ssh_client_folder)
delete_cert = True
Expand Down
55 changes: 49 additions & 6 deletions src/sftp/azext_sftp/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,54 @@
from azure.cli.core import telemetry
from azure.cli.core._profile import Profile
from knack import log
from knack.prompting import prompt_y_n, NoTTYException

from . import rsa_parser
from . import sftp_utils

logger = log.get_logger(__name__)


def _should_regenerate_key_pair(private_key_file, public_key_file, yes_without_prompt):
"""Determine whether to regenerate an SSH key pair when files already exist.

Returns True if new keys should be generated (no existing keys, --yes specified,
or user confirmed overwrite at the prompt). Returns False to reuse existing keys.
"""
private_key_exists = bool(private_key_file) and os.path.isfile(private_key_file)
public_key_exists = bool(public_key_file) and os.path.isfile(public_key_file)

if not private_key_exists and not public_key_exists:
return True

keys_folder = os.path.dirname(private_key_file or public_key_file)
if private_key_exists and public_key_exists:
existing_files = "private key and public key"
elif private_key_exists:
existing_files = "private key only"
else:
existing_files = "public key only"

Comment thread
DevanshG1 marked this conversation as resolved.
logger.debug("Existing SSH key pair detected in '%s' (%s)", keys_folder, existing_files)

if yes_without_prompt:
logger.debug("--yes specified, will overwrite existing key pair")
return True

message = (f"An existing SSH key pair was found in '{keys_folder}' ({existing_files}). "
"Selecting 'y' will generate a new key pair and overwrite the existing files. "
"Selecting 'n' will use the existing key pair. Overwrite?")
try:
overwrite = prompt_y_n(message, default='y')
Comment on lines +52 to +54
except NoTTYException:
logger.warning("No TTY available to prompt for key pair overwrite. Reusing existing "
"key pair in '%s'. Use --yes to overwrite without prompting.", keys_folder)
return False

logger.debug("User chose to %s existing key pair", "overwrite" if overwrite else "reuse")
return overwrite


def delete_file(file_path, message, warning=False):
"""Delete a file with error handling."""
if os.path.isfile(file_path):
Expand All @@ -33,7 +74,8 @@ def delete_file(file_path, message, warning=False):
raise azclierror.FileOperationError(f"{message}Error: {str(e)}") from e


def check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder, ssh_client_folder=None):
def check_or_create_public_private_files(public_key_file, private_key_file, credentials_folder,
ssh_client_folder=None, yes_without_prompt=False):
"""Check for existing key files or create new ones if needed."""
delete_keys = False

Expand All @@ -47,13 +89,14 @@ def check_or_create_public_private_files(public_key_file, private_key_file, cred
public_key_file = os.path.join(credentials_folder, "id_rsa.pub")
private_key_file = os.path.join(credentials_folder, "id_rsa")

# Check if existing keys are present before generating new ones
if not (os.path.isfile(public_key_file) and os.path.isfile(private_key_file)):
# Only generate new keys if both don't exist
if _should_regenerate_key_pair(private_key_file, public_key_file, yes_without_prompt):
# Remove existing files so ssh-keygen does not prompt again to overwrite.
for stale in (private_key_file, public_key_file):
if os.path.isfile(stale):
delete_file(stale, f"Failed to remove existing key file {stale}. ", warning=True)
sftp_utils.create_ssh_keyfile(private_key_file, ssh_client_folder)
# Only set delete_keys to True when we actually create new keys
delete_keys = True
# If existing keys are found, delete_keys remains False
# else: existing keys reused, delete_keys stays False

if not public_key_file:
if private_key_file:
Expand Down
8 changes: 4 additions & 4 deletions src/sftp/azext_sftp/tests/latest/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_sftp_cert(self, mock_write_cert, mock_get_keys, mock_abspath, mock_isdi

custom.sftp_cert(cmd, "cert", "pubkey")

mock_get_keys.assert_called_once_with('/pubkey/path', None, None, None)
mock_get_keys.assert_called_once_with('/pubkey/path', None, None, None, yes_without_prompt=False)
mock_write_cert.assert_called_once_with(cmd, 'pubkey', '/cert/path', None)

@mock.patch('azext_sftp.custom._do_sftp_op')
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_sftp_connect_key_generation_scenarios(self, mock_mkdtemp, mock_create_k
)

# Verify function calls
mock_create_keys.assert_called_once_with(*expected_create_keys_args)
mock_create_keys.assert_called_once_with(*expected_create_keys_args, yes_without_prompt=False)
mock_gen_cert.assert_called_once()
mock_do_sftp.assert_called_once()

Expand Down Expand Up @@ -463,7 +463,7 @@ def test_sftp_cert_parameter_combinations(self, mock_abspath, mock_isdir, mock_c

# Verify calls
expected_keys_dir = os.path.dirname(cert_path) if expected_keys_folder == "cert_dir" else expected_keys_folder
mock_check_files.assert_called_once_with(public_key_file, None, expected_keys_dir, ssh_client_folder)
mock_check_files.assert_called_once_with(public_key_file, None, expected_keys_dir, ssh_client_folder, yes_without_prompt=False)
mock_write_cert.assert_called_once_with(cmd, effective_public_key, cert_path, ssh_client_folder)

def test_sftp_cert_error_cases(self):
Expand Down Expand Up @@ -551,7 +551,7 @@ def test_sftp_cert_valid_minimal_call(self):
custom.sftp_cert(cmd, public_key_file="pubkey.pub")

# Verify function was called correctly
mock_check_files.assert_called_once_with("pubkey.pub", None, None, None)
mock_check_files.assert_called_once_with("pubkey.pub", None, None, None, yes_without_prompt=False)

# Additional tests for private helper functions

Expand Down
141 changes: 141 additions & 0 deletions src/sftp/azext_sftp/tests/latest/test_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,5 +371,146 @@ def test_get_modulus_exponent_parse_error(self, mock_parser_class):
self.assertIn("Could not parse public key", str(context.exception))


class SftpKeyPairOverwritePromptTest(unittest.TestCase):
"""Test suite for the SSH key-pair overwrite prompt."""

def setUp(self):
super().setUp()
self.temp_dir = tempfile.mkdtemp(prefix="sftp_keypair_prompt_test_")
self.private_key = os.path.join(self.temp_dir, "id_rsa")
self.public_key = os.path.join(self.temp_dir, "id_rsa.pub")

def tearDown(self):
super().tearDown()
if os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)

def _write_existing_keys(self):
with open(self.private_key, 'w') as f:
f.write("existing private key")
with open(self.public_key, 'w') as f:
f.write("existing public key")

def test_should_regenerate_no_existing_keys_returns_true(self):
"""No existing keys -> generate without prompting."""
with mock.patch('azext_sftp.file_utils.prompt_y_n') as mock_prompt:
result = file_utils._should_regenerate_key_pair(
self.private_key, self.public_key, yes_without_prompt=False)
self.assertTrue(result)
mock_prompt.assert_not_called()

def test_should_regenerate_yes_flag_skips_prompt(self):
"""--yes specified -> overwrite without prompting."""
self._write_existing_keys()
with mock.patch('azext_sftp.file_utils.prompt_y_n') as mock_prompt:
result = file_utils._should_regenerate_key_pair(
self.private_key, self.public_key, yes_without_prompt=True)
self.assertTrue(result)
mock_prompt.assert_not_called()

@mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True)
def test_should_regenerate_user_confirms_overwrite(self, mock_prompt):
"""Existing keys + user answers 'y' -> regenerate."""
self._write_existing_keys()
result = file_utils._should_regenerate_key_pair(
self.private_key, self.public_key, yes_without_prompt=False)
self.assertTrue(result)
mock_prompt.assert_called_once()

@mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=False)
def test_should_regenerate_user_declines_overwrite(self, mock_prompt):
"""Existing keys + user answers 'n' -> reuse."""
self._write_existing_keys()
result = file_utils._should_regenerate_key_pair(
self.private_key, self.public_key, yes_without_prompt=False)
self.assertFalse(result)
mock_prompt.assert_called_once()

@mock.patch('azext_sftp.file_utils.prompt_y_n',
side_effect=file_utils.NoTTYException())
def test_should_regenerate_no_tty_defaults_to_reuse(self, mock_prompt):
"""No TTY available -> reuse existing keys (safe default)."""
self._write_existing_keys()
result = file_utils._should_regenerate_key_pair(
self.private_key, self.public_key, yes_without_prompt=False)
self.assertFalse(result)

@mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True)
def test_should_regenerate_only_private_key_exists(self, mock_prompt):
"""Partial existence (private only) is detected and prompted."""
with open(self.private_key, 'w') as f:
f.write("existing private key")
result = file_utils._should_regenerate_key_pair(
self.private_key, self.public_key, yes_without_prompt=False)
self.assertTrue(result)
prompt_text = mock_prompt.call_args[0][0]
self.assertIn("private key only", prompt_text)

@mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True)
def test_should_regenerate_only_public_key_exists(self, mock_prompt):
"""Partial existence (public only) is detected and prompted."""
with open(self.public_key, 'w') as f:
f.write("existing public key")
result = file_utils._should_regenerate_key_pair(
self.private_key, self.public_key, yes_without_prompt=False)
self.assertTrue(result)
prompt_text = mock_prompt.call_args[0][0]
self.assertIn("public key only", prompt_text)

@mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile')
@mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=False)
def test_check_or_create_reuses_existing_keys_when_user_declines(
self, mock_prompt, mock_create_keyfile):
"""Existing keys + user declines -> reuse, do not regenerate, delete_keys=False."""
self._write_existing_keys()
public_key, private_key, delete_keys = file_utils.check_or_create_public_private_files(
None, None, self.temp_dir, ssh_client_folder=None, yes_without_prompt=False)
self.assertEqual(public_key, self.public_key)
self.assertEqual(private_key, self.private_key)
self.assertFalse(delete_keys)
mock_create_keyfile.assert_not_called()
mock_prompt.assert_called_once()

@mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile')
@mock.patch('azext_sftp.file_utils.prompt_y_n', return_value=True)
def test_check_or_create_regenerates_keys_when_user_confirms(
self, mock_prompt, mock_create_keyfile):
"""Existing keys + user confirms -> regenerate, delete_keys=True."""
self._write_existing_keys()

def recreate(private_key_path, _ssh_client_folder):
with open(private_key_path, 'w') as f:
f.write("new private key")
with open(private_key_path + ".pub", 'w') as f:
f.write("new public key")
mock_create_keyfile.side_effect = recreate

_, private_key, delete_keys = file_utils.check_or_create_public_private_files(
None, None, self.temp_dir, ssh_client_folder=None, yes_without_prompt=False)
self.assertTrue(delete_keys)
mock_create_keyfile.assert_called_once_with(private_key, None)
mock_prompt.assert_called_once()

@mock.patch('azext_sftp.sftp_utils.create_ssh_keyfile')
@mock.patch('azext_sftp.file_utils.prompt_y_n')
def test_check_or_create_yes_flag_skips_prompt_and_regenerates(
self, mock_prompt, mock_create_keyfile):
"""Existing keys + yes_without_prompt=True -> regenerate without prompt."""
self._write_existing_keys()

def recreate(private_key_path, _ssh_client_folder):
with open(private_key_path, 'w') as f:
f.write("new private key")
with open(private_key_path + ".pub", 'w') as f:
f.write("new public key")
mock_create_keyfile.side_effect = recreate

_, private_key, delete_keys = file_utils.check_or_create_public_private_files(
None, None, self.temp_dir, ssh_client_folder=None, yes_without_prompt=True)
self.assertTrue(delete_keys)
mock_create_keyfile.assert_called_once_with(private_key, None)
mock_prompt.assert_not_called()


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion src/sftp/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

# TODO: Confirm this is the right version number you want and it matches your
# HISTORY.rst entry.
VERSION = '1.0.0b2'
VERSION = '1.0.0b3'

# The full list of classifiers is available at
# https://pypi.python.org/pypi?%3Aaction=list_classifiers
Expand Down
Loading