Initial Commit
This commit is contained in:
298
tests/test_git_credential_forge.py
Executable file
298
tests/test_git_credential_forge.py
Executable file
@@ -0,0 +1,298 @@
|
||||
"""Unit tests for scripts/git-credential-forge.py.
|
||||
|
||||
Run with: python3 -m unittest tests.test_git_credential_forge
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
HERE = Path(__file__).resolve().parent
|
||||
ROOT = HERE.parent
|
||||
sys.path.insert(0, str(ROOT / "scripts"))
|
||||
|
||||
# Load the helper as a module even though its filename has hyphens.
|
||||
_helper_path = ROOT / "scripts" / "git-credential-forge.py"
|
||||
_spec = importlib.util.spec_from_file_location("gcf", _helper_path)
|
||||
assert _spec and _spec.loader
|
||||
gcf = importlib.util.module_from_spec(_spec)
|
||||
_spec.loader.exec_module(gcf)
|
||||
|
||||
import forge_auth as fa # noqa: E402 (imported after sys.path is extended)
|
||||
|
||||
|
||||
def _write_store(tmp: Path, payload: dict) -> Path:
|
||||
p = tmp / "client-auth.json"
|
||||
p.write_text(json.dumps(payload), encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# read_git_fields
|
||||
# --------------------------------------------------------------------
|
||||
class ReadGitFieldsTests(unittest.TestCase):
|
||||
|
||||
def test_full_block(self) -> None:
|
||||
buf = io.StringIO("protocol=https\nhost=g.example:6006\npath=a/b.git\n\n")
|
||||
self.assertEqual(
|
||||
gcf.read_git_fields(buf),
|
||||
{"protocol": "https", "host": "g.example:6006", "path": "a/b.git"},
|
||||
)
|
||||
|
||||
def test_eof_without_blank_line(self) -> None:
|
||||
self.assertEqual(
|
||||
gcf.read_git_fields(io.StringIO("protocol=https\nhost=x\n")),
|
||||
{"protocol": "https", "host": "x"},
|
||||
)
|
||||
|
||||
def test_malformed_line_raises(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
gcf.read_git_fields(io.StringIO("notakeyvalue\n"))
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# host matching
|
||||
# --------------------------------------------------------------------
|
||||
class RequestMatchesTests(unittest.TestCase):
|
||||
|
||||
def test_exact_match_including_port(self) -> None:
|
||||
self.assertTrue(
|
||||
gcf._request_matches(
|
||||
{"protocol": "https", "host": "g.example:6006"},
|
||||
("https", "g.example", 6006),
|
||||
)
|
||||
)
|
||||
|
||||
def test_wrong_scheme_no_match(self) -> None:
|
||||
self.assertFalse(
|
||||
gcf._request_matches(
|
||||
{"protocol": "http", "host": "g.example:6006"},
|
||||
("https", "g.example", 6006),
|
||||
)
|
||||
)
|
||||
|
||||
def test_wrong_host_no_match(self) -> None:
|
||||
self.assertFalse(
|
||||
gcf._request_matches(
|
||||
{"protocol": "https", "host": "other.example:6006"},
|
||||
("https", "g.example", 6006),
|
||||
)
|
||||
)
|
||||
|
||||
def test_wrong_port_no_match(self) -> None:
|
||||
self.assertFalse(
|
||||
gcf._request_matches(
|
||||
{"protocol": "https", "host": "g.example:7000"},
|
||||
("https", "g.example", 6006),
|
||||
)
|
||||
)
|
||||
|
||||
def test_stored_default_https_request_no_port(self) -> None:
|
||||
# Stored URL had no explicit port → default 443 inferred.
|
||||
# Request without port is OK.
|
||||
self.assertTrue(
|
||||
gcf._request_matches(
|
||||
{"protocol": "https", "host": "g.example"},
|
||||
("https", "g.example", None),
|
||||
)
|
||||
)
|
||||
|
||||
def test_stored_default_https_request_with_443(self) -> None:
|
||||
self.assertTrue(
|
||||
gcf._request_matches(
|
||||
{"protocol": "https", "host": "g.example:443"},
|
||||
("https", "g.example", None),
|
||||
)
|
||||
)
|
||||
|
||||
def test_stored_default_https_request_with_other_port_no_match(self) -> None:
|
||||
self.assertFalse(
|
||||
gcf._request_matches(
|
||||
{"protocol": "https", "host": "g.example:6006"},
|
||||
("https", "g.example", None),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# cmd_get (end-to-end, inside a sandboxed FSDGG_AUTH_STORE_PATH)
|
||||
# --------------------------------------------------------------------
|
||||
class CmdGetTests(unittest.TestCase):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
*,
|
||||
store_payload: dict | None,
|
||||
stdin_text: str,
|
||||
env_extra: dict[str, str] | None = None,
|
||||
) -> tuple[int, str, str]:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
tmp = Path(d)
|
||||
env = {
|
||||
"FSDGG_AUTH_STORE_PATH": str(tmp / "client-auth.json"),
|
||||
"FORGE_GITEA_URL": "https://g.example:6006",
|
||||
"FSDGG_CLI_CLIENT_ID": "client-1",
|
||||
"FSDGG_CLI_REDIRECT_URI": "http://127.0.0.1:38111/callback",
|
||||
"HOME": str(tmp),
|
||||
}
|
||||
if env_extra:
|
||||
env.update(env_extra)
|
||||
if store_payload is not None:
|
||||
_write_store(tmp, store_payload)
|
||||
fields = gcf.read_git_fields(io.StringIO(stdin_text))
|
||||
buf_out, buf_err = io.StringIO(), io.StringIO()
|
||||
real_out, real_err = sys.stdout, sys.stderr
|
||||
sys.stdout, sys.stderr = buf_out, buf_err
|
||||
try:
|
||||
with mock.patch.dict(os.environ, env, clear=True):
|
||||
rc = gcf.cmd_get(fields)
|
||||
finally:
|
||||
sys.stdout, sys.stderr = real_out, real_err
|
||||
return rc, buf_out.getvalue(), buf_err.getvalue()
|
||||
|
||||
# --- matching host -------------------------------------------------
|
||||
|
||||
def test_match_live_token_returns_credentials(self) -> None:
|
||||
payload = {
|
||||
"username": "alice",
|
||||
"gitea_access_token": "LIVETOKEN",
|
||||
"gitea_token_expires_at": time.time() + 3600,
|
||||
"_forge_gitea_base_url": "https://g.example:6006",
|
||||
}
|
||||
rc, out, _ = self._run(
|
||||
store_payload=payload,
|
||||
stdin_text="protocol=https\nhost=g.example:6006\npath=org/repo.git\n\n",
|
||||
)
|
||||
self.assertEqual(rc, 0)
|
||||
parsed = dict(l.split("=", 1) for l in out.strip().splitlines())
|
||||
self.assertEqual(parsed["username"], "alice")
|
||||
self.assertEqual(parsed["password"], "LIVETOKEN")
|
||||
self.assertEqual(parsed["host"], "g.example:6006")
|
||||
|
||||
def test_no_store_passes_through(self) -> None:
|
||||
rc, out, _ = self._run(
|
||||
store_payload=None,
|
||||
stdin_text="protocol=https\nhost=g.example:6006\n\n",
|
||||
)
|
||||
self.assertEqual(rc, 0)
|
||||
self.assertNotIn("password=", out)
|
||||
self.assertIn("host=g.example:6006", out)
|
||||
|
||||
def test_non_matching_host_passes_through(self) -> None:
|
||||
payload = {
|
||||
"username": "alice",
|
||||
"gitea_access_token": "LIVETOKEN",
|
||||
"gitea_token_expires_at": time.time() + 3600,
|
||||
"_forge_gitea_base_url": "https://g.example:6006",
|
||||
}
|
||||
rc, out, _ = self._run(
|
||||
store_payload=payload,
|
||||
stdin_text="protocol=https\nhost=github.com\n\n",
|
||||
)
|
||||
self.assertEqual(rc, 0)
|
||||
self.assertNotIn("password=", out)
|
||||
self.assertIn("host=github.com", out)
|
||||
|
||||
def test_match_but_no_token_passes_through(self) -> None:
|
||||
payload = {
|
||||
"username": "alice",
|
||||
"gitea_access_token": "",
|
||||
"_forge_gitea_base_url": "https://g.example:6006",
|
||||
}
|
||||
rc, out, _ = self._run(
|
||||
store_payload=payload,
|
||||
stdin_text="protocol=https\nhost=g.example:6006\n\n",
|
||||
)
|
||||
self.assertEqual(rc, 0)
|
||||
self.assertNotIn("password=", out)
|
||||
|
||||
# --- expired token + refresh -------------------------------------
|
||||
|
||||
def test_expired_token_triggers_refresh(self) -> None:
|
||||
payload = {
|
||||
"username": "alice",
|
||||
"gitea_access_token": "EXPIRED",
|
||||
"gitea_token_expires_at": time.time() - 10,
|
||||
"_forge_gitea_base_url": "https://g.example:6006",
|
||||
"_forge_refresh_token": "REFRESH",
|
||||
"_forge_client_id": "client-1",
|
||||
}
|
||||
# Patch the refresh path directly.
|
||||
def fake_refresh(config, *, must_refresh=False):
|
||||
f = fa.AuthFile.read(fa.auth_store_path())
|
||||
f.merge_refresh(
|
||||
gitea_access_token="ROTATED",
|
||||
gitea_token_expires_at=time.time() + 3600,
|
||||
refresh_token="NEW-RT",
|
||||
)
|
||||
f.write(fa.auth_store_path())
|
||||
return f
|
||||
|
||||
with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh):
|
||||
rc, out, _ = self._run(
|
||||
store_payload=payload,
|
||||
stdin_text="protocol=https\nhost=g.example:6006\n\n",
|
||||
)
|
||||
self.assertEqual(rc, 0)
|
||||
parsed = dict(l.split("=", 1) for l in out.strip().splitlines())
|
||||
self.assertEqual(parsed["password"], "ROTATED")
|
||||
|
||||
def test_refresh_failure_passes_through_with_stderr(self) -> None:
|
||||
payload = {
|
||||
"username": "alice",
|
||||
"gitea_access_token": "EXPIRED",
|
||||
"gitea_token_expires_at": time.time() - 10,
|
||||
"_forge_gitea_base_url": "https://g.example:6006",
|
||||
"_forge_refresh_token": "DEAD",
|
||||
}
|
||||
|
||||
def fake_refresh(*_args, **_kwargs):
|
||||
raise fa.AuthError("refresh token revoked")
|
||||
|
||||
with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh):
|
||||
rc, out, err = self._run(
|
||||
store_payload=payload,
|
||||
stdin_text="protocol=https\nhost=g.example:6006\n\n",
|
||||
)
|
||||
self.assertEqual(rc, 0)
|
||||
self.assertNotIn("password=", out)
|
||||
self.assertIn("token refresh failed", err)
|
||||
self.assertIn("just login", err)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------
|
||||
# main() dispatcher
|
||||
# --------------------------------------------------------------------
|
||||
class MainDispatcherTests(unittest.TestCase):
|
||||
|
||||
def _main(self, argv: list[str], stdin_text: str = "") -> int:
|
||||
real_stdin = sys.stdin
|
||||
sys.stdin = io.StringIO(stdin_text)
|
||||
try:
|
||||
return gcf.main(argv)
|
||||
finally:
|
||||
sys.stdin = real_stdin
|
||||
|
||||
def test_store_is_noop(self) -> None:
|
||||
self.assertEqual(self._main(["h", "store"], "username=x\npassword=y\n\n"), 0)
|
||||
|
||||
def test_erase_is_noop(self) -> None:
|
||||
self.assertEqual(self._main(["h", "erase"], "username=x\npassword=y\n\n"), 0)
|
||||
|
||||
def test_no_action_rc_2(self) -> None:
|
||||
self.assertEqual(self._main(["h"]), 2)
|
||||
|
||||
def test_unknown_action_rc_2(self) -> None:
|
||||
self.assertEqual(self._main(["h", "bogus"]), 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user