231 lines
8.7 KiB
Python
Executable File
231 lines
8.7 KiB
Python
Executable File
"""Minimal OIDC + OAuth2 PKCE server used by the integration test.
|
|
|
|
Implements the subset of Gitea's `/.well-known/openid-configuration`,
|
|
`/login/oauth/authorize`, `/login/oauth/access_token`, and
|
|
`/login/oauth/userinfo` surface for the welcome-repo's `forge_auth.py`
|
|
to run an end-to-end login + refresh without touching a real Gitea.
|
|
|
|
Test fixture only. Binds to loopback, accepts any non-empty
|
|
`client_id`, and issues deterministic opaque tokens; it does not
|
|
model authentication or authorisation. Not suitable for any purpose
|
|
other than driving the welcome-repo client during tests.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import sys
|
|
import threading
|
|
import time
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from typing import Any, Callable, cast
|
|
from urllib.parse import parse_qs, urlencode, urlparse
|
|
|
|
|
|
class _State:
|
|
"""In-memory bookkeeping shared by every handler instance."""
|
|
|
|
def __init__(self, *, base_url: str, username: str) -> None:
|
|
self.base_url = base_url
|
|
self.username = username
|
|
# code -> {client_id, redirect_uri, code_challenge, used}
|
|
self.pending_codes: dict[str, dict[str, Any]] = {}
|
|
# refresh_token -> {client_id, revoked}
|
|
self.refresh_tokens: dict[str, dict[str, Any]] = {}
|
|
self.access_token_expires_in = 3600
|
|
self.access_token_counter = 0
|
|
self.refresh_token_counter = 0
|
|
|
|
def issue_access_token(self) -> str:
|
|
self.access_token_counter += 1
|
|
return f"access-{self.access_token_counter}"
|
|
|
|
def issue_refresh_token(self, *, client_id: str) -> str:
|
|
self.refresh_token_counter += 1
|
|
tok = f"refresh-{self.refresh_token_counter}"
|
|
self.refresh_tokens[tok] = {"client_id": client_id, "revoked": False}
|
|
return tok
|
|
|
|
|
|
def _verify_pkce(challenge: str, verifier: str) -> bool:
|
|
expected = (
|
|
base64.urlsafe_b64encode(hashlib.sha256(verifier.encode("ascii")).digest())
|
|
.rstrip(b"=")
|
|
.decode("ascii")
|
|
)
|
|
return expected == challenge
|
|
|
|
|
|
class _Handler(BaseHTTPRequestHandler):
|
|
state: _State
|
|
|
|
def _send_json(self, code: int, body: dict) -> None:
|
|
data = json.dumps(body).encode("utf-8")
|
|
self.send_response(code)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.send_header("Content-Length", str(len(data)))
|
|
self.end_headers()
|
|
self.wfile.write(data)
|
|
|
|
def _read_form(self) -> dict[str, str]:
|
|
length = int(self.headers.get("Content-Length", "0") or "0")
|
|
raw = self.rfile.read(length).decode("ascii") if length else ""
|
|
return {k: v[0] for k, v in parse_qs(raw, keep_blank_values=True).items()}
|
|
|
|
def do_GET(self) -> None: # noqa: N802
|
|
parsed = urlparse(self.path)
|
|
path = parsed.path
|
|
query = {k: v[0] for k, v in parse_qs(parsed.query).items()}
|
|
|
|
if path == "/.well-known/openid-configuration":
|
|
base = self.state.base_url
|
|
self._send_json(200, {
|
|
"issuer": base,
|
|
"authorization_endpoint": f"{base}/login/oauth/authorize",
|
|
"token_endpoint": f"{base}/login/oauth/access_token",
|
|
"userinfo_endpoint": f"{base}/login/oauth/userinfo",
|
|
})
|
|
return
|
|
|
|
if path == "/login/oauth/authorize":
|
|
client_id = query.get("client_id", "")
|
|
redirect_uri = query.get("redirect_uri", "")
|
|
code_challenge = query.get("code_challenge", "")
|
|
state_value = query.get("state", "")
|
|
if not (client_id and redirect_uri and code_challenge and state_value):
|
|
self._send_json(400, {"error": "invalid_request"})
|
|
return
|
|
code = f"code-{len(self.state.pending_codes) + 1}"
|
|
self.state.pending_codes[code] = {
|
|
"client_id": client_id,
|
|
"redirect_uri": redirect_uri,
|
|
"code_challenge": code_challenge,
|
|
"used": False,
|
|
}
|
|
sep = "&" if "?" in redirect_uri else "?"
|
|
location = (
|
|
f"{redirect_uri}{sep}"
|
|
f"{urlencode({'code': code, 'state': state_value})}"
|
|
)
|
|
self.send_response(302)
|
|
self.send_header("Location", location)
|
|
self.end_headers()
|
|
return
|
|
|
|
if path == "/login/oauth/userinfo":
|
|
auth = self.headers.get("Authorization", "")
|
|
if not auth.startswith("Bearer access-"):
|
|
self._send_json(401, {"error": "unauthorized"})
|
|
return
|
|
self._send_json(200, {
|
|
"sub": "1",
|
|
"preferred_username": self.state.username,
|
|
"name": self.state.username,
|
|
"email": f"{self.state.username}@example.test",
|
|
})
|
|
return
|
|
|
|
self._send_json(404, {"error": "not_found"})
|
|
|
|
def do_POST(self) -> None: # noqa: N802
|
|
if self.path != "/login/oauth/access_token":
|
|
self._send_json(404, {"error": "not_found"})
|
|
return
|
|
form = self._read_form()
|
|
grant = form.get("grant_type", "")
|
|
|
|
if grant == "authorization_code":
|
|
code = form.get("code", "")
|
|
verifier = form.get("code_verifier", "")
|
|
client_id = form.get("client_id", "")
|
|
entry = self.state.pending_codes.get(code)
|
|
if not entry or entry["used"]:
|
|
self._send_json(400, {"error": "invalid_grant",
|
|
"error_description": "code not found or already used"})
|
|
return
|
|
if entry["client_id"] != client_id:
|
|
self._send_json(400, {"error": "invalid_client"})
|
|
return
|
|
if not _verify_pkce(entry["code_challenge"], verifier):
|
|
self._send_json(400, {"error": "invalid_grant",
|
|
"error_description": "PKCE verification failed"})
|
|
return
|
|
entry["used"] = True
|
|
access = self.state.issue_access_token()
|
|
refresh = self.state.issue_refresh_token(client_id=client_id)
|
|
self._send_json(200, {
|
|
"access_token": access,
|
|
"refresh_token": refresh,
|
|
"expires_in": self.state.access_token_expires_in,
|
|
"token_type": "Bearer",
|
|
"scope": "openid profile email",
|
|
})
|
|
return
|
|
|
|
if grant == "refresh_token":
|
|
rt = form.get("refresh_token", "")
|
|
client_id = form.get("client_id", "")
|
|
entry = self.state.refresh_tokens.get(rt)
|
|
if not entry or entry["revoked"]:
|
|
self._send_json(400, {"error": "invalid_grant",
|
|
"error_description": "refresh token invalid or revoked"})
|
|
return
|
|
if entry["client_id"] != client_id:
|
|
self._send_json(400, {"error": "invalid_client"})
|
|
return
|
|
# Rotate: revoke the old refresh token, issue new pair.
|
|
entry["revoked"] = True
|
|
access = self.state.issue_access_token()
|
|
new_rt = self.state.issue_refresh_token(client_id=client_id)
|
|
self._send_json(200, {
|
|
"access_token": access,
|
|
"refresh_token": new_rt,
|
|
"expires_in": self.state.access_token_expires_in,
|
|
"token_type": "Bearer",
|
|
})
|
|
return
|
|
|
|
self._send_json(400, {"error": "unsupported_grant_type"})
|
|
|
|
def log_message(self, format: str, *args: Any) -> None: # noqa: A003
|
|
return
|
|
|
|
|
|
def make_server(*, username: str = "testuser") -> tuple[ThreadingHTTPServer, _State, str]:
|
|
# Bind to ephemeral port, then set base_url so the handler knows its URL.
|
|
server = ThreadingHTTPServer(("127.0.0.1", 0), _Handler)
|
|
port = server.server_address[1]
|
|
base_url = f"http://127.0.0.1:{port}"
|
|
state = _State(base_url=base_url, username=username)
|
|
# Share the state across all handler instances via the class attr.
|
|
cast(type, _Handler).state = state # type: ignore[assignment]
|
|
return server, state, base_url
|
|
|
|
|
|
def serve_forever(server: ThreadingHTTPServer) -> threading.Thread:
|
|
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
|
thread.start()
|
|
return thread
|
|
|
|
|
|
def main() -> int:
|
|
import os
|
|
|
|
username = os.environ.get("MOCK_OIDC_USERNAME", "testuser")
|
|
server, state, base_url = make_server(username=username)
|
|
serve_forever(server)
|
|
sys.stdout.write(f"{base_url}\n")
|
|
sys.stdout.flush()
|
|
try:
|
|
# Block until killed.
|
|
while True:
|
|
time.sleep(3600)
|
|
except KeyboardInterrupt:
|
|
server.shutdown()
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|