299 lines
10 KiB
Python
Executable File
299 lines
10 KiB
Python
Executable File
"""Unit tests for scripts/git-credential-forge.py.
|
|
|
|
Run with: python3 -m unittest tests.test_git_credential_forge
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import io
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest import mock
|
|
|
|
HERE = Path(__file__).resolve().parent
|
|
ROOT = HERE.parent
|
|
sys.path.insert(0, str(ROOT / "scripts"))
|
|
|
|
# Load the helper as a module even though its filename has hyphens.
|
|
_helper_path = ROOT / "scripts" / "git-credential-forge.py"
|
|
_spec = importlib.util.spec_from_file_location("gcf", _helper_path)
|
|
assert _spec and _spec.loader
|
|
gcf = importlib.util.module_from_spec(_spec)
|
|
_spec.loader.exec_module(gcf)
|
|
|
|
import forge_auth as fa # noqa: E402 (imported after sys.path is extended)
|
|
|
|
|
|
def _write_store(tmp: Path, payload: dict) -> Path:
|
|
p = tmp / "client-auth.json"
|
|
p.write_text(json.dumps(payload), encoding="utf-8")
|
|
return p
|
|
|
|
|
|
# --------------------------------------------------------------------
|
|
# read_git_fields
|
|
# --------------------------------------------------------------------
|
|
class ReadGitFieldsTests(unittest.TestCase):
|
|
|
|
def test_full_block(self) -> None:
|
|
buf = io.StringIO("protocol=https\nhost=g.example:8443\npath=a/b.git\n\n")
|
|
self.assertEqual(
|
|
gcf.read_git_fields(buf),
|
|
{"protocol": "https", "host": "g.example:8443", "path": "a/b.git"},
|
|
)
|
|
|
|
def test_eof_without_blank_line(self) -> None:
|
|
self.assertEqual(
|
|
gcf.read_git_fields(io.StringIO("protocol=https\nhost=x\n")),
|
|
{"protocol": "https", "host": "x"},
|
|
)
|
|
|
|
def test_malformed_line_raises(self) -> None:
|
|
with self.assertRaises(ValueError):
|
|
gcf.read_git_fields(io.StringIO("notakeyvalue\n"))
|
|
|
|
|
|
# --------------------------------------------------------------------
|
|
# host matching
|
|
# --------------------------------------------------------------------
|
|
class RequestMatchesTests(unittest.TestCase):
|
|
|
|
def test_exact_match_including_port(self) -> None:
|
|
self.assertTrue(
|
|
gcf._request_matches(
|
|
{"protocol": "https", "host": "g.example:8443"},
|
|
("https", "g.example", 8443),
|
|
)
|
|
)
|
|
|
|
def test_wrong_scheme_no_match(self) -> None:
|
|
self.assertFalse(
|
|
gcf._request_matches(
|
|
{"protocol": "http", "host": "g.example:8443"},
|
|
("https", "g.example", 8443),
|
|
)
|
|
)
|
|
|
|
def test_wrong_host_no_match(self) -> None:
|
|
self.assertFalse(
|
|
gcf._request_matches(
|
|
{"protocol": "https", "host": "other.example:8443"},
|
|
("https", "g.example", 8443),
|
|
)
|
|
)
|
|
|
|
def test_wrong_port_no_match(self) -> None:
|
|
self.assertFalse(
|
|
gcf._request_matches(
|
|
{"protocol": "https", "host": "g.example:7000"},
|
|
("https", "g.example", 8443),
|
|
)
|
|
)
|
|
|
|
def test_stored_default_https_request_no_port(self) -> None:
|
|
# Stored URL had no explicit port → default 443 inferred.
|
|
# Request without port is OK.
|
|
self.assertTrue(
|
|
gcf._request_matches(
|
|
{"protocol": "https", "host": "g.example"},
|
|
("https", "g.example", None),
|
|
)
|
|
)
|
|
|
|
def test_stored_default_https_request_with_443(self) -> None:
|
|
self.assertTrue(
|
|
gcf._request_matches(
|
|
{"protocol": "https", "host": "g.example:443"},
|
|
("https", "g.example", None),
|
|
)
|
|
)
|
|
|
|
def test_stored_default_https_request_with_other_port_no_match(self) -> None:
|
|
self.assertFalse(
|
|
gcf._request_matches(
|
|
{"protocol": "https", "host": "g.example:8443"},
|
|
("https", "g.example", None),
|
|
)
|
|
)
|
|
|
|
|
|
# --------------------------------------------------------------------
|
|
# cmd_get (end-to-end, inside a sandboxed FSDGG_AUTH_STORE_PATH)
|
|
# --------------------------------------------------------------------
|
|
class CmdGetTests(unittest.TestCase):
|
|
|
|
def _run(
|
|
self,
|
|
*,
|
|
store_payload: dict | None,
|
|
stdin_text: str,
|
|
env_extra: dict[str, str] | None = None,
|
|
) -> tuple[int, str, str]:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
tmp = Path(d)
|
|
env = {
|
|
"FSDGG_AUTH_STORE_PATH": str(tmp / "client-auth.json"),
|
|
"FORGE_GITEA_URL": "https://g.example:8443",
|
|
"FSDGG_CLI_CLIENT_ID": "client-1",
|
|
"FSDGG_CLI_REDIRECT_URI": "http://127.0.0.1:38111/callback",
|
|
"HOME": str(tmp),
|
|
}
|
|
if env_extra:
|
|
env.update(env_extra)
|
|
if store_payload is not None:
|
|
_write_store(tmp, store_payload)
|
|
fields = gcf.read_git_fields(io.StringIO(stdin_text))
|
|
buf_out, buf_err = io.StringIO(), io.StringIO()
|
|
real_out, real_err = sys.stdout, sys.stderr
|
|
sys.stdout, sys.stderr = buf_out, buf_err
|
|
try:
|
|
with mock.patch.dict(os.environ, env, clear=True):
|
|
rc = gcf.cmd_get(fields)
|
|
finally:
|
|
sys.stdout, sys.stderr = real_out, real_err
|
|
return rc, buf_out.getvalue(), buf_err.getvalue()
|
|
|
|
# --- matching host -------------------------------------------------
|
|
|
|
def test_match_live_token_returns_credentials(self) -> None:
|
|
payload = {
|
|
"username": "alice",
|
|
"gitea_access_token": "LIVETOKEN",
|
|
"gitea_token_expires_at": time.time() + 3600,
|
|
"_forge_gitea_base_url": "https://g.example:8443",
|
|
}
|
|
rc, out, _ = self._run(
|
|
store_payload=payload,
|
|
stdin_text="protocol=https\nhost=g.example:8443\npath=org/repo.git\n\n",
|
|
)
|
|
self.assertEqual(rc, 0)
|
|
parsed = dict(l.split("=", 1) for l in out.strip().splitlines())
|
|
self.assertEqual(parsed["username"], "alice")
|
|
self.assertEqual(parsed["password"], "LIVETOKEN")
|
|
self.assertEqual(parsed["host"], "g.example:8443")
|
|
|
|
def test_no_store_passes_through(self) -> None:
|
|
rc, out, _ = self._run(
|
|
store_payload=None,
|
|
stdin_text="protocol=https\nhost=g.example:8443\n\n",
|
|
)
|
|
self.assertEqual(rc, 0)
|
|
self.assertNotIn("password=", out)
|
|
self.assertIn("host=g.example:8443", out)
|
|
|
|
def test_non_matching_host_passes_through(self) -> None:
|
|
payload = {
|
|
"username": "alice",
|
|
"gitea_access_token": "LIVETOKEN",
|
|
"gitea_token_expires_at": time.time() + 3600,
|
|
"_forge_gitea_base_url": "https://g.example:8443",
|
|
}
|
|
rc, out, _ = self._run(
|
|
store_payload=payload,
|
|
stdin_text="protocol=https\nhost=github.com\n\n",
|
|
)
|
|
self.assertEqual(rc, 0)
|
|
self.assertNotIn("password=", out)
|
|
self.assertIn("host=github.com", out)
|
|
|
|
def test_match_but_no_token_passes_through(self) -> None:
|
|
payload = {
|
|
"username": "alice",
|
|
"gitea_access_token": "",
|
|
"_forge_gitea_base_url": "https://g.example:8443",
|
|
}
|
|
rc, out, _ = self._run(
|
|
store_payload=payload,
|
|
stdin_text="protocol=https\nhost=g.example:8443\n\n",
|
|
)
|
|
self.assertEqual(rc, 0)
|
|
self.assertNotIn("password=", out)
|
|
|
|
# --- expired token + refresh -------------------------------------
|
|
|
|
def test_expired_token_triggers_refresh(self) -> None:
|
|
payload = {
|
|
"username": "alice",
|
|
"gitea_access_token": "EXPIRED",
|
|
"gitea_token_expires_at": time.time() - 10,
|
|
"_forge_gitea_base_url": "https://g.example:8443",
|
|
"_forge_refresh_token": "REFRESH",
|
|
"_forge_client_id": "client-1",
|
|
}
|
|
# Patch the refresh path directly.
|
|
def fake_refresh(config, *, must_refresh=False):
|
|
f = fa.AuthFile.read(fa.auth_store_path())
|
|
f.merge_refresh(
|
|
gitea_access_token="ROTATED",
|
|
gitea_token_expires_at=time.time() + 3600,
|
|
refresh_token="NEW-RT",
|
|
)
|
|
f.write(fa.auth_store_path())
|
|
return f
|
|
|
|
with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh):
|
|
rc, out, _ = self._run(
|
|
store_payload=payload,
|
|
stdin_text="protocol=https\nhost=g.example:8443\n\n",
|
|
)
|
|
self.assertEqual(rc, 0)
|
|
parsed = dict(l.split("=", 1) for l in out.strip().splitlines())
|
|
self.assertEqual(parsed["password"], "ROTATED")
|
|
|
|
def test_refresh_failure_passes_through_with_stderr(self) -> None:
|
|
payload = {
|
|
"username": "alice",
|
|
"gitea_access_token": "EXPIRED",
|
|
"gitea_token_expires_at": time.time() - 10,
|
|
"_forge_gitea_base_url": "https://g.example:8443",
|
|
"_forge_refresh_token": "DEAD",
|
|
}
|
|
|
|
def fake_refresh(*_args, **_kwargs):
|
|
raise fa.AuthError("refresh token revoked")
|
|
|
|
with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh):
|
|
rc, out, err = self._run(
|
|
store_payload=payload,
|
|
stdin_text="protocol=https\nhost=g.example:8443\n\n",
|
|
)
|
|
self.assertEqual(rc, 0)
|
|
self.assertNotIn("password=", out)
|
|
self.assertIn("token refresh failed", err)
|
|
self.assertIn("just login", err)
|
|
|
|
|
|
# --------------------------------------------------------------------
|
|
# main() dispatcher
|
|
# --------------------------------------------------------------------
|
|
class MainDispatcherTests(unittest.TestCase):
|
|
|
|
def _main(self, argv: list[str], stdin_text: str = "") -> int:
|
|
real_stdin = sys.stdin
|
|
sys.stdin = io.StringIO(stdin_text)
|
|
try:
|
|
return gcf.main(argv)
|
|
finally:
|
|
sys.stdin = real_stdin
|
|
|
|
def test_store_is_noop(self) -> None:
|
|
self.assertEqual(self._main(["h", "store"], "username=x\npassword=y\n\n"), 0)
|
|
|
|
def test_erase_is_noop(self) -> None:
|
|
self.assertEqual(self._main(["h", "erase"], "username=x\npassword=y\n\n"), 0)
|
|
|
|
def test_no_action_rc_2(self) -> None:
|
|
self.assertEqual(self._main(["h"]), 2)
|
|
|
|
def test_unknown_action_rc_2(self) -> None:
|
|
self.assertEqual(self._main(["h", "bogus"]), 2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|