Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions tests/cli/test_cli_auth.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import io
import json
import tempfile
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import requests

from codecarbon.cli import auth
from codecarbon.cli.auth import _CallbackHandler

Expand Down Expand Up @@ -100,6 +104,31 @@ def test_validate_access_token_valid(
):
self.assertTrue(auth._validate_access_token("token"))

@patch("codecarbon.cli.auth._discover_endpoints", return_value={"jwks_uri": "jwks"})
@patch(
"codecarbon.cli.auth.requests.get",
side_effect=requests.RequestException("offline"),
)
def test_validate_access_token_network_error_returns_true(
self, mock_get, mock_discover
):
self.assertTrue(auth._validate_access_token("token"))

@patch("codecarbon.cli.auth._discover_endpoints", return_value={"jwks_uri": "jwks"})
@patch("codecarbon.cli.auth.requests.get")
@patch("codecarbon.cli.auth.JsonWebKey.import_key_set")
@patch(
"codecarbon.cli.auth.jose_jwt.decode",
side_effect=Exception("invalid"),
)
def test_validate_access_token_invalid_returns_false(
self, mock_decode, mock_import_key_set, mock_get, mock_discover
):
mock_get.return_value.json.return_value = {"keys": []}
mock_get.return_value.raise_for_status.return_value = None

self.assertFalse(auth._validate_access_token("token"))

@patch("codecarbon.cli.auth.requests.post")
@patch("codecarbon.cli.auth._discover_endpoints")
def test_refresh_tokens(self, mock_discover, mock_post):
Expand All @@ -120,6 +149,18 @@ def test_get_access_token_valid(self, mock_validate, mock_load):
mock_validate.return_value = True
self.assertEqual(auth.get_access_token(), "a")

@patch("codecarbon.cli.auth._load_credentials", side_effect=OSError("missing"))
def test_get_access_token_raises_when_credentials_missing(self, mock_load):
with self.assertRaises(ValueError):
auth.get_access_token()

@patch("codecarbon.cli.auth._load_credentials")
def test_get_access_token_raises_when_access_token_missing(self, mock_load):
mock_load.return_value = {"refresh_token": "r"}

with self.assertRaises(ValueError):
auth.get_access_token()

@patch("codecarbon.cli.auth._load_credentials")
@patch("codecarbon.cli.auth._validate_access_token")
@patch("codecarbon.cli.auth._refresh_tokens")
Expand All @@ -133,11 +174,138 @@ def test_get_access_token_refresh(
self.assertEqual(auth.get_access_token(), "b")
mock_save.assert_called()

@patch("codecarbon.cli.auth._refresh_tokens", side_effect=Exception("expired"))
@patch("codecarbon.cli.auth._validate_access_token", return_value=False)
@patch(
"codecarbon.cli.auth._load_credentials",
return_value={"access_token": "a", "refresh_token": "r"},
)
def test_get_access_token_refresh_failure_deletes_credentials(
self, mock_load, mock_validate, mock_refresh
):
original_credentials_file = auth._CREDENTIALS_FILE
with tempfile.TemporaryDirectory() as tmp_dir:
temp_credentials = Path(tmp_dir) / "test_credentials.json"
temp_credentials.write_text("{}")
try:
auth._CREDENTIALS_FILE = temp_credentials
with self.assertRaises(ValueError):
auth.get_access_token()
self.assertFalse(temp_credentials.exists())
finally:
auth._CREDENTIALS_FILE = original_credentials_file

@patch("codecarbon.cli.auth._load_credentials")
def test_get_id_token(self, mock_load):
mock_load.return_value = {"id_token": "i"}
self.assertEqual(auth.get_id_token(), "i")

@patch("codecarbon.cli.auth._save_credentials")
@patch("codecarbon.cli.auth.webbrowser.open")
@patch("codecarbon.cli.auth.HTTPServer")
@patch("codecarbon.cli.auth.OAuth2Session")
@patch(
"codecarbon.cli.auth._discover_endpoints",
return_value={
"authorization_endpoint": "https://auth.example/authorize",
"token_endpoint": "https://auth.example/token",
},
)
def test_authorize_success(
self,
mock_discover,
mock_session_cls,
mock_server_cls,
mock_browser_open,
mock_save_credentials,
):
mock_session = MagicMock()
mock_session.create_authorization_url.return_value = (
"https://auth.example/authorize?state=abc",
"abc",
)
mock_session.fetch_token.return_value = {"access_token": "token"}
mock_session_cls.return_value = mock_session

mock_server = MagicMock()
mock_server.handle_request.side_effect = lambda: setattr(
auth._CallbackHandler,
"callback_url",
"http://localhost:8090/callback?code=123",
)
mock_server_cls.return_value = mock_server

auth._CallbackHandler.callback_url = None
auth._CallbackHandler.error = None

result = auth.authorize()

self.assertEqual(result, {"access_token": "token"})
mock_browser_open.assert_called_once()
mock_server.handle_request.assert_called_once()
mock_server.server_close.assert_called_once()
mock_save_credentials.assert_called_once_with({"access_token": "token"})

@patch("codecarbon.cli.auth.HTTPServer")
@patch("codecarbon.cli.auth.OAuth2Session")
@patch(
"codecarbon.cli.auth._discover_endpoints",
return_value={
"authorization_endpoint": "https://auth.example/authorize",
"token_endpoint": "https://auth.example/token",
},
)
def test_authorize_raises_on_callback_error(
self, mock_discover, mock_session_cls, mock_server_cls
):
mock_session = MagicMock()
mock_session.create_authorization_url.return_value = (
"https://auth.example/authorize?state=abc",
"abc",
)
mock_session_cls.return_value = mock_session
mock_server = MagicMock()
mock_server.handle_request.side_effect = lambda: setattr(
auth._CallbackHandler,
"error",
"access_denied",
)
mock_server_cls.return_value = mock_server

auth._CallbackHandler.callback_url = None
auth._CallbackHandler.error = None

with self.assertRaises(ValueError):
auth.authorize()
mock_server.handle_request.assert_called_once()
mock_server.server_close.assert_called_once()

@patch("codecarbon.cli.auth.HTTPServer")
@patch("codecarbon.cli.auth.OAuth2Session")
@patch(
"codecarbon.cli.auth._discover_endpoints",
return_value={
"authorization_endpoint": "https://auth.example/authorize",
"token_endpoint": "https://auth.example/token",
},
)
def test_authorize_raises_when_no_callback_received(
self, mock_discover, mock_session_cls, mock_server_cls
):
mock_session = MagicMock()
mock_session.create_authorization_url.return_value = (
"https://auth.example/authorize?state=abc",
"abc",
)
mock_session_cls.return_value = mock_session
mock_server_cls.return_value = MagicMock()

auth._CallbackHandler.callback_url = None
auth._CallbackHandler.error = None

with self.assertRaises(ValueError):
auth.authorize()


if __name__ == "__main__":
unittest.main()
107 changes: 107 additions & 0 deletions tests/cli/test_cli_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import configparser

import pytest

from codecarbon.cli import cli_utils


def test_get_config_reads_codecarbon_section(tmp_path):
config_path = tmp_path / ".codecarbon.config"
config_path.write_text("[codecarbon]\napi_endpoint=https://example.test\n")

config = cli_utils.get_config(config_path)

assert config["api_endpoint"] == "https://example.test"


def test_get_config_raises_when_missing(tmp_path):
with pytest.raises(FileNotFoundError):
cli_utils.get_config(tmp_path / ".codecarbon.config")


def test_get_api_endpoint_appends_default_when_missing_key(tmp_path):
config_path = tmp_path / ".codecarbon.config"
config_path.write_text("[codecarbon]\n")

endpoint = cli_utils.get_api_endpoint(config_path)

assert endpoint == "https://api.codecarbon.io"
parser = configparser.ConfigParser()
parser.read(config_path)
assert parser["codecarbon"]["api_endpoint"] == "https://api.codecarbon.io"


def test_get_api_endpoint_returns_default_when_file_missing(tmp_path):
endpoint = cli_utils.get_api_endpoint(tmp_path / ".codecarbon.config")

assert endpoint == "https://api.codecarbon.io"


def test_get_existing_exp_id_returns_none_on_key_error(monkeypatch):
def raise_key_error():
raise KeyError("missing")

monkeypatch.setattr(cli_utils, "get_hierarchical_config", raise_key_error)

assert cli_utils.get_existing_exp_id() is None


def test_get_existing_exp_id_reads_experiment_id(monkeypatch):
monkeypatch.setattr(
cli_utils, "get_hierarchical_config", lambda: {"experiment_id": "exp-123"}
)

assert cli_utils.get_existing_exp_id() == "exp-123"


def test_write_local_exp_id_creates_section(tmp_path):
config_path = tmp_path / ".codecarbon.config"

cli_utils.write_local_exp_id("exp-456", config_path)

parser = configparser.ConfigParser()
parser.read(config_path)
assert parser["codecarbon"]["experiment_id"] == "exp-456"


def test_overwrite_local_config_updates_existing_file(tmp_path):
config_path = tmp_path / ".codecarbon.config"
config_path.write_text("[codecarbon]\nexperiment_id=old\n")

cli_utils.overwrite_local_config("experiment_id", "new", config_path)

parser = configparser.ConfigParser()
parser.read(config_path)
assert parser["codecarbon"]["experiment_id"] == "new"


def test_create_new_config_file_creates_parent_and_file(monkeypatch, tmp_path):
target = tmp_path / "nested" / ".codecarbon.config"
prompts = iter([str(target)])

monkeypatch.setattr(
cli_utils.typer, "prompt", lambda *args, **kwargs: next(prompts)
)
monkeypatch.setattr(cli_utils.Confirm, "ask", lambda *args, **kwargs: True)

created_path = cli_utils.create_new_config_file()

assert created_path == target
assert target.exists()
assert target.read_text() == "[codecarbon]\n"


def test_create_new_config_file_expands_home(monkeypatch, tmp_path):
home = tmp_path / "home"
home.mkdir()
target = home / ".codecarbon.config"

monkeypatch.setattr(cli_utils.Path, "home", lambda: home)
monkeypatch.setattr(
cli_utils.typer, "prompt", lambda *args, **kwargs: "~/.codecarbon.config"
)

created_path = cli_utils.create_new_config_file()

assert created_path == target
assert target.exists()
Loading
Loading