Files
welcome-to-codevalet-as-a-p…/tests/test_git_credential_forge.py
2026-04-29 09:38:02 -04:00

299 lines
10 KiB
Python
Executable File

"""Unit tests for scripts/git-credential-forge.py.
Run with: python3 -m unittest tests.test_git_credential_forge
"""
from __future__ import annotations
import importlib.util
import io
import json
import os
import sys
import tempfile
import time
import unittest
from pathlib import Path
from unittest import mock
HERE = Path(__file__).resolve().parent
ROOT = HERE.parent
sys.path.insert(0, str(ROOT / "scripts"))
# Load the helper as a module even though its filename has hyphens.
_helper_path = ROOT / "scripts" / "git-credential-forge.py"
_spec = importlib.util.spec_from_file_location("gcf", _helper_path)
assert _spec and _spec.loader
gcf = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(gcf)
import forge_auth as fa # noqa: E402 (imported after sys.path is extended)
def _write_store(tmp: Path, payload: dict) -> Path:
p = tmp / "client-auth.json"
p.write_text(json.dumps(payload), encoding="utf-8")
return p
# --------------------------------------------------------------------
# read_git_fields
# --------------------------------------------------------------------
class ReadGitFieldsTests(unittest.TestCase):
def test_full_block(self) -> None:
buf = io.StringIO("protocol=https\nhost=g.example:8443\npath=a/b.git\n\n")
self.assertEqual(
gcf.read_git_fields(buf),
{"protocol": "https", "host": "g.example:8443", "path": "a/b.git"},
)
def test_eof_without_blank_line(self) -> None:
self.assertEqual(
gcf.read_git_fields(io.StringIO("protocol=https\nhost=x\n")),
{"protocol": "https", "host": "x"},
)
def test_malformed_line_raises(self) -> None:
with self.assertRaises(ValueError):
gcf.read_git_fields(io.StringIO("notakeyvalue\n"))
# --------------------------------------------------------------------
# host matching
# --------------------------------------------------------------------
class RequestMatchesTests(unittest.TestCase):
def test_exact_match_including_port(self) -> None:
self.assertTrue(
gcf._request_matches(
{"protocol": "https", "host": "g.example:8443"},
("https", "g.example", 8443),
)
)
def test_wrong_scheme_no_match(self) -> None:
self.assertFalse(
gcf._request_matches(
{"protocol": "http", "host": "g.example:8443"},
("https", "g.example", 8443),
)
)
def test_wrong_host_no_match(self) -> None:
self.assertFalse(
gcf._request_matches(
{"protocol": "https", "host": "other.example:8443"},
("https", "g.example", 8443),
)
)
def test_wrong_port_no_match(self) -> None:
self.assertFalse(
gcf._request_matches(
{"protocol": "https", "host": "g.example:7000"},
("https", "g.example", 8443),
)
)
def test_stored_default_https_request_no_port(self) -> None:
# Stored URL had no explicit port → default 443 inferred.
# Request without port is OK.
self.assertTrue(
gcf._request_matches(
{"protocol": "https", "host": "g.example"},
("https", "g.example", None),
)
)
def test_stored_default_https_request_with_443(self) -> None:
self.assertTrue(
gcf._request_matches(
{"protocol": "https", "host": "g.example:443"},
("https", "g.example", None),
)
)
def test_stored_default_https_request_with_other_port_no_match(self) -> None:
self.assertFalse(
gcf._request_matches(
{"protocol": "https", "host": "g.example:8443"},
("https", "g.example", None),
)
)
# --------------------------------------------------------------------
# cmd_get (end-to-end, inside a sandboxed FSDGG_AUTH_STORE_PATH)
# --------------------------------------------------------------------
class CmdGetTests(unittest.TestCase):
def _run(
self,
*,
store_payload: dict | None,
stdin_text: str,
env_extra: dict[str, str] | None = None,
) -> tuple[int, str, str]:
with tempfile.TemporaryDirectory() as d:
tmp = Path(d)
env = {
"FSDGG_AUTH_STORE_PATH": str(tmp / "client-auth.json"),
"FORGE_GITEA_URL": "https://g.example:8443",
"FSDGG_CLI_CLIENT_ID": "client-1",
"FSDGG_CLI_REDIRECT_URI": "http://127.0.0.1:38111/callback",
"HOME": str(tmp),
}
if env_extra:
env.update(env_extra)
if store_payload is not None:
_write_store(tmp, store_payload)
fields = gcf.read_git_fields(io.StringIO(stdin_text))
buf_out, buf_err = io.StringIO(), io.StringIO()
real_out, real_err = sys.stdout, sys.stderr
sys.stdout, sys.stderr = buf_out, buf_err
try:
with mock.patch.dict(os.environ, env, clear=True):
rc = gcf.cmd_get(fields)
finally:
sys.stdout, sys.stderr = real_out, real_err
return rc, buf_out.getvalue(), buf_err.getvalue()
# --- matching host -------------------------------------------------
def test_match_live_token_returns_credentials(self) -> None:
payload = {
"username": "alice",
"gitea_access_token": "LIVETOKEN",
"gitea_token_expires_at": time.time() + 3600,
"_forge_gitea_base_url": "https://g.example:8443",
}
rc, out, _ = self._run(
store_payload=payload,
stdin_text="protocol=https\nhost=g.example:8443\npath=org/repo.git\n\n",
)
self.assertEqual(rc, 0)
parsed = dict(l.split("=", 1) for l in out.strip().splitlines())
self.assertEqual(parsed["username"], "alice")
self.assertEqual(parsed["password"], "LIVETOKEN")
self.assertEqual(parsed["host"], "g.example:8443")
def test_no_store_passes_through(self) -> None:
rc, out, _ = self._run(
store_payload=None,
stdin_text="protocol=https\nhost=g.example:8443\n\n",
)
self.assertEqual(rc, 0)
self.assertNotIn("password=", out)
self.assertIn("host=g.example:8443", out)
def test_non_matching_host_passes_through(self) -> None:
payload = {
"username": "alice",
"gitea_access_token": "LIVETOKEN",
"gitea_token_expires_at": time.time() + 3600,
"_forge_gitea_base_url": "https://g.example:8443",
}
rc, out, _ = self._run(
store_payload=payload,
stdin_text="protocol=https\nhost=github.com\n\n",
)
self.assertEqual(rc, 0)
self.assertNotIn("password=", out)
self.assertIn("host=github.com", out)
def test_match_but_no_token_passes_through(self) -> None:
payload = {
"username": "alice",
"gitea_access_token": "",
"_forge_gitea_base_url": "https://g.example:8443",
}
rc, out, _ = self._run(
store_payload=payload,
stdin_text="protocol=https\nhost=g.example:8443\n\n",
)
self.assertEqual(rc, 0)
self.assertNotIn("password=", out)
# --- expired token + refresh -------------------------------------
def test_expired_token_triggers_refresh(self) -> None:
payload = {
"username": "alice",
"gitea_access_token": "EXPIRED",
"gitea_token_expires_at": time.time() - 10,
"_forge_gitea_base_url": "https://g.example:8443",
"_forge_refresh_token": "REFRESH",
"_forge_client_id": "client-1",
}
# Patch the refresh path directly.
def fake_refresh(config, *, must_refresh=False):
f = fa.AuthFile.read(fa.auth_store_path())
f.merge_refresh(
gitea_access_token="ROTATED",
gitea_token_expires_at=time.time() + 3600,
refresh_token="NEW-RT",
)
f.write(fa.auth_store_path())
return f
with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh):
rc, out, _ = self._run(
store_payload=payload,
stdin_text="protocol=https\nhost=g.example:8443\n\n",
)
self.assertEqual(rc, 0)
parsed = dict(l.split("=", 1) for l in out.strip().splitlines())
self.assertEqual(parsed["password"], "ROTATED")
def test_refresh_failure_passes_through_with_stderr(self) -> None:
payload = {
"username": "alice",
"gitea_access_token": "EXPIRED",
"gitea_token_expires_at": time.time() - 10,
"_forge_gitea_base_url": "https://g.example:8443",
"_forge_refresh_token": "DEAD",
}
def fake_refresh(*_args, **_kwargs):
raise fa.AuthError("refresh token revoked")
with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh):
rc, out, err = self._run(
store_payload=payload,
stdin_text="protocol=https\nhost=g.example:8443\n\n",
)
self.assertEqual(rc, 0)
self.assertNotIn("password=", out)
self.assertIn("token refresh failed", err)
self.assertIn("just login", err)
# --------------------------------------------------------------------
# main() dispatcher
# --------------------------------------------------------------------
class MainDispatcherTests(unittest.TestCase):
def _main(self, argv: list[str], stdin_text: str = "") -> int:
real_stdin = sys.stdin
sys.stdin = io.StringIO(stdin_text)
try:
return gcf.main(argv)
finally:
sys.stdin = real_stdin
def test_store_is_noop(self) -> None:
self.assertEqual(self._main(["h", "store"], "username=x\npassword=y\n\n"), 0)
def test_erase_is_noop(self) -> None:
self.assertEqual(self._main(["h", "erase"], "username=x\npassword=y\n\n"), 0)
def test_no_action_rc_2(self) -> None:
self.assertEqual(self._main(["h"]), 2)
def test_unknown_action_rc_2(self) -> None:
self.assertEqual(self._main(["h", "bogus"]), 2)
if __name__ == "__main__":
unittest.main()