Initial Commit

This commit is contained in:
FanaticPythoner (Nathan Trudeau)
2026-04-19 17:11:58 -04:00
parent eccb05b97f
commit a591cd21f2
23 changed files with 4896 additions and 1 deletions

230
tests/mock_oidc_server.py Executable file
View File

@@ -0,0 +1,230 @@
"""Minimal OIDC + OAuth2 PKCE server used by the integration test.
Implements just enough of Gitea's `/.well-known/openid-configuration`,
`/login/oauth/authorize`, `/login/oauth/access_token`, and
`/login/oauth/userinfo` surface for the welcome-repo's `forge_auth.py`
to run an end-to-end login + refresh without touching a real Gitea.
Test fixture only. Binds to loopback, accepts any non-empty
`client_id`, and issues deterministic opaque tokens; it does not
model authentication or authorisation. Not suitable for any purpose
other than driving the welcome-repo client during tests.
"""
from __future__ import annotations
import base64
import hashlib
import json
import sys
import threading
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import Any, Callable, cast
from urllib.parse import parse_qs, urlencode, urlparse
class _State:
"""In-memory bookkeeping shared by every handler instance."""
def __init__(self, *, base_url: str, username: str) -> None:
self.base_url = base_url
self.username = username
# code -> {client_id, redirect_uri, code_challenge, used}
self.pending_codes: dict[str, dict[str, Any]] = {}
# refresh_token -> {client_id, revoked}
self.refresh_tokens: dict[str, dict[str, Any]] = {}
self.access_token_expires_in = 3600
self.access_token_counter = 0
self.refresh_token_counter = 0
def issue_access_token(self) -> str:
self.access_token_counter += 1
return f"access-{self.access_token_counter}"
def issue_refresh_token(self, *, client_id: str) -> str:
self.refresh_token_counter += 1
tok = f"refresh-{self.refresh_token_counter}"
self.refresh_tokens[tok] = {"client_id": client_id, "revoked": False}
return tok
def _verify_pkce(challenge: str, verifier: str) -> bool:
expected = (
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
.rstrip(b"=")
.decode("ascii")
)
return expected == challenge
class _Handler(BaseHTTPRequestHandler):
state: _State
def _send_json(self, code: int, body: dict) -> None:
data = json.dumps(body).encode("utf-8")
self.send_response(code)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def _read_form(self) -> dict[str, str]:
length = int(self.headers.get("Content-Length", "0") or "0")
raw = self.rfile.read(length).decode("ascii") if length else ""
return {k: v[0] for k, v in parse_qs(raw, keep_blank_values=True).items()}
def do_GET(self) -> None: # noqa: N802
parsed = urlparse(self.path)
path = parsed.path
query = {k: v[0] for k, v in parse_qs(parsed.query).items()}
if path == "/.well-known/openid-configuration":
base = self.state.base_url
self._send_json(200, {
"issuer": base,
"authorization_endpoint": f"{base}/login/oauth/authorize",
"token_endpoint": f"{base}/login/oauth/access_token",
"userinfo_endpoint": f"{base}/login/oauth/userinfo",
})
return
if path == "/login/oauth/authorize":
client_id = query.get("client_id", "")
redirect_uri = query.get("redirect_uri", "")
code_challenge = query.get("code_challenge", "")
state_value = query.get("state", "")
if not (client_id and redirect_uri and code_challenge and state_value):
self._send_json(400, {"error": "invalid_request"})
return
code = f"code-{len(self.state.pending_codes) + 1}"
self.state.pending_codes[code] = {
"client_id": client_id,
"redirect_uri": redirect_uri,
"code_challenge": code_challenge,
"used": False,
}
sep = "&" if "?" in redirect_uri else "?"
location = (
f"{redirect_uri}{sep}"
f"{urlencode({'code': code, 'state': state_value})}"
)
self.send_response(302)
self.send_header("Location", location)
self.end_headers()
return
if path == "/login/oauth/userinfo":
auth = self.headers.get("Authorization", "")
if not auth.startswith("Bearer access-"):
self._send_json(401, {"error": "unauthorized"})
return
self._send_json(200, {
"sub": "1",
"preferred_username": self.state.username,
"name": self.state.username,
"email": f"{self.state.username}@example.test",
})
return
self._send_json(404, {"error": "not_found"})
def do_POST(self) -> None: # noqa: N802
if self.path != "/login/oauth/access_token":
self._send_json(404, {"error": "not_found"})
return
form = self._read_form()
grant = form.get("grant_type", "")
if grant == "authorization_code":
code = form.get("code", "")
verifier = form.get("code_verifier", "")
client_id = form.get("client_id", "")
entry = self.state.pending_codes.get(code)
if not entry or entry["used"]:
self._send_json(400, {"error": "invalid_grant",
"error_description": "code not found or already used"})
return
if entry["client_id"] != client_id:
self._send_json(400, {"error": "invalid_client"})
return
if not _verify_pkce(entry["code_challenge"], verifier):
self._send_json(400, {"error": "invalid_grant",
"error_description": "PKCE verification failed"})
return
entry["used"] = True
access = self.state.issue_access_token()
refresh = self.state.issue_refresh_token(client_id=client_id)
self._send_json(200, {
"access_token": access,
"refresh_token": refresh,
"expires_in": self.state.access_token_expires_in,
"token_type": "Bearer",
"scope": "openid profile email",
})
return
if grant == "refresh_token":
rt = form.get("refresh_token", "")
client_id = form.get("client_id", "")
entry = self.state.refresh_tokens.get(rt)
if not entry or entry["revoked"]:
self._send_json(400, {"error": "invalid_grant",
"error_description": "refresh token invalid or revoked"})
return
if entry["client_id"] != client_id:
self._send_json(400, {"error": "invalid_client"})
return
# Rotate: revoke the old refresh token, issue new pair.
entry["revoked"] = True
access = self.state.issue_access_token()
new_rt = self.state.issue_refresh_token(client_id=client_id)
self._send_json(200, {
"access_token": access,
"refresh_token": new_rt,
"expires_in": self.state.access_token_expires_in,
"token_type": "Bearer",
})
return
self._send_json(400, {"error": "unsupported_grant_type"})
def log_message(self, format: str, *args: Any) -> None: # noqa: A003
return
def make_server(*, username: str = "testuser") -> tuple[ThreadingHTTPServer, _State, str]:
# Bind to ephemeral port, then set base_url so the handler knows its URL.
server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler)
port = server.server_address[1]
base_url = f"http://127.0.0.1:{port}"
state = _State(base_url=base_url, username=username)
# Share the state across all handler instances via the class attr.
cast(type, _Handler).state = state # type: ignore[assignment]
return server, state, base_url
def serve_forever(server: ThreadingHTTPServer) -> threading.Thread:
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
return thread
def main() -> int:
import os
username = os.environ.get("MOCK_OIDC_USERNAME", "testuser")
server, state, base_url = make_server(username=username)
serve_forever(server)
sys.stdout.write(f"{base_url}\n")
sys.stdout.flush()
try:
# Block until killed.
while True:
time.sleep(3600)
except KeyboardInterrupt:
server.shutdown()
return 0
if __name__ == "__main__":
sys.exit(main())