"""Unit tests for scripts/git-credential-forge.py. Run with: python3 -m unittest tests.test_git_credential_forge """ from __future__ import annotations import importlib.util import io import json import os import sys import tempfile import time import unittest from pathlib import Path from unittest import mock HERE = Path(__file__).resolve().parent ROOT = HERE.parent sys.path.insert(0, str(ROOT / "scripts")) # Load the helper as a module even though its filename has hyphens. _helper_path = ROOT / "scripts" / "git-credential-forge.py" _spec = importlib.util.spec_from_file_location("gcf", _helper_path) assert _spec and _spec.loader gcf = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(gcf) import forge_auth as fa # noqa: E402 (imported after sys.path is extended) def _write_store(tmp: Path, payload: dict) -> Path: p = tmp / "client-auth.json" p.write_text(json.dumps(payload), encoding="utf-8") return p # -------------------------------------------------------------------- # read_git_fields # -------------------------------------------------------------------- class ReadGitFieldsTests(unittest.TestCase): def test_full_block(self) -> None: buf = io.StringIO("protocol=https\nhost=g.example:6006\npath=a/b.git\n\n") self.assertEqual( gcf.read_git_fields(buf), {"protocol": "https", "host": "g.example:6006", "path": "a/b.git"}, ) def test_eof_without_blank_line(self) -> None: self.assertEqual( gcf.read_git_fields(io.StringIO("protocol=https\nhost=x\n")), {"protocol": "https", "host": "x"}, ) def test_malformed_line_raises(self) -> None: with self.assertRaises(ValueError): gcf.read_git_fields(io.StringIO("notakeyvalue\n")) # -------------------------------------------------------------------- # host matching # -------------------------------------------------------------------- class RequestMatchesTests(unittest.TestCase): def test_exact_match_including_port(self) -> None: self.assertTrue( gcf._request_matches( {"protocol": "https", "host": "g.example:6006"}, ("https", "g.example", 6006), ) ) def test_wrong_scheme_no_match(self) -> None: self.assertFalse( gcf._request_matches( {"protocol": "http", "host": "g.example:6006"}, ("https", "g.example", 6006), ) ) def test_wrong_host_no_match(self) -> None: self.assertFalse( gcf._request_matches( {"protocol": "https", "host": "other.example:6006"}, ("https", "g.example", 6006), ) ) def test_wrong_port_no_match(self) -> None: self.assertFalse( gcf._request_matches( {"protocol": "https", "host": "g.example:7000"}, ("https", "g.example", 6006), ) ) def test_stored_default_https_request_no_port(self) -> None: # Stored URL had no explicit port → default 443 inferred. # Request without port is OK. self.assertTrue( gcf._request_matches( {"protocol": "https", "host": "g.example"}, ("https", "g.example", None), ) ) def test_stored_default_https_request_with_443(self) -> None: self.assertTrue( gcf._request_matches( {"protocol": "https", "host": "g.example:443"}, ("https", "g.example", None), ) ) def test_stored_default_https_request_with_other_port_no_match(self) -> None: self.assertFalse( gcf._request_matches( {"protocol": "https", "host": "g.example:6006"}, ("https", "g.example", None), ) ) # -------------------------------------------------------------------- # cmd_get (end-to-end, inside a sandboxed FSDGG_AUTH_STORE_PATH) # -------------------------------------------------------------------- class CmdGetTests(unittest.TestCase): def _run( self, *, store_payload: dict | None, stdin_text: str, env_extra: dict[str, str] | None = None, ) -> tuple[int, str, str]: with tempfile.TemporaryDirectory() as d: tmp = Path(d) env = { "FSDGG_AUTH_STORE_PATH": str(tmp / "client-auth.json"), "FORGE_GITEA_URL": "https://g.example:6006", "FSDGG_CLI_CLIENT_ID": "client-1", "FSDGG_CLI_REDIRECT_URI": "http://127.0.0.1:38111/callback", "HOME": str(tmp), } if env_extra: env.update(env_extra) if store_payload is not None: _write_store(tmp, store_payload) fields = gcf.read_git_fields(io.StringIO(stdin_text)) buf_out, buf_err = io.StringIO(), io.StringIO() real_out, real_err = sys.stdout, sys.stderr sys.stdout, sys.stderr = buf_out, buf_err try: with mock.patch.dict(os.environ, env, clear=True): rc = gcf.cmd_get(fields) finally: sys.stdout, sys.stderr = real_out, real_err return rc, buf_out.getvalue(), buf_err.getvalue() # --- matching host ------------------------------------------------- def test_match_live_token_returns_credentials(self) -> None: payload = { "username": "alice", "gitea_access_token": "LIVETOKEN", "gitea_token_expires_at": time.time() + 3600, "_forge_gitea_base_url": "https://g.example:6006", } rc, out, _ = self._run( store_payload=payload, stdin_text="protocol=https\nhost=g.example:6006\npath=org/repo.git\n\n", ) self.assertEqual(rc, 0) parsed = dict(l.split("=", 1) for l in out.strip().splitlines()) self.assertEqual(parsed["username"], "alice") self.assertEqual(parsed["password"], "LIVETOKEN") self.assertEqual(parsed["host"], "g.example:6006") def test_no_store_passes_through(self) -> None: rc, out, _ = self._run( store_payload=None, stdin_text="protocol=https\nhost=g.example:6006\n\n", ) self.assertEqual(rc, 0) self.assertNotIn("password=", out) self.assertIn("host=g.example:6006", out) def test_non_matching_host_passes_through(self) -> None: payload = { "username": "alice", "gitea_access_token": "LIVETOKEN", "gitea_token_expires_at": time.time() + 3600, "_forge_gitea_base_url": "https://g.example:6006", } rc, out, _ = self._run( store_payload=payload, stdin_text="protocol=https\nhost=github.com\n\n", ) self.assertEqual(rc, 0) self.assertNotIn("password=", out) self.assertIn("host=github.com", out) def test_match_but_no_token_passes_through(self) -> None: payload = { "username": "alice", "gitea_access_token": "", "_forge_gitea_base_url": "https://g.example:6006", } rc, out, _ = self._run( store_payload=payload, stdin_text="protocol=https\nhost=g.example:6006\n\n", ) self.assertEqual(rc, 0) self.assertNotIn("password=", out) # --- expired token + refresh ------------------------------------- def test_expired_token_triggers_refresh(self) -> None: payload = { "username": "alice", "gitea_access_token": "EXPIRED", "gitea_token_expires_at": time.time() - 10, "_forge_gitea_base_url": "https://g.example:6006", "_forge_refresh_token": "REFRESH", "_forge_client_id": "client-1", } # Patch the refresh path directly. def fake_refresh(config, *, must_refresh=False): f = fa.AuthFile.read(fa.auth_store_path()) f.merge_refresh( gitea_access_token="ROTATED", gitea_token_expires_at=time.time() + 3600, refresh_token="NEW-RT", ) f.write(fa.auth_store_path()) return f with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh): rc, out, _ = self._run( store_payload=payload, stdin_text="protocol=https\nhost=g.example:6006\n\n", ) self.assertEqual(rc, 0) parsed = dict(l.split("=", 1) for l in out.strip().splitlines()) self.assertEqual(parsed["password"], "ROTATED") def test_refresh_failure_passes_through_with_stderr(self) -> None: payload = { "username": "alice", "gitea_access_token": "EXPIRED", "gitea_token_expires_at": time.time() - 10, "_forge_gitea_base_url": "https://g.example:6006", "_forge_refresh_token": "DEAD", } def fake_refresh(*_args, **_kwargs): raise fa.AuthError("refresh token revoked") with mock.patch.object(gcf.forge_auth, "run_refresh", side_effect=fake_refresh): rc, out, err = self._run( store_payload=payload, stdin_text="protocol=https\nhost=g.example:6006\n\n", ) self.assertEqual(rc, 0) self.assertNotIn("password=", out) self.assertIn("token refresh failed", err) self.assertIn("just login", err) # -------------------------------------------------------------------- # main() dispatcher # -------------------------------------------------------------------- class MainDispatcherTests(unittest.TestCase): def _main(self, argv: list[str], stdin_text: str = "") -> int: real_stdin = sys.stdin sys.stdin = io.StringIO(stdin_text) try: return gcf.main(argv) finally: sys.stdin = real_stdin def test_store_is_noop(self) -> None: self.assertEqual(self._main(["h", "store"], "username=x\npassword=y\n\n"), 0) def test_erase_is_noop(self) -> None: self.assertEqual(self._main(["h", "erase"], "username=x\npassword=y\n\n"), 0) def test_no_action_rc_2(self) -> None: self.assertEqual(self._main(["h"]), 2) def test_unknown_action_rc_2(self) -> None: self.assertEqual(self._main(["h", "bogus"]), 2) if __name__ == "__main__": unittest.main()