"""Minimal OIDC + OAuth2 PKCE server used by the integration test. Implements just enough of Gitea's `/.well-known/openid-configuration`, `/login/oauth/authorize`, `/login/oauth/access_token`, and `/login/oauth/userinfo` surface for the welcome-repo's `forge_auth.py` to run an end-to-end login + refresh without touching a real Gitea. Test fixture only. Binds to loopback, accepts any non-empty `client_id`, and issues deterministic opaque tokens; it does not model authentication or authorisation. Not suitable for any purpose other than driving the welcome-repo client during tests. """ from __future__ import annotations import base64 import hashlib import json import sys import threading import time from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from typing import Any, Callable, cast from urllib.parse import parse_qs, urlencode, urlparse class _State: """In-memory bookkeeping shared by every handler instance.""" def __init__(self, *, base_url: str, username: str) -> None: self.base_url = base_url self.username = username # code -> {client_id, redirect_uri, code_challenge, used} self.pending_codes: dict[str, dict[str, Any]] = {} # refresh_token -> {client_id, revoked} self.refresh_tokens: dict[str, dict[str, Any]] = {} self.access_token_expires_in = 3600 self.access_token_counter = 0 self.refresh_token_counter = 0 def issue_access_token(self) -> str: self.access_token_counter += 1 return f"access-{self.access_token_counter}" def issue_refresh_token(self, *, client_id: str) -> str: self.refresh_token_counter += 1 tok = f"refresh-{self.refresh_token_counter}" self.refresh_tokens[tok] = {"client_id": client_id, "revoked": False} return tok def _verify_pkce(challenge: str, verifier: str) -> bool: expected = ( base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest()) .rstrip(b"=") .decode("ascii") ) return expected == challenge class _Handler(BaseHTTPRequestHandler): state: _State def _send_json(self, code: int, body: dict) -> None: data = json.dumps(body).encode("utf-8") self.send_response(code) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", str(len(data))) self.end_headers() self.wfile.write(data) def _read_form(self) -> dict[str, str]: length = int(self.headers.get("Content-Length", "0") or "0") raw = self.rfile.read(length).decode("ascii") if length else "" return {k: v[0] for k, v in parse_qs(raw, keep_blank_values=True).items()} def do_GET(self) -> None: # noqa: N802 parsed = urlparse(self.path) path = parsed.path query = {k: v[0] for k, v in parse_qs(parsed.query).items()} if path == "/.well-known/openid-configuration": base = self.state.base_url self._send_json(200, { "issuer": base, "authorization_endpoint": f"{base}/login/oauth/authorize", "token_endpoint": f"{base}/login/oauth/access_token", "userinfo_endpoint": f"{base}/login/oauth/userinfo", }) return if path == "/login/oauth/authorize": client_id = query.get("client_id", "") redirect_uri = query.get("redirect_uri", "") code_challenge = query.get("code_challenge", "") state_value = query.get("state", "") if not (client_id and redirect_uri and code_challenge and state_value): self._send_json(400, {"error": "invalid_request"}) return code = f"code-{len(self.state.pending_codes) + 1}" self.state.pending_codes[code] = { "client_id": client_id, "redirect_uri": redirect_uri, "code_challenge": code_challenge, "used": False, } sep = "&" if "?" in redirect_uri else "?" location = ( f"{redirect_uri}{sep}" f"{urlencode({'code': code, 'state': state_value})}" ) self.send_response(302) self.send_header("Location", location) self.end_headers() return if path == "/login/oauth/userinfo": auth = self.headers.get("Authorization", "") if not auth.startswith("Bearer access-"): self._send_json(401, {"error": "unauthorized"}) return self._send_json(200, { "sub": "1", "preferred_username": self.state.username, "name": self.state.username, "email": f"{self.state.username}@example.test", }) return self._send_json(404, {"error": "not_found"}) def do_POST(self) -> None: # noqa: N802 if self.path != "/login/oauth/access_token": self._send_json(404, {"error": "not_found"}) return form = self._read_form() grant = form.get("grant_type", "") if grant == "authorization_code": code = form.get("code", "") verifier = form.get("code_verifier", "") client_id = form.get("client_id", "") entry = self.state.pending_codes.get(code) if not entry or entry["used"]: self._send_json(400, {"error": "invalid_grant", "error_description": "code not found or already used"}) return if entry["client_id"] != client_id: self._send_json(400, {"error": "invalid_client"}) return if not _verify_pkce(entry["code_challenge"], verifier): self._send_json(400, {"error": "invalid_grant", "error_description": "PKCE verification failed"}) return entry["used"] = True access = self.state.issue_access_token() refresh = self.state.issue_refresh_token(client_id=client_id) self._send_json(200, { "access_token": access, "refresh_token": refresh, "expires_in": self.state.access_token_expires_in, "token_type": "Bearer", "scope": "openid profile email", }) return if grant == "refresh_token": rt = form.get("refresh_token", "") client_id = form.get("client_id", "") entry = self.state.refresh_tokens.get(rt) if not entry or entry["revoked"]: self._send_json(400, {"error": "invalid_grant", "error_description": "refresh token invalid or revoked"}) return if entry["client_id"] != client_id: self._send_json(400, {"error": "invalid_client"}) return # Rotate: revoke the old refresh token, issue new pair. entry["revoked"] = True access = self.state.issue_access_token() new_rt = self.state.issue_refresh_token(client_id=client_id) self._send_json(200, { "access_token": access, "refresh_token": new_rt, "expires_in": self.state.access_token_expires_in, "token_type": "Bearer", }) return self._send_json(400, {"error": "unsupported_grant_type"}) def log_message(self, format: str, *args: Any) -> None: # noqa: A003 return def make_server(*, username: str = "testuser") -> tuple[ThreadingHTTPServer, _State, str]: # Bind to ephemeral port, then set base_url so the handler knows its URL. server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler) port = server.server_address[1] base_url = f"http://127.0.0.1:{port}" state = _State(base_url=base_url, username=username) # Share the state across all handler instances via the class attr. cast(type, _Handler).state = state # type: ignore[assignment] return server, state, base_url def serve_forever(server: ThreadingHTTPServer) -> threading.Thread: thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() return thread def main() -> int: import os username = os.environ.get("MOCK_OIDC_USERNAME", "testuser") server, state, base_url = make_server(username=username) serve_forever(server) sys.stdout.write(f"{base_url}\n") sys.stdout.flush() try: # Block until killed. while True: time.sleep(3600) except KeyboardInterrupt: server.shutdown() return 0 if __name__ == "__main__": sys.exit(main())