Initial Commit
This commit is contained in:
230
tests/mock_oidc_server.py
Executable file
230
tests/mock_oidc_server.py
Executable file
@@ -0,0 +1,230 @@
|
||||
"""Minimal OIDC + OAuth2 PKCE server used by the integration test.
|
||||
|
||||
Implements the subset of Gitea's `/.well-known/openid-configuration`,
|
||||
`/login/oauth/authorize`, `/login/oauth/access_token`, and
|
||||
`/login/oauth/userinfo` surface for the welcome-repo's `forge_auth.py`
|
||||
to run an end-to-end login + refresh without touching a real Gitea.
|
||||
|
||||
Test fixture only. Binds to loopback, accepts any non-empty
|
||||
`client_id`, and issues deterministic opaque tokens; it does not
|
||||
model authentication or authorisation. Not suitable for any purpose
|
||||
other than driving the welcome-repo client during tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from typing import Any, Callable, cast
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
|
||||
class _State:
|
||||
"""In-memory bookkeeping shared by every handler instance."""
|
||||
|
||||
def __init__(self, *, base_url: str, username: str) -> None:
|
||||
self.base_url = base_url
|
||||
self.username = username
|
||||
# code -> {client_id, redirect_uri, code_challenge, used}
|
||||
self.pending_codes: dict[str, dict[str, Any]] = {}
|
||||
# refresh_token -> {client_id, revoked}
|
||||
self.refresh_tokens: dict[str, dict[str, Any]] = {}
|
||||
self.access_token_expires_in = 3600
|
||||
self.access_token_counter = 0
|
||||
self.refresh_token_counter = 0
|
||||
|
||||
def issue_access_token(self) -> str:
|
||||
self.access_token_counter += 1
|
||||
return f"access-{self.access_token_counter}"
|
||||
|
||||
def issue_refresh_token(self, *, client_id: str) -> str:
|
||||
self.refresh_token_counter += 1
|
||||
tok = f"refresh-{self.refresh_token_counter}"
|
||||
self.refresh_tokens[tok] = {"client_id": client_id, "revoked": False}
|
||||
return tok
|
||||
|
||||
|
||||
def _verify_pkce(challenge: str, verifier: str) -> bool:
|
||||
expected = (
|
||||
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
||||
.rstrip(b"=")
|
||||
.decode("ascii")
|
||||
)
|
||||
return expected == challenge
|
||||
|
||||
|
||||
class _Handler(BaseHTTPRequestHandler):
|
||||
state: _State
|
||||
|
||||
def _send_json(self, code: int, body: dict) -> None:
|
||||
data = json.dumps(body).encode("utf-8")
|
||||
self.send_response(code)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.send_header("Content-Length", str(len(data)))
|
||||
self.end_headers()
|
||||
self.wfile.write(data)
|
||||
|
||||
def _read_form(self) -> dict[str, str]:
|
||||
length = int(self.headers.get("Content-Length", "0") or "0")
|
||||
raw = self.rfile.read(length).decode("ascii") if length else ""
|
||||
return {k: v[0] for k, v in parse_qs(raw, keep_blank_values=True).items()}
|
||||
|
||||
def do_GET(self) -> None: # noqa: N802
|
||||
parsed = urlparse(self.path)
|
||||
path = parsed.path
|
||||
query = {k: v[0] for k, v in parse_qs(parsed.query).items()}
|
||||
|
||||
if path == "/.well-known/openid-configuration":
|
||||
base = self.state.base_url
|
||||
self._send_json(200, {
|
||||
"issuer": base,
|
||||
"authorization_endpoint": f"{base}/login/oauth/authorize",
|
||||
"token_endpoint": f"{base}/login/oauth/access_token",
|
||||
"userinfo_endpoint": f"{base}/login/oauth/userinfo",
|
||||
})
|
||||
return
|
||||
|
||||
if path == "/login/oauth/authorize":
|
||||
client_id = query.get("client_id", "")
|
||||
redirect_uri = query.get("redirect_uri", "")
|
||||
code_challenge = query.get("code_challenge", "")
|
||||
state_value = query.get("state", "")
|
||||
if not (client_id and redirect_uri and code_challenge and state_value):
|
||||
self._send_json(400, {"error": "invalid_request"})
|
||||
return
|
||||
code = f"code-{len(self.state.pending_codes) + 1}"
|
||||
self.state.pending_codes[code] = {
|
||||
"client_id": client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_challenge": code_challenge,
|
||||
"used": False,
|
||||
}
|
||||
sep = "&" if "?" in redirect_uri else "?"
|
||||
location = (
|
||||
f"{redirect_uri}{sep}"
|
||||
f"{urlencode({'code': code, 'state': state_value})}"
|
||||
)
|
||||
self.send_response(302)
|
||||
self.send_header("Location", location)
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
if path == "/login/oauth/userinfo":
|
||||
auth = self.headers.get("Authorization", "")
|
||||
if not auth.startswith("Bearer access-"):
|
||||
self._send_json(401, {"error": "unauthorized"})
|
||||
return
|
||||
self._send_json(200, {
|
||||
"sub": "1",
|
||||
"preferred_username": self.state.username,
|
||||
"name": self.state.username,
|
||||
"email": f"{self.state.username}@example.test",
|
||||
})
|
||||
return
|
||||
|
||||
self._send_json(404, {"error": "not_found"})
|
||||
|
||||
def do_POST(self) -> None: # noqa: N802
|
||||
if self.path != "/login/oauth/access_token":
|
||||
self._send_json(404, {"error": "not_found"})
|
||||
return
|
||||
form = self._read_form()
|
||||
grant = form.get("grant_type", "")
|
||||
|
||||
if grant == "authorization_code":
|
||||
code = form.get("code", "")
|
||||
verifier = form.get("code_verifier", "")
|
||||
client_id = form.get("client_id", "")
|
||||
entry = self.state.pending_codes.get(code)
|
||||
if not entry or entry["used"]:
|
||||
self._send_json(400, {"error": "invalid_grant",
|
||||
"error_description": "code not found or already used"})
|
||||
return
|
||||
if entry["client_id"] != client_id:
|
||||
self._send_json(400, {"error": "invalid_client"})
|
||||
return
|
||||
if not _verify_pkce(entry["code_challenge"], verifier):
|
||||
self._send_json(400, {"error": "invalid_grant",
|
||||
"error_description": "PKCE verification failed"})
|
||||
return
|
||||
entry["used"] = True
|
||||
access = self.state.issue_access_token()
|
||||
refresh = self.state.issue_refresh_token(client_id=client_id)
|
||||
self._send_json(200, {
|
||||
"access_token": access,
|
||||
"refresh_token": refresh,
|
||||
"expires_in": self.state.access_token_expires_in,
|
||||
"token_type": "Bearer",
|
||||
"scope": "openid profile email",
|
||||
})
|
||||
return
|
||||
|
||||
if grant == "refresh_token":
|
||||
rt = form.get("refresh_token", "")
|
||||
client_id = form.get("client_id", "")
|
||||
entry = self.state.refresh_tokens.get(rt)
|
||||
if not entry or entry["revoked"]:
|
||||
self._send_json(400, {"error": "invalid_grant",
|
||||
"error_description": "refresh token invalid or revoked"})
|
||||
return
|
||||
if entry["client_id"] != client_id:
|
||||
self._send_json(400, {"error": "invalid_client"})
|
||||
return
|
||||
# Rotate: revoke the old refresh token, issue new pair.
|
||||
entry["revoked"] = True
|
||||
access = self.state.issue_access_token()
|
||||
new_rt = self.state.issue_refresh_token(client_id=client_id)
|
||||
self._send_json(200, {
|
||||
"access_token": access,
|
||||
"refresh_token": new_rt,
|
||||
"expires_in": self.state.access_token_expires_in,
|
||||
"token_type": "Bearer",
|
||||
})
|
||||
return
|
||||
|
||||
self._send_json(400, {"error": "unsupported_grant_type"})
|
||||
|
||||
def log_message(self, format: str, *args: Any) -> None: # noqa: A003
|
||||
return
|
||||
|
||||
|
||||
def make_server(*, username: str = "testuser") -> tuple[ThreadingHTTPServer, _State, str]:
|
||||
# Bind to ephemeral port, then set base_url so the handler knows its URL.
|
||||
server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler)
|
||||
port = server.server_address[1]
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
state = _State(base_url=base_url, username=username)
|
||||
# Share the state across all handler instances via the class attr.
|
||||
cast(type, _Handler).state = state # type: ignore[assignment]
|
||||
return server, state, base_url
|
||||
|
||||
|
||||
def serve_forever(server: ThreadingHTTPServer) -> threading.Thread:
|
||||
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
|
||||
def main() -> int:
|
||||
import os
|
||||
|
||||
username = os.environ.get("MOCK_OIDC_USERNAME", "testuser")
|
||||
server, state, base_url = make_server(username=username)
|
||||
serve_forever(server)
|
||||
sys.stdout.write(f"{base_url}\n")
|
||||
sys.stdout.flush()
|
||||
try:
|
||||
# Block until killed.
|
||||
while True:
|
||||
time.sleep(3600)
|
||||
except KeyboardInterrupt:
|
||||
server.shutdown()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user