316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""Unit tests for JWT token generation and validation."""
|
|
|
|
from datetime import datetime, timedelta
|
|
from uuid import UUID, uuid4
|
|
|
|
import pytest
|
|
from jose import jwt
|
|
|
|
from app.auth.jwt import create_access_token, decode_access_token
|
|
from app.core.config import settings
|
|
|
|
|
|
class TestCreateAccessToken:
|
|
"""Test JWT access token creation."""
|
|
|
|
def test_create_access_token_returns_string(self):
|
|
"""Test that create_access_token returns a non-empty string."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
|
|
assert isinstance(token, str)
|
|
assert len(token) > 0
|
|
|
|
def test_create_access_token_contains_user_data(self):
|
|
"""Test that token contains user ID and email."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
|
|
# Decode without verification to inspect payload
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
|
|
assert payload["sub"] == str(user_id)
|
|
assert payload["email"] == email
|
|
|
|
def test_create_access_token_contains_required_claims(self):
|
|
"""Test that token contains all required JWT claims."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
|
|
# Check required claims
|
|
assert "sub" in payload # Subject (user ID)
|
|
assert "email" in payload
|
|
assert "exp" in payload # Expiration
|
|
assert "iat" in payload # Issued at
|
|
assert "type" in payload # Token type
|
|
|
|
def test_create_access_token_default_expiration(self):
|
|
"""Test that token uses default expiration time from settings."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
before = datetime.utcnow()
|
|
token = create_access_token(user_id, email)
|
|
after = datetime.utcnow()
|
|
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
exp_timestamp = payload["exp"]
|
|
exp_datetime = datetime.fromtimestamp(exp_timestamp)
|
|
|
|
# Calculate expected expiration range
|
|
min_exp = before + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
max_exp = after + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
|
|
|
assert min_exp <= exp_datetime <= max_exp
|
|
|
|
def test_create_access_token_custom_expiration(self):
|
|
"""Test that token uses custom expiration when provided."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
custom_delta = timedelta(hours=2)
|
|
|
|
before = datetime.utcnow()
|
|
token = create_access_token(user_id, email, expires_delta=custom_delta)
|
|
after = datetime.utcnow()
|
|
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
exp_timestamp = payload["exp"]
|
|
exp_datetime = datetime.fromtimestamp(exp_timestamp)
|
|
|
|
min_exp = before + custom_delta
|
|
max_exp = after + custom_delta
|
|
|
|
assert min_exp <= exp_datetime <= max_exp
|
|
|
|
def test_create_access_token_type_is_access(self):
|
|
"""Test that token type is set to 'access'."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
|
|
assert payload["type"] == "access"
|
|
|
|
def test_create_access_token_different_users_different_tokens(self):
|
|
"""Test that different users get different tokens."""
|
|
user1_id = uuid4()
|
|
user2_id = uuid4()
|
|
email1 = "user1@example.com"
|
|
email2 = "user2@example.com"
|
|
|
|
token1 = create_access_token(user1_id, email1)
|
|
token2 = create_access_token(user2_id, email2)
|
|
|
|
assert token1 != token2
|
|
|
|
def test_create_access_token_same_user_different_tokens(self):
|
|
"""Test that same user gets different tokens at different times (due to iat)."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token1 = create_access_token(user_id, email)
|
|
# Wait a tiny bit to ensure different iat
|
|
import time
|
|
|
|
time.sleep(0.01)
|
|
token2 = create_access_token(user_id, email)
|
|
|
|
# Tokens should be different because iat (issued at) is different
|
|
assert token1 != token2
|
|
|
|
|
|
class TestDecodeAccessToken:
|
|
"""Test JWT access token decoding and validation."""
|
|
|
|
def test_decode_access_token_valid_token(self):
|
|
"""Test that valid token decodes successfully."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
payload = decode_access_token(token)
|
|
|
|
assert payload is not None
|
|
assert payload["sub"] == str(user_id)
|
|
assert payload["email"] == email
|
|
|
|
def test_decode_access_token_invalid_token(self):
|
|
"""Test that invalid token returns None."""
|
|
invalid_tokens = [
|
|
"invalid.token.here",
|
|
"not_a_jwt",
|
|
"",
|
|
"a.b.c.d.e", # Too many parts
|
|
]
|
|
|
|
for token in invalid_tokens:
|
|
payload = decode_access_token(token)
|
|
assert payload is None
|
|
|
|
def test_decode_access_token_wrong_secret(self):
|
|
"""Test that token signed with different secret fails."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
# Create token with different secret
|
|
wrong_payload = {"sub": str(user_id), "email": email, "exp": datetime.utcnow() + timedelta(minutes=30)}
|
|
wrong_token = jwt.encode(wrong_payload, "wrong_secret_key", algorithm=settings.ALGORITHM)
|
|
|
|
payload = decode_access_token(wrong_token)
|
|
assert payload is None
|
|
|
|
def test_decode_access_token_expired_token(self):
|
|
"""Test that expired token returns None."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
# Create token that expired 1 hour ago
|
|
expired_delta = timedelta(hours=-1)
|
|
token = create_access_token(user_id, email, expires_delta=expired_delta)
|
|
|
|
payload = decode_access_token(token)
|
|
assert payload is None
|
|
|
|
def test_decode_access_token_wrong_algorithm(self):
|
|
"""Test that token with wrong algorithm fails."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
# Create token with different algorithm
|
|
wrong_payload = {
|
|
"sub": str(user_id),
|
|
"email": email,
|
|
"exp": datetime.utcnow() + timedelta(minutes=30),
|
|
}
|
|
# Use HS512 instead of HS256
|
|
wrong_token = jwt.encode(wrong_payload, settings.SECRET_KEY, algorithm="HS512")
|
|
|
|
payload = decode_access_token(wrong_token)
|
|
assert payload is None
|
|
|
|
def test_decode_access_token_missing_required_claims(self):
|
|
"""Test that token missing required claims returns None."""
|
|
# Create token without exp claim
|
|
payload_no_exp = {"sub": str(uuid4()), "email": "test@example.com"}
|
|
token_no_exp = jwt.encode(payload_no_exp, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
|
|
|
# jose library will reject tokens without exp when validating
|
|
payload = decode_access_token(token_no_exp)
|
|
# This should still decode (jose doesn't require exp by default)
|
|
# But we document this behavior
|
|
assert payload is not None or payload is None # Depends on jose version
|
|
|
|
def test_decode_access_token_preserves_all_claims(self):
|
|
"""Test that all claims are preserved in decoded payload."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
payload = decode_access_token(token)
|
|
|
|
assert payload is not None
|
|
assert "sub" in payload
|
|
assert "email" in payload
|
|
assert "exp" in payload
|
|
assert "iat" in payload
|
|
assert "type" in payload
|
|
assert payload["type"] == "access"
|
|
|
|
|
|
class TestJWTSecurityProperties:
|
|
"""Test security properties of JWT implementation."""
|
|
|
|
def test_jwt_token_is_url_safe(self):
|
|
"""Test that JWT tokens are URL-safe."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
|
|
# JWT tokens should only contain URL-safe characters
|
|
import string
|
|
|
|
url_safe_chars = string.ascii_letters + string.digits + "-_."
|
|
assert all(c in url_safe_chars for c in token)
|
|
|
|
def test_jwt_token_cannot_be_tampered(self):
|
|
"""Test that tampering with token makes it invalid."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
|
|
# Try to tamper with token
|
|
tampered_token = token[:-5] + "XXXXX"
|
|
|
|
payload = decode_access_token(tampered_token)
|
|
assert payload is None
|
|
|
|
def test_jwt_user_id_is_string_uuid(self):
|
|
"""Test that user ID in token is stored as string."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
payload = decode_access_token(token)
|
|
|
|
assert payload is not None
|
|
assert isinstance(payload["sub"], str)
|
|
|
|
# Should be valid UUID string
|
|
parsed_uuid = UUID(payload["sub"])
|
|
assert parsed_uuid == user_id
|
|
|
|
def test_jwt_email_preserved_correctly(self):
|
|
"""Test that email is preserved with correct casing and format."""
|
|
user_id = uuid4()
|
|
test_emails = [
|
|
"test@example.com",
|
|
"Test.User@Example.COM",
|
|
"user+tag@domain.co.uk",
|
|
"first.last@sub.domain.org",
|
|
]
|
|
|
|
for email in test_emails:
|
|
token = create_access_token(user_id, email)
|
|
payload = decode_access_token(token)
|
|
|
|
assert payload is not None
|
|
assert payload["email"] == email
|
|
|
|
def test_jwt_expiration_is_timestamp(self):
|
|
"""Test that expiration is stored as Unix timestamp."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
payload = decode_access_token(token)
|
|
|
|
assert payload is not None
|
|
assert isinstance(payload["exp"], (int, float))
|
|
|
|
# Should be a reasonable timestamp (between 2020 and 2030)
|
|
assert 1577836800 < payload["exp"] < 1893456000
|
|
|
|
def test_jwt_iat_before_exp(self):
|
|
"""Test that issued-at time is before expiration time."""
|
|
user_id = uuid4()
|
|
email = "test@example.com"
|
|
|
|
token = create_access_token(user_id, email)
|
|
payload = decode_access_token(token)
|
|
|
|
assert payload is not None
|
|
assert payload["iat"] < payload["exp"]
|
|
|