RotatingJWTTokenService(
key_provider,
*,
alg=RS256,
base_kid=None,
create_spec=None,
default_issuer=None,
rotate_every_s=None,
max_tokens_per_key=None,
previous_key_ttl_s=86400,
)
Bases: TokenServiceBase
JWT issuer/verifier that rotates the signing key.
Rotation can be triggered by elapsed time (rotate_every_s
) or after a
maximum number of minted tokens (max_tokens_per_key
). Previous key
versions are retained for a configurable window so that older tokens remain
verifiable.
Create a rotating JWT token service.
key_provider (IKeyProvider): Backend used to store and rotate keys.
alg (JWAAlg): Signing algorithm for issued tokens.
base_kid (str | None): Existing key identifier to bootstrap from.
create_spec (KeySpec | None): Specification for creating the initial
key if one does not already exist.
default_issuer (str | None): Default iss
claim for minted tokens.
rotate_every_s (int | None): Seconds between automatic rotations.
max_tokens_per_key (int | None): Maximum tokens minted before forcing
rotation.
previous_key_ttl_s (int): Time in seconds to retain previous keys for
verification.
RETURNS (None): This constructor does not return anything.
Source code in swarmauri_tokens_rotatingjwt/rotating_jwt.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183 | def __init__(
self,
key_provider: IKeyProvider,
*,
alg: JWAAlg = JWAAlg.RS256,
base_kid: Optional[str] = None,
create_spec: Optional[KeySpec] = None,
default_issuer: Optional[str] = None,
rotate_every_s: Optional[int] = None,
max_tokens_per_key: Optional[int] = None,
previous_key_ttl_s: int = 86_400,
) -> None:
"""Create a rotating JWT token service.
key_provider (IKeyProvider): Backend used to store and rotate keys.
alg (JWAAlg): Signing algorithm for issued tokens.
base_kid (str | None): Existing key identifier to bootstrap from.
create_spec (KeySpec | None): Specification for creating the initial
key if one does not already exist.
default_issuer (str | None): Default ``iss`` claim for minted tokens.
rotate_every_s (int | None): Seconds between automatic rotations.
max_tokens_per_key (int | None): Maximum tokens minted before forcing
rotation.
previous_key_ttl_s (int): Time in seconds to retain previous keys for
verification.
RETURNS (None): This constructor does not return anything.
"""
super().__init__()
if alg not in _SIGN_ALGS:
raise ValueError(f"Unsupported alg: {alg.value}")
self._kp = key_provider
self._alg = alg
self._iss = default_issuer
self._rotate_every_s = int(rotate_every_s) if rotate_every_s else None
self._max_tokens = int(max_tokens_per_key) if max_tokens_per_key else None
self._prev_ttl = int(previous_key_ttl_s)
self._kid: str
self._ver: int
self._prev_versions: Dict[int, int] = {}
self._mint_count = 0
self._next_rotate_at: Optional[int] = None
self._init_signing_key(base_kid, create_spec)
|
type
class-attribute
instance-attribute
type = 'RotatingJWTTokenService'
current_signing_key
property
Return the current key identifier, version and algorithm.
RETURNS (Tuple[str, int, JWAAlg]): Details of the active signing key.
supports
Return the token formats and algorithms supported.
RETURNS (Dict[str, Iterable[JWAAlg]]): Mapping of supported formats and
algorithms.
Source code in swarmauri_tokens_rotatingjwt/rotating_jwt.py
188
189
190
191
192
193
194
195 | def supports(self) -> Dict[str, Iterable[JWAAlg]]:
"""Return the token formats and algorithms supported.
RETURNS (Dict[str, Iterable[JWAAlg]]): Mapping of supported formats and
algorithms.
"""
return {"formats": ("JWT", "JWS"), "algs": (self._alg,)}
|
mint
async
mint(
claims,
*,
alg,
kid=None,
key_version=None,
headers=None,
lifetime_s=3600,
issuer=None,
subject=None,
audience=None,
scope=None,
)
Generate a signed JWT.
claims (Dict[str, object]): Claims to include in the payload.
alg (JWAAlg): Algorithm used for signing; must match the service
configuration.
kid (str | None): Override key identifier to sign with.
key_version (int | None): Specific key version to sign with.
headers (Dict[str, object] | None): Additional headers to include.
lifetime_s (int | None): Lifetime of the token in seconds.
issuer (str | None): Issuer claim to set for the token.
subject (str | None): Subject claim to include.
audience (str | list[str] | None): Audience claim for the token.
scope (str | None): Optional scope value.
RETURNS (str): Encoded JWT token.
Source code in swarmauri_tokens_rotatingjwt/rotating_jwt.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265 | async def mint(
self,
claims: Dict[str, object],
*,
alg: JWAAlg,
kid: str | None = None,
key_version: int | None = None,
headers: Optional[Dict[str, object]] = None,
lifetime_s: Optional[int] = 3600,
issuer: Optional[str] = None,
subject: Optional[str] = None,
audience: Optional[str | list[str]] = None,
scope: Optional[str] = None,
) -> str:
"""Generate a signed JWT.
claims (Dict[str, object]): Claims to include in the payload.
alg (JWAAlg): Algorithm used for signing; must match the service
configuration.
kid (str | None): Override key identifier to sign with.
key_version (int | None): Specific key version to sign with.
headers (Dict[str, object] | None): Additional headers to include.
lifetime_s (int | None): Lifetime of the token in seconds.
issuer (str | None): Issuer claim to set for the token.
subject (str | None): Subject claim to include.
audience (str | list[str] | None): Audience claim for the token.
scope (str | None): Optional scope value.
RETURNS (str): Encoded JWT token.
"""
if alg != self._alg:
raise ValueError(
f"This service is configured for alg={self._alg.value}, got {alg.value}"
)
await self._maybe_rotate()
now = _now()
payload = dict(claims)
payload.setdefault("iat", now)
payload.setdefault("nbf", now)
if lifetime_s:
payload.setdefault("exp", now + int(lifetime_s))
if issuer or self._iss:
payload.setdefault("iss", issuer or self._iss)
if subject:
payload.setdefault("sub", subject)
if audience:
payload.setdefault("aud", audience)
if scope:
payload.setdefault("scope", scope)
hdr = dict(headers or {})
hdr["alg"] = self._alg.value
hdr["kid"] = f"{self._kid}.{self._ver}"
ref = await self._kp.get_key(self._kid, self._ver, include_secret=True)
if self._alg == JWAAlg.HS256:
if ref.material is None:
raise RuntimeError("HMAC secret is not exportable under current policy")
key = ref.material
else:
key = ref.material
if key is None:
raise RuntimeError("Signing key is not exportable under current policy")
token = jwt.encode(payload, key, algorithm=self._alg.value, headers=hdr)
self._mint_count += 1
return token
|
verify
async
verify(token, *, issuer=None, audience=None, leeway_s=60)
Validate a JWT and return its claims.
token (str): Encoded JWT to verify.
issuer (str | None): Expected issuer value.
audience (str | list[str] | None): Expected audience claim.
leeway_s (int): Allowable clock skew in seconds.
RETURNS (Dict[str, object]): Verified claims payload.
Source code in swarmauri_tokens_rotatingjwt/rotating_jwt.py
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336 | async def verify(
self,
token: str,
*,
issuer: Optional[str] = None,
audience: Optional[str | list[str]] = None,
leeway_s: int = 60,
) -> Dict[str, object]:
"""Validate a JWT and return its claims.
token (str): Encoded JWT to verify.
issuer (str | None): Expected issuer value.
audience (str | list[str] | None): Expected audience claim.
leeway_s (int): Allowable clock skew in seconds.
RETURNS (Dict[str, object]): Verified claims payload.
"""
try:
header = jwt.get_unverified_header(token)
except Exception as exc: # pragma: no cover - propagating original error
raise jwt.InvalidTokenError(f"Invalid JWS/JWT header: {exc}") from exc
header_kid = header.get("kid")
alg_val = header.get("alg")
if not header_kid or alg_val is None:
raise jwt.InvalidTokenError("Missing or unsupported kid/alg in header")
try:
alg = JWAAlg(alg_val)
except ValueError as exc:
raise jwt.InvalidTokenError(f"Unsupported alg: {alg_val}") from exc
if alg not in _SIGN_ALGS:
raise jwt.InvalidTokenError("Missing or unsupported kid/alg in header")
kid, ver = _parse_kid_ver(header_kid)
async def resolve_key() -> object | None:
if alg == JWAAlg.HS256:
ref = await self._kp.get_key(kid, ver, include_secret=True)
return ref.material
if ver is not None:
try:
jwk = await self._kp.get_public_jwk(kid, ver)
return _jwk_to_key(jwk)
except Exception:
pass
jwks = await self._kp.jwks()
for jwk in jwks.get("keys", []):
if jwk.get("kid") == header_kid:
return _jwk_to_key(jwk)
for jwk in jwks.get("keys", []):
if isinstance(jwk.get("kid"), str) and jwk["kid"].startswith(kid + "."):
return _jwk_to_key(jwk)
return None
key_obj = await resolve_key()
if key_obj is None:
raise jwt.InvalidTokenError("Unable to resolve verification key")
options = {"verify_aud": audience is not None}
return jwt.decode(
token,
key=key_obj,
algorithms=[alg.value],
audience=audience,
issuer=issuer or self._iss,
leeway=leeway_s,
options=options,
)
|
jwks
async
Return the JSON Web Key Set for verification.
RETURNS (dict): JWKS containing current and previous public keys.
Source code in swarmauri_tokens_rotatingjwt/rotating_jwt.py
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370 | async def jwks(self) -> dict:
"""Return the JSON Web Key Set for verification.
RETURNS (dict): JWKS containing current and previous public keys.
"""
base = await self._kp.jwks()
seen = {k.get("kid") for k in base.get("keys", []) if isinstance(k, dict)}
keys = list(base.get("keys", []))
current_kid = f"{self._kid}.{self._ver}"
try:
if current_kid not in seen:
keys.append(await self._kp.get_public_jwk(self._kid, self._ver))
seen.add(current_kid)
except Exception:
pass
now = _now()
expired = [v for v, until in self._prev_versions.items() if until <= now]
for v in expired:
self._prev_versions.pop(v, None)
for v in self._prev_versions.keys():
kidv = f"{self._kid}.{v}"
if kidv in seen:
continue
try:
keys.append(await self._kp.get_public_jwk(self._kid, v))
seen.add(kidv)
except Exception:
continue
return {"keys": keys}
|
force_rotate
async
Force immediate key rotation.
RETURNS (Tuple[str, int]): New key identifier and version.
Source code in swarmauri_tokens_rotatingjwt/rotating_jwt.py
426
427
428
429
430
431
432
433
434
435
436
437 | async def force_rotate(self) -> Tuple[str, int]:
"""Force immediate key rotation.
RETURNS (Tuple[str, int]): New key identifier and version.
"""
await self._maybe_rotate()
self._prev_versions[self._ver] = _now() + self._prev_ttl
ref = await self._kp.rotate_key(self._kid)
self._kid, self._ver = ref.kid, ref.version
self._schedule_next_rotation()
return self._kid, self._ver
|