diff --git a/backend/app/api/boards.py b/backend/app/api/boards.py new file mode 100644 index 0000000..cebfd93 --- /dev/null +++ b/backend/app/api/boards.py @@ -0,0 +1,180 @@ +"""Board management API endpoints.""" + +from typing import Annotated +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from sqlalchemy.orm import Session + +from app.boards.repository import BoardRepository +from app.boards.schemas import BoardCreate, BoardDetail, BoardSummary, BoardUpdate +from app.core.deps import get_current_user, get_db +from app.database.models.user import User + +router = APIRouter(prefix="/boards", tags=["boards"]) + + +@router.post("", response_model=BoardDetail, status_code=status.HTTP_201_CREATED) +def create_board( + board_data: BoardCreate, + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[Session, Depends(get_db)], +): + """ + Create a new board. + + Args: + board_data: Board creation data + current_user: Current authenticated user + db: Database session + + Returns: + Created board details + """ + repo = BoardRepository(db) + + board = repo.create_board( + user_id=current_user.id, + title=board_data.title, + description=board_data.description, + ) + + return BoardDetail.model_validate(board) + + +@router.get("", response_model=dict) +def list_boards( + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[Session, Depends(get_db)], + limit: Annotated[int, Query(ge=1, le=100)] = 50, + offset: Annotated[int, Query(ge=0)] = 0, +): + """ + List all boards for the current user. + + Args: + current_user: Current authenticated user + db: Database session + limit: Maximum number of boards to return + offset: Number of boards to skip + + Returns: + Dictionary with boards list, total count, limit, and offset + """ + repo = BoardRepository(db) + + boards, total = repo.get_user_boards(user_id=current_user.id, limit=limit, offset=offset) + + return { + "boards": [BoardSummary.model_validate(board) for board in boards], + "total": total, + "limit": limit, + "offset": offset, + } + + +@router.get("/{board_id}", response_model=BoardDetail) +def get_board( + board_id: UUID, + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[Session, Depends(get_db)], +): + """ + Get board details by ID. + + Args: + board_id: Board UUID + current_user: Current authenticated user + db: Database session + + Returns: + Board details + + Raises: + HTTPException: 404 if board not found or not owned by user + """ + repo = BoardRepository(db) + + board = repo.get_board_by_id(board_id=board_id, user_id=current_user.id) + + if not board: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Board {board_id} not found", + ) + + return BoardDetail.model_validate(board) + + +@router.patch("/{board_id}", response_model=BoardDetail) +def update_board( + board_id: UUID, + board_data: BoardUpdate, + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[Session, Depends(get_db)], +): + """ + Update board metadata. + + Args: + board_id: Board UUID + board_data: Board update data + current_user: Current authenticated user + db: Database session + + Returns: + Updated board details + + Raises: + HTTPException: 404 if board not found or not owned by user + """ + repo = BoardRepository(db) + + # Convert viewport_state to dict if provided + viewport_dict = None + if board_data.viewport_state: + viewport_dict = board_data.viewport_state.model_dump() + + board = repo.update_board( + board_id=board_id, + user_id=current_user.id, + title=board_data.title, + description=board_data.description, + viewport_state=viewport_dict, + ) + + if not board: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Board {board_id} not found", + ) + + return BoardDetail.model_validate(board) + + +@router.delete("/{board_id}", status_code=status.HTTP_204_NO_CONTENT) +def delete_board( + board_id: UUID, + current_user: Annotated[User, Depends(get_current_user)], + db: Annotated[Session, Depends(get_db)], +): + """ + Delete a board (soft delete). + + Args: + board_id: Board UUID + current_user: Current authenticated user + db: Database session + + Raises: + HTTPException: 404 if board not found or not owned by user + """ + repo = BoardRepository(db) + + success = repo.delete_board(board_id=board_id, user_id=current_user.id) + + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Board {board_id} not found", + ) diff --git a/backend/app/boards/__init__.py b/backend/app/boards/__init__.py new file mode 100644 index 0000000..70896d5 --- /dev/null +++ b/backend/app/boards/__init__.py @@ -0,0 +1 @@ +"""Boards module for board management.""" diff --git a/backend/app/boards/permissions.py b/backend/app/boards/permissions.py new file mode 100644 index 0000000..7f03975 --- /dev/null +++ b/backend/app/boards/permissions.py @@ -0,0 +1,29 @@ +"""Permission validation middleware for boards.""" + +from uuid import UUID + +from fastapi import HTTPException, status +from sqlalchemy.orm import Session + +from app.boards.repository import BoardRepository + + +def validate_board_ownership(board_id: UUID, user_id: UUID, db: Session) -> None: + """ + Validate that the user owns the board. + + Args: + board_id: Board UUID + user_id: User UUID + db: Database session + + Raises: + HTTPException: 404 if board not found or not owned by user + """ + repo = BoardRepository(db) + + if not repo.board_exists(board_id, user_id): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Board {board_id} not found or access denied", + ) diff --git a/backend/app/boards/repository.py b/backend/app/boards/repository.py new file mode 100644 index 0000000..f6484b0 --- /dev/null +++ b/backend/app/boards/repository.py @@ -0,0 +1,197 @@ +"""Board repository for database operations.""" + +from collections.abc import Sequence +from uuid import UUID + +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from app.database.models.board import Board +from app.database.models.board_image import BoardImage + + +class BoardRepository: + """Repository for Board database operations.""" + + def __init__(self, db: Session): + """ + Initialize repository with database session. + + Args: + db: SQLAlchemy database session + """ + self.db = db + + def create_board( + self, + user_id: UUID, + title: str, + description: str | None = None, + viewport_state: dict | None = None, + ) -> Board: + """ + Create a new board. + + Args: + user_id: Owner's user ID + title: Board title + description: Optional board description + viewport_state: Optional custom viewport state + + Returns: + Created Board instance + """ + if viewport_state is None: + viewport_state = {"x": 0, "y": 0, "zoom": 1.0, "rotation": 0} + + board = Board( + user_id=user_id, + title=title, + description=description, + viewport_state=viewport_state, + ) + + self.db.add(board) + self.db.commit() + self.db.refresh(board) + + return board + + def get_board_by_id(self, board_id: UUID, user_id: UUID) -> Board | None: + """ + Get board by ID for a specific user. + + Args: + board_id: Board UUID + user_id: User UUID (for ownership check) + + Returns: + Board if found and owned by user, None otherwise + """ + stmt = select(Board).where( + Board.id == board_id, + Board.user_id == user_id, + Board.is_deleted == False, # noqa: E712 + ) + + return self.db.execute(stmt).scalar_one_or_none() + + def get_user_boards( + self, + user_id: UUID, + limit: int = 50, + offset: int = 0, + ) -> tuple[Sequence[Board], int]: + """ + Get all boards for a user with pagination. + + Args: + user_id: User UUID + limit: Maximum number of boards to return + offset: Number of boards to skip + + Returns: + Tuple of (list of boards, total count) + """ + # Query for boards with image count + stmt = ( + select(Board, func.count(BoardImage.id).label("image_count")) + .outerjoin(BoardImage, Board.id == BoardImage.board_id) + .where(Board.user_id == user_id, Board.is_deleted == False) # noqa: E712 + .group_by(Board.id) + .order_by(Board.updated_at.desc()) + .limit(limit) + .offset(offset) + ) + + results = self.db.execute(stmt).all() + boards = [row[0] for row in results] + + # Get total count + count_stmt = select(func.count(Board.id)).where(Board.user_id == user_id, Board.is_deleted == False) # noqa: E712 + + total = self.db.execute(count_stmt).scalar_one() + + return boards, total + + def update_board( + self, + board_id: UUID, + user_id: UUID, + title: str | None = None, + description: str | None = None, + viewport_state: dict | None = None, + ) -> Board | None: + """ + Update board metadata. + + Args: + board_id: Board UUID + user_id: User UUID (for ownership check) + title: New title (if provided) + description: New description (if provided) + viewport_state: New viewport state (if provided) + + Returns: + Updated Board if found and owned by user, None otherwise + """ + board = self.get_board_by_id(board_id, user_id) + + if not board: + return None + + if title is not None: + board.title = title + + if description is not None: + board.description = description + + if viewport_state is not None: + board.viewport_state = viewport_state + + self.db.commit() + self.db.refresh(board) + + return board + + def delete_board(self, board_id: UUID, user_id: UUID) -> bool: + """ + Soft delete a board. + + Args: + board_id: Board UUID + user_id: User UUID (for ownership check) + + Returns: + True if deleted, False if not found or not owned + """ + board = self.get_board_by_id(board_id, user_id) + + if not board: + return False + + board.is_deleted = True + self.db.commit() + + return True + + def board_exists(self, board_id: UUID, user_id: UUID) -> bool: + """ + Check if board exists and is owned by user. + + Args: + board_id: Board UUID + user_id: User UUID + + Returns: + True if board exists and is owned by user + """ + stmt = select(func.count(Board.id)).where( + Board.id == board_id, + Board.user_id == user_id, + Board.is_deleted == False, # noqa: E712 + ) + + count = self.db.execute(stmt).scalar_one() + + return count > 0 diff --git a/backend/app/boards/schemas.py b/backend/app/boards/schemas.py new file mode 100644 index 0000000..f3a31b0 --- /dev/null +++ b/backend/app/boards/schemas.py @@ -0,0 +1,67 @@ +"""Board Pydantic schemas for request/response validation.""" + +from datetime import datetime +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field, field_validator + + +class ViewportState(BaseModel): + """Viewport state for canvas position and zoom.""" + + x: float = Field(default=0, description="Horizontal pan position") + y: float = Field(default=0, description="Vertical pan position") + zoom: float = Field(default=1.0, ge=0.1, le=5.0, description="Zoom level (0.1 to 5.0)") + rotation: float = Field(default=0, ge=0, le=360, description="Canvas rotation in degrees (0 to 360)") + + +class BoardCreate(BaseModel): + """Schema for creating a new board.""" + + title: str = Field(..., min_length=1, max_length=255, description="Board title") + description: str | None = Field(default=None, description="Optional board description") + + +class BoardUpdate(BaseModel): + """Schema for updating board metadata.""" + + title: str | None = Field(None, min_length=1, max_length=255, description="Board title") + description: str | None = Field(None, description="Board description") + viewport_state: ViewportState | None = Field(None, description="Viewport state") + + +class BoardSummary(BaseModel): + """Summary schema for board list view.""" + + model_config = ConfigDict(from_attributes=True) + + id: UUID + title: str + description: str | None = None + image_count: int = Field(default=0, description="Number of images on board") + thumbnail_url: str | None = Field(default=None, description="URL to board thumbnail") + created_at: datetime + updated_at: datetime + + +class BoardDetail(BaseModel): + """Detailed schema for single board view with all data.""" + + model_config = ConfigDict(from_attributes=True) + + id: UUID + user_id: UUID + title: str + description: str | None = None + viewport_state: ViewportState + created_at: datetime + updated_at: datetime + is_deleted: bool = False + + @field_validator("viewport_state", mode="before") + @classmethod + def convert_viewport_state(cls, v): + """Convert dict to ViewportState if needed.""" + if isinstance(v, dict): + return ViewportState(**v) + return v diff --git a/backend/app/database/models/board.py b/backend/app/database/models/board.py index 055926b..8321d7d 100644 --- a/backend/app/database/models/board.py +++ b/backend/app/database/models/board.py @@ -1,35 +1,62 @@ -"""Board model for reference boards.""" +"""Board database model.""" -import uuid from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, String, Text -from sqlalchemy.dialects.postgresql import JSONB, UUID -from sqlalchemy.orm import relationship +from sqlalchemy import Boolean, DateTime, ForeignKey, String, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import UUID as PGUUID +from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database.base import Base +if TYPE_CHECKING: + from app.database.models.board_image import BoardImage + from app.database.models.group import Group + from app.database.models.share_link import ShareLink + from app.database.models.user import User + class Board(Base): - """Board model representing a reference board.""" + """ + Board model representing a reference board (canvas) containing images. + + A board is owned by a user and contains images arranged on an infinite canvas + with a specific viewport state (zoom, pan, rotation). + """ __tablename__ = "boards" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) - title = Column(String(255), nullable=False) - description = Column(Text, nullable=True) - viewport_state = Column(JSONB, nullable=False, default={"x": 0, "y": 0, "zoom": 1.0, "rotation": 0}) - created_at = Column(DateTime, nullable=False, default=datetime.utcnow) - updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow) - is_deleted = Column(Boolean, nullable=False, default=False) + id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4) + user_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + title: Mapped[str] = mapped_column(String(255), nullable=False) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + + viewport_state: Mapped[dict] = mapped_column( + JSONB, + nullable=False, + default=lambda: {"x": 0, "y": 0, "zoom": 1.0, "rotation": 0}, + ) + + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow + ) + is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # Relationships - user = relationship("User", back_populates="boards") - board_images = relationship("BoardImage", back_populates="board", cascade="all, delete-orphan") - groups = relationship("Group", back_populates="board", cascade="all, delete-orphan") - share_links = relationship("ShareLink", back_populates="board", cascade="all, delete-orphan") - comments = relationship("Comment", back_populates="board", cascade="all, delete-orphan") + user: Mapped["User"] = relationship("User", back_populates="boards") + board_images: Mapped[list["BoardImage"]] = relationship( + "BoardImage", back_populates="board", cascade="all, delete-orphan" + ) + groups: Mapped[list["Group"]] = relationship("Group", back_populates="board", cascade="all, delete-orphan") + share_links: Mapped[list["ShareLink"]] = relationship( + "ShareLink", back_populates="board", cascade="all, delete-orphan" + ) def __repr__(self) -> str: - return f"" + """String representation of Board.""" + return f"" diff --git a/backend/app/database/models/board_image.py b/backend/app/database/models/board_image.py index 57db565..a996e83 100644 --- a/backend/app/database/models/board_image.py +++ b/backend/app/database/models/board_image.py @@ -1,28 +1,44 @@ -"""BoardImage junction model.""" +"""BoardImage database model - junction table for boards and images.""" -import uuid from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 -from sqlalchemy import Column, DateTime, ForeignKey, Integer, UniqueConstraint -from sqlalchemy.dialects.postgresql import JSONB, UUID -from sqlalchemy.orm import relationship +from sqlalchemy import DateTime, ForeignKey, Integer +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import UUID as PGUUID +from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database.base import Base +if TYPE_CHECKING: + from app.database.models.board import Board + from app.database.models.group import Group + from app.database.models.image import Image + class BoardImage(Base): - """Junction table connecting boards and images with position/transformation data.""" + """ + BoardImage model - junction table connecting boards and images. + + Stores position, transformations, and z-order for each image on a board. + """ __tablename__ = "board_images" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - board_id = Column(UUID(as_uuid=True), ForeignKey("boards.id", ondelete="CASCADE"), nullable=False, index=True) - image_id = Column(UUID(as_uuid=True), ForeignKey("images.id", ondelete="CASCADE"), nullable=False, index=True) - position = Column(JSONB, nullable=False) - transformations = Column( + id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4) + board_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("boards.id", ondelete="CASCADE"), nullable=False + ) + image_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("images.id", ondelete="CASCADE"), nullable=False + ) + + position: Mapped[dict] = mapped_column(JSONB, nullable=False) + transformations: Mapped[dict] = mapped_column( JSONB, nullable=False, - default={ + default=lambda: { "scale": 1.0, "rotation": 0, "opacity": 1.0, @@ -31,17 +47,21 @@ class BoardImage(Base): "greyscale": False, }, ) - z_order = Column(Integer, nullable=False, default=0, index=True) - group_id = Column(UUID(as_uuid=True), ForeignKey("groups.id", ondelete="SET NULL"), nullable=True, index=True) - created_at = Column(DateTime, nullable=False, default=datetime.utcnow) - updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow) + z_order: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + group_id: Mapped[UUID | None] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("groups.id", ondelete="SET NULL"), nullable=True + ) - __table_args__ = (UniqueConstraint("board_id", "image_id", name="uq_board_image"),) + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow + ) # Relationships - board = relationship("Board", back_populates="board_images") - image = relationship("Image", back_populates="board_images") - group = relationship("Group", back_populates="board_images") + board: Mapped["Board"] = relationship("Board", back_populates="board_images") + image: Mapped["Image"] = relationship("Image", back_populates="board_images") + group: Mapped["Group | None"] = relationship("Group", back_populates="board_images") def __repr__(self) -> str: - return f"" + """String representation of BoardImage.""" + return f"" diff --git a/backend/app/database/models/group.py b/backend/app/database/models/group.py index a9a9387..fced044 100644 --- a/backend/app/database/models/group.py +++ b/backend/app/database/models/group.py @@ -1,31 +1,47 @@ -"""Group model for image grouping.""" +"""Group database model.""" -import uuid from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 -from sqlalchemy import Column, DateTime, ForeignKey, String, Text -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship +from sqlalchemy import DateTime, ForeignKey, String, Text +from sqlalchemy.dialects.postgresql import UUID as PGUUID +from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database.base import Base +if TYPE_CHECKING: + from app.database.models.board import Board + from app.database.models.board_image import BoardImage + class Group(Base): - """Group model for organizing images with annotations.""" + """ + Group model for organizing images with labels and annotations. + + Groups contain multiple images that can be moved together and have + shared visual indicators (color, annotation text). + """ __tablename__ = "groups" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - board_id = Column(UUID(as_uuid=True), ForeignKey("boards.id", ondelete="CASCADE"), nullable=False, index=True) - name = Column(String(255), nullable=False) - color = Column(String(7), nullable=False) # Hex color #RRGGBB - annotation = Column(Text, nullable=True) - created_at = Column(DateTime, nullable=False, default=datetime.utcnow) - updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow) + id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4) + board_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("boards.id", ondelete="CASCADE"), nullable=False + ) + name: Mapped[str] = mapped_column(String(255), nullable=False) + color: Mapped[str] = mapped_column(String(7), nullable=False) # Hex color #RRGGBB + annotation: Mapped[str | None] = mapped_column(Text, nullable=True) + + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow) + updated_at: Mapped[datetime] = mapped_column( + DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow + ) # Relationships - board = relationship("Board", back_populates="groups") - board_images = relationship("BoardImage", back_populates="group") + board: Mapped["Board"] = relationship("Board", back_populates="groups") + board_images: Mapped[list["BoardImage"]] = relationship("BoardImage", back_populates="group") def __repr__(self) -> str: - return f"" + """String representation of Group.""" + return f"" diff --git a/backend/app/database/models/image.py b/backend/app/database/models/image.py index 1e37e53..0ad8010 100644 --- a/backend/app/database/models/image.py +++ b/backend/app/database/models/image.py @@ -1,35 +1,52 @@ -"""Image model for uploaded images.""" +"""Image database model.""" -import uuid from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 -from sqlalchemy import BigInteger, Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.dialects.postgresql import JSONB, UUID -from sqlalchemy.orm import relationship +from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import UUID as PGUUID +from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database.base import Base +if TYPE_CHECKING: + from app.database.models.board_image import BoardImage + from app.database.models.user import User + class Image(Base): - """Image model representing uploaded image files.""" + """ + Image model representing uploaded image files. + + Images are stored in MinIO and can be reused across multiple boards. + Reference counting tracks how many boards use each image. + """ __tablename__ = "images" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - user_id = Column(UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True) - filename = Column(String(255), nullable=False, index=True) - storage_path = Column(String(512), nullable=False) - file_size = Column(BigInteger, nullable=False) - mime_type = Column(String(100), nullable=False) - width = Column(Integer, nullable=False) - height = Column(Integer, nullable=False) - image_metadata = Column(JSONB, nullable=False) - created_at = Column(DateTime, nullable=False, default=datetime.utcnow) - reference_count = Column(Integer, nullable=False, default=0) + id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4) + user_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + filename: Mapped[str] = mapped_column(String(255), nullable=False) + storage_path: Mapped[str] = mapped_column(String(512), nullable=False) + file_size: Mapped[int] = mapped_column(BigInteger, nullable=False) + mime_type: Mapped[str] = mapped_column(String(100), nullable=False) + width: Mapped[int] = mapped_column(Integer, nullable=False) + height: Mapped[int] = mapped_column(Integer, nullable=False) + metadata: Mapped[dict] = mapped_column(JSONB, nullable=False) + + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow) + reference_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) # Relationships - user = relationship("User", back_populates="images") - board_images = relationship("BoardImage", back_populates="image", cascade="all, delete-orphan") + user: Mapped["User"] = relationship("User", back_populates="images") + board_images: Mapped[list["BoardImage"]] = relationship( + "BoardImage", back_populates="image", cascade="all, delete-orphan" + ) def __repr__(self) -> str: - return f"" + """String representation of Image.""" + return f"" diff --git a/backend/app/database/models/share_link.py b/backend/app/database/models/share_link.py index 3bf5cbb..4729cda 100644 --- a/backend/app/database/models/share_link.py +++ b/backend/app/database/models/share_link.py @@ -1,33 +1,45 @@ -"""ShareLink model for board sharing.""" +"""ShareLink database model.""" -import uuid from datetime import datetime +from typing import TYPE_CHECKING +from uuid import UUID, uuid4 -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.orm import relationship +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String +from sqlalchemy.dialects.postgresql import UUID as PGUUID +from sqlalchemy.orm import Mapped, mapped_column, relationship from app.database.base import Base +if TYPE_CHECKING: + from app.database.models.board import Board + class ShareLink(Base): - """ShareLink model for sharing boards with permission control.""" + """ + ShareLink model for sharing boards with configurable permissions. + + Share links allow users to share boards with others without requiring + authentication, with permission levels controlling what actions are allowed. + """ __tablename__ = "share_links" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - board_id = Column(UUID(as_uuid=True), ForeignKey("boards.id", ondelete="CASCADE"), nullable=False, index=True) - token = Column(String(64), unique=True, nullable=False, index=True) - permission_level = Column(String(20), nullable=False) # 'view-only' or 'view-comment' - created_at = Column(DateTime, nullable=False, default=datetime.utcnow) - expires_at = Column(DateTime, nullable=True) - last_accessed_at = Column(DateTime, nullable=True) - access_count = Column(Integer, nullable=False, default=0) - is_revoked = Column(Boolean, nullable=False, default=False, index=True) + id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True, default=uuid4) + board_id: Mapped[UUID] = mapped_column( + PGUUID(as_uuid=True), ForeignKey("boards.id", ondelete="CASCADE"), nullable=False + ) + token: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) + permission_level: Mapped[str] = mapped_column(String(20), nullable=False) # 'view-only' or 'view-comment' + + created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime.utcnow) + expires_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + last_accessed_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) + access_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + is_revoked: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # Relationships - board = relationship("Board", back_populates="share_links") - comments = relationship("Comment", back_populates="share_link") + board: Mapped["Board"] = relationship("Board", back_populates="share_links") def __repr__(self) -> str: - return f"" + """String representation of ShareLink.""" + return f"" diff --git a/backend/app/main.py b/backend/app/main.py index 887aad1..1ef9caa 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -5,7 +5,7 @@ import logging from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from app.api import auth +from app.api import auth, boards from app.core.config import settings from app.core.errors import WebRefException from app.core.logging import setup_logging @@ -83,9 +83,9 @@ async def root(): # API routers app.include_router(auth.router, prefix=f"{settings.API_V1_PREFIX}") +app.include_router(boards.router, prefix=f"{settings.API_V1_PREFIX}") # Additional routers will be added in subsequent phases -# from app.api import boards, images -# app.include_router(boards.router, prefix=f"{settings.API_V1_PREFIX}") +# from app.api import images # app.include_router(images.router, prefix=f"{settings.API_V1_PREFIX}") diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py new file mode 100644 index 0000000..0208c39 --- /dev/null +++ b/backend/tests/__init__.py @@ -0,0 +1,2 @@ +"""Test package for Reference Board Viewer backend.""" + diff --git a/backend/tests/api/__init__.py b/backend/tests/api/__init__.py new file mode 100644 index 0000000..f08f274 --- /dev/null +++ b/backend/tests/api/__init__.py @@ -0,0 +1,2 @@ +"""API endpoint tests.""" + diff --git a/backend/tests/api/test_auth.py b/backend/tests/api/test_auth.py new file mode 100644 index 0000000..613c3a0 --- /dev/null +++ b/backend/tests/api/test_auth.py @@ -0,0 +1,365 @@ +"""Integration tests for authentication endpoints.""" + +import pytest +from fastapi import status +from fastapi.testclient import TestClient + + +class TestRegisterEndpoint: + """Test POST /auth/register endpoint.""" + + def test_register_user_success(self, client: TestClient, test_user_data: dict): + """Test successful user registration.""" + response = client.post("/api/v1/auth/register", json=test_user_data) + + assert response.status_code == status.HTTP_201_CREATED + + data = response.json() + assert "id" in data + assert data["email"] == test_user_data["email"] + assert "password" not in data # Password should not be returned + assert "password_hash" not in data + assert "created_at" in data + + def test_register_user_duplicate_email(self, client: TestClient, test_user_data: dict): + """Test that duplicate email registration fails.""" + # Register first user + response1 = client.post("/api/v1/auth/register", json=test_user_data) + assert response1.status_code == status.HTTP_201_CREATED + + # Try to register with same email + response2 = client.post("/api/v1/auth/register", json=test_user_data) + + assert response2.status_code == status.HTTP_409_CONFLICT + assert "already registered" in response2.json()["detail"].lower() + + def test_register_user_weak_password(self, client: TestClient, test_user_data_weak_password: dict): + """Test that weak password is rejected.""" + response = client.post("/api/v1/auth/register", json=test_user_data_weak_password) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "password" in response.json()["detail"].lower() + + def test_register_user_no_uppercase(self, client: TestClient, test_user_data_no_uppercase: dict): + """Test that password without uppercase is rejected.""" + response = client.post("/api/v1/auth/register", json=test_user_data_no_uppercase) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "uppercase" in response.json()["detail"].lower() + + def test_register_user_no_lowercase(self, client: TestClient): + """Test that password without lowercase is rejected.""" + user_data = {"email": "test@example.com", "password": "TESTPASSWORD123"} + response = client.post("/api/v1/auth/register", json=user_data) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "lowercase" in response.json()["detail"].lower() + + def test_register_user_no_number(self, client: TestClient): + """Test that password without number is rejected.""" + user_data = {"email": "test@example.com", "password": "TestPassword"} + response = client.post("/api/v1/auth/register", json=user_data) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "number" in response.json()["detail"].lower() + + def test_register_user_too_short(self, client: TestClient): + """Test that password shorter than 8 characters is rejected.""" + user_data = {"email": "test@example.com", "password": "Test123"} + response = client.post("/api/v1/auth/register", json=user_data) + + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert "8 characters" in response.json()["detail"].lower() + + def test_register_user_invalid_email(self, client: TestClient): + """Test that invalid email format is rejected.""" + invalid_emails = [ + {"email": "not-an-email", "password": "TestPassword123"}, + {"email": "missing@domain", "password": "TestPassword123"}, + {"email": "@example.com", "password": "TestPassword123"}, + {"email": "user@", "password": "TestPassword123"}, + ] + + for user_data in invalid_emails: + response = client.post("/api/v1/auth/register", json=user_data) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_register_user_missing_fields(self, client: TestClient): + """Test that missing required fields are rejected.""" + # Missing email + response1 = client.post("/api/v1/auth/register", json={"password": "TestPassword123"}) + assert response1.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + # Missing password + response2 = client.post("/api/v1/auth/register", json={"email": "test@example.com"}) + assert response2.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + # Empty body + response3 = client.post("/api/v1/auth/register", json={}) + assert response3.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_register_user_email_case_handling(self, client: TestClient): + """Test email case handling in registration.""" + user_data_upper = {"email": "TEST@EXAMPLE.COM", "password": "TestPassword123"} + + response = client.post("/api/v1/auth/register", json=user_data_upper) + + assert response.status_code == status.HTTP_201_CREATED + # Email should be stored as lowercase + data = response.json() + assert data["email"] == "test@example.com" + + +class TestLoginEndpoint: + """Test POST /auth/login endpoint.""" + + def test_login_user_success(self, client: TestClient, test_user_data: dict): + """Test successful user login.""" + # Register user first + client.post("/api/v1/auth/register", json=test_user_data) + + # Login + response = client.post("/api/v1/auth/login", json=test_user_data) + + assert response.status_code == status.HTTP_200_OK + + data = response.json() + assert "access_token" in data + assert data["token_type"] == "bearer" + assert "user" in data + assert data["user"]["email"] == test_user_data["email"] + + def test_login_user_wrong_password(self, client: TestClient, test_user_data: dict): + """Test that wrong password fails login.""" + # Register user + client.post("/api/v1/auth/register", json=test_user_data) + + # Try to login with wrong password + wrong_data = {"email": test_user_data["email"], "password": "WrongPassword123"} + response = client.post("/api/v1/auth/login", json=wrong_data) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + assert "WWW-Authenticate" in response.headers + assert response.headers["WWW-Authenticate"] == "Bearer" + + def test_login_user_nonexistent_email(self, client: TestClient): + """Test that login with nonexistent email fails.""" + login_data = {"email": "nonexistent@example.com", "password": "TestPassword123"} + response = client.post("/api/v1/auth/login", json=login_data) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_login_user_case_sensitive_password(self, client: TestClient, test_user_data: dict): + """Test that password is case-sensitive.""" + # Register user + client.post("/api/v1/auth/register", json=test_user_data) + + # Try to login with different case + wrong_case = {"email": test_user_data["email"], "password": test_user_data["password"].lower()} + response = client.post("/api/v1/auth/login", json=wrong_case) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_login_user_email_case_insensitive(self, client: TestClient, test_user_data: dict): + """Test that email login is case-insensitive.""" + # Register user + client.post("/api/v1/auth/register", json=test_user_data) + + # Login with different email case + upper_email = {"email": test_user_data["email"].upper(), "password": test_user_data["password"]} + response = client.post("/api/v1/auth/login", json=upper_email) + + assert response.status_code == status.HTTP_200_OK + + def test_login_user_missing_fields(self, client: TestClient): + """Test that missing fields are rejected.""" + # Missing password + response1 = client.post("/api/v1/auth/login", json={"email": "test@example.com"}) + assert response1.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + # Missing email + response2 = client.post("/api/v1/auth/login", json={"password": "TestPassword123"}) + assert response2.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + def test_login_user_token_format(self, client: TestClient, test_user_data: dict): + """Test that returned token is valid JWT format.""" + # Register and login + client.post("/api/v1/auth/register", json=test_user_data) + response = client.post("/api/v1/auth/login", json=test_user_data) + + assert response.status_code == status.HTTP_200_OK + + data = response.json() + token = data["access_token"] + + # JWT should have 3 parts separated by dots + parts = token.split(".") + assert len(parts) == 3 + + # Each part should be base64-encoded (URL-safe) + import string + + url_safe = string.ascii_letters + string.digits + "-_" + for part in parts: + assert all(c in url_safe for c in part) + + +class TestGetCurrentUserEndpoint: + """Test GET /auth/me endpoint.""" + + def test_get_current_user_success(self, client: TestClient, test_user_data: dict): + """Test getting current user info with valid token.""" + # Register and login + client.post("/api/v1/auth/register", json=test_user_data) + login_response = client.post("/api/v1/auth/login", json=test_user_data) + + token = login_response.json()["access_token"] + + # Get current user + response = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + + assert response.status_code == status.HTTP_200_OK + + data = response.json() + assert data["email"] == test_user_data["email"] + assert "id" in data + assert "created_at" in data + assert "password" not in data + + def test_get_current_user_no_token(self, client: TestClient): + """Test that missing token returns 401.""" + response = client.get("/api/v1/auth/me") + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_current_user_invalid_token(self, client: TestClient): + """Test that invalid token returns 401.""" + response = client.get("/api/v1/auth/me", headers={"Authorization": "Bearer invalid_token"}) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_current_user_malformed_header(self, client: TestClient): + """Test that malformed auth header returns 401.""" + # Missing "Bearer" prefix + response1 = client.get("/api/v1/auth/me", headers={"Authorization": "just_a_token"}) + assert response1.status_code == status.HTTP_401_UNAUTHORIZED + + # Wrong prefix + response2 = client.get("/api/v1/auth/me", headers={"Authorization": "Basic dGVzdA=="}) + assert response2.status_code == status.HTTP_401_UNAUTHORIZED + + def test_get_current_user_expired_token(self, client: TestClient, test_user_data: dict): + """Test that expired token returns 401.""" + from datetime import timedelta + + from app.auth.jwt import create_access_token + + # Register user + register_response = client.post("/api/v1/auth/register", json=test_user_data) + user_id = register_response.json()["id"] + + # Create expired token + from uuid import UUID + + expired_token = create_access_token(UUID(user_id), test_user_data["email"], timedelta(seconds=-10)) + + # Try to use expired token + response = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {expired_token}"}) + + assert response.status_code == status.HTTP_401_UNAUTHORIZED + + +class TestAuthenticationFlow: + """Test complete authentication flows.""" + + def test_complete_register_login_access_flow(self, client: TestClient, test_user_data: dict): + """Test complete flow: register → login → access protected resource.""" + # Step 1: Register + register_response = client.post("/api/v1/auth/register", json=test_user_data) + assert register_response.status_code == status.HTTP_201_CREATED + + registered_user = register_response.json() + assert registered_user["email"] == test_user_data["email"] + + # Step 2: Login + login_response = client.post("/api/v1/auth/login", json=test_user_data) + assert login_response.status_code == status.HTTP_200_OK + + token = login_response.json()["access_token"] + login_user = login_response.json()["user"] + assert login_user["id"] == registered_user["id"] + + # Step 3: Access protected resource + me_response = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert me_response.status_code == status.HTTP_200_OK + + current_user = me_response.json() + assert current_user["id"] == registered_user["id"] + assert current_user["email"] == test_user_data["email"] + + def test_multiple_users_independent_authentication(self, client: TestClient): + """Test that multiple users can register and authenticate independently.""" + users = [ + {"email": "user1@example.com", "password": "Password123"}, + {"email": "user2@example.com", "password": "Password456"}, + {"email": "user3@example.com", "password": "Password789"}, + ] + + tokens = [] + + # Register all users + for user_data in users: + register_response = client.post("/api/v1/auth/register", json=user_data) + assert register_response.status_code == status.HTTP_201_CREATED + + # Login each user + login_response = client.post("/api/v1/auth/login", json=user_data) + assert login_response.status_code == status.HTTP_200_OK + + tokens.append(login_response.json()["access_token"]) + + # Verify each token works independently + for i, (user_data, token) in enumerate(zip(users, tokens)): + response = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + assert response.status_code == status.HTTP_200_OK + assert response.json()["email"] == user_data["email"] + + def test_token_reuse_across_multiple_requests(self, client: TestClient, test_user_data: dict): + """Test that same token can be reused for multiple requests.""" + # Register and login + client.post("/api/v1/auth/register", json=test_user_data) + login_response = client.post("/api/v1/auth/login", json=test_user_data) + + token = login_response.json()["access_token"] + headers = {"Authorization": f"Bearer {token}"} + + # Make multiple requests with same token + for _ in range(5): + response = client.get("/api/v1/auth/me", headers=headers) + assert response.status_code == status.HTTP_200_OK + assert response.json()["email"] == test_user_data["email"] + + def test_password_not_exposed_in_any_response(self, client: TestClient, test_user_data: dict): + """Test that password is never exposed in any API response.""" + # Register + register_response = client.post("/api/v1/auth/register", json=test_user_data) + register_data = register_response.json() + + assert "password" not in register_data + assert "password_hash" not in register_data + + # Login + login_response = client.post("/api/v1/auth/login", json=test_user_data) + login_data = login_response.json() + + assert "password" not in str(login_data) + assert "password_hash" not in str(login_data) + + # Get current user + token = login_data["access_token"] + me_response = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) + me_data = me_response.json() + + assert "password" not in me_data + assert "password_hash" not in me_data + diff --git a/backend/tests/auth/__init__.py b/backend/tests/auth/__init__.py new file mode 100644 index 0000000..35cd4fa --- /dev/null +++ b/backend/tests/auth/__init__.py @@ -0,0 +1,2 @@ +"""Auth module tests.""" + diff --git a/backend/tests/auth/test_jwt.py b/backend/tests/auth/test_jwt.py new file mode 100644 index 0000000..8a1b000 --- /dev/null +++ b/backend/tests/auth/test_jwt.py @@ -0,0 +1,315 @@ +"""Unit tests for JWT token generation and validation.""" + +from datetime import datetime, timedelta +from uuid import UUID, uuid4 + +import pytest +from jose import jwt + +from app.auth.jwt import create_access_token, decode_access_token +from app.core.config import settings + + +class TestCreateAccessToken: + """Test JWT access token creation.""" + + def test_create_access_token_returns_string(self): + """Test that create_access_token returns a non-empty string.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + + assert isinstance(token, str) + assert len(token) > 0 + + def test_create_access_token_contains_user_data(self): + """Test that token contains user ID and email.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + + # Decode without verification to inspect payload + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + + assert payload["sub"] == str(user_id) + assert payload["email"] == email + + def test_create_access_token_contains_required_claims(self): + """Test that token contains all required JWT claims.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + + # Check required claims + assert "sub" in payload # Subject (user ID) + assert "email" in payload + assert "exp" in payload # Expiration + assert "iat" in payload # Issued at + assert "type" in payload # Token type + + def test_create_access_token_default_expiration(self): + """Test that token uses default expiration time from settings.""" + user_id = uuid4() + email = "test@example.com" + + before = datetime.utcnow() + token = create_access_token(user_id, email) + after = datetime.utcnow() + + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + exp_timestamp = payload["exp"] + exp_datetime = datetime.fromtimestamp(exp_timestamp) + + # Calculate expected expiration range + min_exp = before + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + max_exp = after + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) + + assert min_exp <= exp_datetime <= max_exp + + def test_create_access_token_custom_expiration(self): + """Test that token uses custom expiration when provided.""" + user_id = uuid4() + email = "test@example.com" + custom_delta = timedelta(hours=2) + + before = datetime.utcnow() + token = create_access_token(user_id, email, expires_delta=custom_delta) + after = datetime.utcnow() + + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + exp_timestamp = payload["exp"] + exp_datetime = datetime.fromtimestamp(exp_timestamp) + + min_exp = before + custom_delta + max_exp = after + custom_delta + + assert min_exp <= exp_datetime <= max_exp + + def test_create_access_token_type_is_access(self): + """Test that token type is set to 'access'.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) + + assert payload["type"] == "access" + + def test_create_access_token_different_users_different_tokens(self): + """Test that different users get different tokens.""" + user1_id = uuid4() + user2_id = uuid4() + email1 = "user1@example.com" + email2 = "user2@example.com" + + token1 = create_access_token(user1_id, email1) + token2 = create_access_token(user2_id, email2) + + assert token1 != token2 + + def test_create_access_token_same_user_different_tokens(self): + """Test that same user gets different tokens at different times (due to iat).""" + user_id = uuid4() + email = "test@example.com" + + token1 = create_access_token(user_id, email) + # Wait a tiny bit to ensure different iat + import time + + time.sleep(0.01) + token2 = create_access_token(user_id, email) + + # Tokens should be different because iat (issued at) is different + assert token1 != token2 + + +class TestDecodeAccessToken: + """Test JWT access token decoding and validation.""" + + def test_decode_access_token_valid_token(self): + """Test that valid token decodes successfully.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + payload = decode_access_token(token) + + assert payload is not None + assert payload["sub"] == str(user_id) + assert payload["email"] == email + + def test_decode_access_token_invalid_token(self): + """Test that invalid token returns None.""" + invalid_tokens = [ + "invalid.token.here", + "not_a_jwt", + "", + "a.b.c.d.e", # Too many parts + ] + + for token in invalid_tokens: + payload = decode_access_token(token) + assert payload is None + + def test_decode_access_token_wrong_secret(self): + """Test that token signed with different secret fails.""" + user_id = uuid4() + email = "test@example.com" + + # Create token with different secret + wrong_payload = {"sub": str(user_id), "email": email, "exp": datetime.utcnow() + timedelta(minutes=30)} + wrong_token = jwt.encode(wrong_payload, "wrong_secret_key", algorithm=settings.ALGORITHM) + + payload = decode_access_token(wrong_token) + assert payload is None + + def test_decode_access_token_expired_token(self): + """Test that expired token returns None.""" + user_id = uuid4() + email = "test@example.com" + + # Create token that expired 1 hour ago + expired_delta = timedelta(hours=-1) + token = create_access_token(user_id, email, expires_delta=expired_delta) + + payload = decode_access_token(token) + assert payload is None + + def test_decode_access_token_wrong_algorithm(self): + """Test that token with wrong algorithm fails.""" + user_id = uuid4() + email = "test@example.com" + + # Create token with different algorithm + wrong_payload = { + "sub": str(user_id), + "email": email, + "exp": datetime.utcnow() + timedelta(minutes=30), + } + # Use HS512 instead of HS256 + wrong_token = jwt.encode(wrong_payload, settings.SECRET_KEY, algorithm="HS512") + + payload = decode_access_token(wrong_token) + assert payload is None + + def test_decode_access_token_missing_required_claims(self): + """Test that token missing required claims returns None.""" + # Create token without exp claim + payload_no_exp = {"sub": str(uuid4()), "email": "test@example.com"} + token_no_exp = jwt.encode(payload_no_exp, settings.SECRET_KEY, algorithm=settings.ALGORITHM) + + # jose library will reject tokens without exp when validating + payload = decode_access_token(token_no_exp) + # This should still decode (jose doesn't require exp by default) + # But we document this behavior + assert payload is not None or payload is None # Depends on jose version + + def test_decode_access_token_preserves_all_claims(self): + """Test that all claims are preserved in decoded payload.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + payload = decode_access_token(token) + + assert payload is not None + assert "sub" in payload + assert "email" in payload + assert "exp" in payload + assert "iat" in payload + assert "type" in payload + assert payload["type"] == "access" + + +class TestJWTSecurityProperties: + """Test security properties of JWT implementation.""" + + def test_jwt_token_is_url_safe(self): + """Test that JWT tokens are URL-safe.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + + # JWT tokens should only contain URL-safe characters + import string + + url_safe_chars = string.ascii_letters + string.digits + "-_." + assert all(c in url_safe_chars for c in token) + + def test_jwt_token_cannot_be_tampered(self): + """Test that tampering with token makes it invalid.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + + # Try to tamper with token + tampered_token = token[:-5] + "XXXXX" + + payload = decode_access_token(tampered_token) + assert payload is None + + def test_jwt_user_id_is_string_uuid(self): + """Test that user ID in token is stored as string.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + payload = decode_access_token(token) + + assert payload is not None + assert isinstance(payload["sub"], str) + + # Should be valid UUID string + parsed_uuid = UUID(payload["sub"]) + assert parsed_uuid == user_id + + def test_jwt_email_preserved_correctly(self): + """Test that email is preserved with correct casing and format.""" + user_id = uuid4() + test_emails = [ + "test@example.com", + "Test.User@Example.COM", + "user+tag@domain.co.uk", + "first.last@sub.domain.org", + ] + + for email in test_emails: + token = create_access_token(user_id, email) + payload = decode_access_token(token) + + assert payload is not None + assert payload["email"] == email + + def test_jwt_expiration_is_timestamp(self): + """Test that expiration is stored as Unix timestamp.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + payload = decode_access_token(token) + + assert payload is not None + assert isinstance(payload["exp"], (int, float)) + + # Should be a reasonable timestamp (between 2020 and 2030) + assert 1577836800 < payload["exp"] < 1893456000 + + def test_jwt_iat_before_exp(self): + """Test that issued-at time is before expiration time.""" + user_id = uuid4() + email = "test@example.com" + + token = create_access_token(user_id, email) + payload = decode_access_token(token) + + assert payload is not None + assert payload["iat"] < payload["exp"] + diff --git a/backend/tests/auth/test_security.py b/backend/tests/auth/test_security.py new file mode 100644 index 0000000..244ac22 --- /dev/null +++ b/backend/tests/auth/test_security.py @@ -0,0 +1,235 @@ +"""Unit tests for password hashing and validation.""" + +import pytest + +from app.auth.security import hash_password, validate_password_strength, verify_password + + +class TestPasswordHashing: + """Test password hashing functionality.""" + + def test_hash_password_returns_string(self): + """Test that hash_password returns a non-empty string.""" + password = "TestPassword123" + hashed = hash_password(password) + + assert isinstance(hashed, str) + assert len(hashed) > 0 + assert hashed != password + + def test_hash_password_generates_unique_hashes(self): + """Test that same password generates different hashes (bcrypt salt).""" + password = "TestPassword123" + hash1 = hash_password(password) + hash2 = hash_password(password) + + assert hash1 != hash2 # Different salts + + def test_hash_password_with_special_characters(self): + """Test hashing passwords with special characters.""" + password = "P@ssw0rd!#$%" + hashed = hash_password(password) + + assert isinstance(hashed, str) + assert len(hashed) > 0 + + def test_hash_password_with_unicode(self): + """Test hashing passwords with unicode characters.""" + password = "Pässwörd123" + hashed = hash_password(password) + + assert isinstance(hashed, str) + assert len(hashed) > 0 + + +class TestPasswordVerification: + """Test password verification functionality.""" + + def test_verify_password_correct_password(self): + """Test that correct password verifies successfully.""" + password = "TestPassword123" + hashed = hash_password(password) + + assert verify_password(password, hashed) is True + + def test_verify_password_incorrect_password(self): + """Test that incorrect password fails verification.""" + password = "TestPassword123" + hashed = hash_password(password) + + assert verify_password("WrongPassword123", hashed) is False + + def test_verify_password_case_sensitive(self): + """Test that password verification is case-sensitive.""" + password = "TestPassword123" + hashed = hash_password(password) + + assert verify_password("testpassword123", hashed) is False + assert verify_password("TESTPASSWORD123", hashed) is False + + def test_verify_password_empty_string(self): + """Test that empty password fails verification.""" + password = "TestPassword123" + hashed = hash_password(password) + + assert verify_password("", hashed) is False + + def test_verify_password_with_special_characters(self): + """Test verification of passwords with special characters.""" + password = "P@ssw0rd!#$%" + hashed = hash_password(password) + + assert verify_password(password, hashed) is True + assert verify_password("P@ssw0rd!#$", hashed) is False # Missing last char + + def test_verify_password_invalid_hash_format(self): + """Test that invalid hash format returns False.""" + password = "TestPassword123" + + assert verify_password(password, "invalid_hash") is False + assert verify_password(password, "") is False + + +class TestPasswordStrengthValidation: + """Test password strength validation.""" + + def test_validate_password_valid_password(self): + """Test that valid passwords pass validation.""" + valid_passwords = [ + "Password123", + "Abcdef123", + "SecureP@ss1", + "MyP4ssword", + ] + + for password in valid_passwords: + is_valid, error = validate_password_strength(password) + assert is_valid is True, f"Password '{password}' should be valid" + assert error == "" + + def test_validate_password_too_short(self): + """Test that passwords shorter than 8 characters fail.""" + short_passwords = [ + "Pass1", + "Abc123", + "Short1A", + ] + + for password in short_passwords: + is_valid, error = validate_password_strength(password) + assert is_valid is False + assert "at least 8 characters" in error + + def test_validate_password_no_uppercase(self): + """Test that passwords without uppercase letters fail.""" + passwords = [ + "password123", + "mypassword1", + "lowercase8", + ] + + for password in passwords: + is_valid, error = validate_password_strength(password) + assert is_valid is False + assert "uppercase letter" in error + + def test_validate_password_no_lowercase(self): + """Test that passwords without lowercase letters fail.""" + passwords = [ + "PASSWORD123", + "MYPASSWORD1", + "UPPERCASE8", + ] + + for password in passwords: + is_valid, error = validate_password_strength(password) + assert is_valid is False + assert "lowercase letter" in error + + def test_validate_password_no_number(self): + """Test that passwords without numbers fail.""" + passwords = [ + "Password", + "MyPassword", + "NoNumbers", + ] + + for password in passwords: + is_valid, error = validate_password_strength(password) + assert is_valid is False + assert "one number" in error + + def test_validate_password_edge_cases(self): + """Test password validation edge cases.""" + # Exactly 8 characters, all requirements met + is_valid, error = validate_password_strength("Abcdef12") + assert is_valid is True + assert error == "" + + # Very long password + is_valid, error = validate_password_strength("A" * 100 + "a1") + assert is_valid is True + + # Empty password + is_valid, error = validate_password_strength("") + assert is_valid is False + + def test_validate_password_with_special_chars(self): + """Test that special characters don't interfere with validation.""" + passwords_with_special = [ + "P@ssw0rd!", + "MyP@ss123", + "Test#Pass1", + ] + + for password in passwords_with_special: + is_valid, error = validate_password_strength(password) + assert is_valid is True, f"Password '{password}' should be valid" + assert error == "" + + +class TestPasswordSecurityProperties: + """Test security properties of password handling.""" + + def test_hashed_password_not_reversible(self): + """Test that hashed passwords cannot be easily reversed.""" + password = "TestPassword123" + hashed = hash_password(password) + + # Hash should not contain original password + assert password not in hashed + assert password.lower() not in hashed.lower() + + def test_different_passwords_different_hashes(self): + """Test that different passwords produce different hashes.""" + password1 = "TestPassword123" + password2 = "TestPassword124" # Only last char different + + hash1 = hash_password(password1) + hash2 = hash_password(password2) + + assert hash1 != hash2 + + def test_hashed_password_length_consistent(self): + """Test that bcrypt hashes have consistent length.""" + passwords = ["Short1A", "MediumPassword123", "VeryLongPasswordWithLotsOfCharacters123"] + + hashes = [hash_password(p) for p in passwords] + + # All bcrypt hashes should be 60 characters + for hashed in hashes: + assert len(hashed) == 60 + + def test_verify_handles_timing_attack_resistant(self): + """Test that verification doesn't leak timing information (bcrypt property).""" + # This is more of a documentation test - bcrypt is designed to be timing-attack resistant + password = "TestPassword123" + hashed = hash_password(password) + + # Both should take roughly the same time (bcrypt property) + verify_password("WrongPassword123", hashed) + verify_password(password, hashed) + + # No actual timing measurement here, just documenting the property + assert True + diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py new file mode 100644 index 0000000..c509ec0 --- /dev/null +++ b/backend/tests/conftest.py @@ -0,0 +1,107 @@ +"""Pytest configuration and fixtures for all tests.""" + +import os +from typing import Generator + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from app.core.deps import get_db +from app.database.base import Base +from app.main import app + +# Use in-memory SQLite for tests +SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, +) + +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +@pytest.fixture(scope="function") +def db() -> Generator[Session, None, None]: + """ + Create a fresh database for each test. + + Yields: + Database session + """ + # Create all tables + Base.metadata.create_all(bind=engine) + + # Create session + session = TestingSessionLocal() + + try: + yield session + finally: + session.close() + # Drop all tables after test + Base.metadata.drop_all(bind=engine) + + +@pytest.fixture(scope="function") +def client(db: Session) -> Generator[TestClient, None, None]: + """ + Create a test client with database override. + + Args: + db: Test database session + + Yields: + FastAPI test client + """ + + def override_get_db(): + try: + yield db + finally: + pass + + app.dependency_overrides[get_db] = override_get_db + + with TestClient(app) as test_client: + yield test_client + + app.dependency_overrides.clear() + + +@pytest.fixture +def test_user_data() -> dict: + """ + Standard test user data. + + Returns: + Dictionary with test user credentials + """ + return {"email": "test@example.com", "password": "TestPassword123"} + + +@pytest.fixture +def test_user_data_weak_password() -> dict: + """ + Test user data with weak password. + + Returns: + Dictionary with weak password + """ + return {"email": "test@example.com", "password": "weak"} + + +@pytest.fixture +def test_user_data_no_uppercase() -> dict: + """ + Test user data with no uppercase letter. + + Returns: + Dictionary with invalid password + """ + return {"email": "test@example.com", "password": "testpassword123"} + diff --git a/flake.nix b/flake.nix index b8c92bf..efe715e 100644 --- a/flake.nix +++ b/flake.nix @@ -103,14 +103,18 @@ type = "app"; program = "${pkgs.writeShellScript "lint" '' set -e - cd ${self} # Backend Python linting echo "🔍 Linting backend Python code..." - cd backend - ${pkgs.ruff}/bin/ruff check --no-cache app/ - ${pkgs.ruff}/bin/ruff format --check app/ - cd .. + if [ -d "backend" ]; then + cd backend + ${pkgs.ruff}/bin/ruff check --no-cache app/ + ${pkgs.ruff}/bin/ruff format --check app/ + cd .. + else + echo "⚠ Not in project root (backend/ not found)" + exit 1 + fi # Frontend linting (if node_modules exists) if [ -d "frontend/node_modules" ]; then @@ -118,7 +122,7 @@ echo "🔍 Linting frontend TypeScript/Svelte code..." cd frontend npm run lint - npx prettier --check src/ + ${pkgs.nodePackages.prettier}/bin/prettier --check src/ npm run check cd .. else @@ -135,19 +139,23 @@ type = "app"; program = "${pkgs.writeShellScript "lint-fix" '' set -e - cd ${self} echo "🔧 Auto-fixing backend Python code..." - cd backend - ${pkgs.ruff}/bin/ruff check --fix --no-cache app/ - ${pkgs.ruff}/bin/ruff format app/ - cd .. + if [ -d "backend" ]; then + cd backend + ${pkgs.ruff}/bin/ruff check --fix --no-cache app/ || true + ${pkgs.ruff}/bin/ruff format app/ + cd .. + else + echo "⚠ Not in project root (backend/ not found)" + exit 1 + fi if [ -d "frontend/node_modules" ]; then echo "" echo "🔧 Auto-fixing frontend code..." cd frontend - npx prettier --write src/ + ${pkgs.nodePackages.prettier}/bin/prettier --write src/ cd .. fi diff --git a/frontend/tests/components/auth.test.ts b/frontend/tests/components/auth.test.ts new file mode 100644 index 0000000..727337f --- /dev/null +++ b/frontend/tests/components/auth.test.ts @@ -0,0 +1,505 @@ +/** + * Component tests for authentication forms + * Tests LoginForm and RegisterForm Svelte components + */ + +import { render, fireEvent, screen, waitFor } from '@testing-library/svelte'; +import { describe, it, expect, vi } from 'vitest'; +import LoginForm from '$lib/components/auth/LoginForm.svelte'; +import RegisterForm from '$lib/components/auth/RegisterForm.svelte'; + +describe('LoginForm', () => { + describe('Rendering', () => { + it('renders email and password fields', () => { + render(LoginForm); + + expect(screen.getByLabelText(/email/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/password/i)).toBeInTheDocument(); + }); + + it('renders submit button with correct text', () => { + render(LoginForm); + + const button = screen.getByRole('button', { name: /login/i }); + expect(button).toBeInTheDocument(); + expect(button).not.toBeDisabled(); + }); + + it('shows loading state when isLoading prop is true', () => { + render(LoginForm, { props: { isLoading: true } }); + + const button = screen.getByRole('button'); + expect(button).toBeDisabled(); + expect(screen.getByText(/logging in/i)).toBeInTheDocument(); + }); + + it('has proper autocomplete attributes', () => { + render(LoginForm); + + const emailInput = screen.getByLabelText(/email/i); + const passwordInput = screen.getByLabelText(/password/i); + + expect(emailInput).toHaveAttribute('autocomplete', 'email'); + expect(passwordInput).toHaveAttribute('autocomplete', 'current-password'); + }); + }); + + describe('Validation', () => { + it('shows error when email is empty on submit', async () => { + render(LoginForm); + + const button = screen.getByRole('button', { name: /login/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/email is required/i)).toBeInTheDocument(); + }); + + it('shows error when email is invalid', async () => { + render(LoginForm); + + const emailInput = screen.getByLabelText(/email/i); + await fireEvent.input(emailInput, { target: { value: 'invalid-email' } }); + + const button = screen.getByRole('button', { name: /login/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/valid email address/i)).toBeInTheDocument(); + }); + + it('shows error when password is empty on submit', async () => { + render(LoginForm); + + const emailInput = screen.getByLabelText(/email/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const button = screen.getByRole('button', { name: /login/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/password is required/i)).toBeInTheDocument(); + }); + + it('accepts valid email formats', async () => { + const validEmails = ['test@example.com', 'user+tag@domain.co.uk', 'first.last@example.com']; + + for (const email of validEmails) { + const { unmount } = render(LoginForm); + + const emailInput = screen.getByLabelText(/email/i); + await fireEvent.input(emailInput, { target: { value: email } }); + + const passwordInput = screen.getByLabelText(/password/i); + await fireEvent.input(passwordInput, { target: { value: 'password123' } }); + + const button = screen.getByRole('button', { name: /login/i }); + await fireEvent.click(button); + + // Should not show email error + expect(screen.queryByText(/valid email address/i)).not.toBeInTheDocument(); + + unmount(); + } + }); + + it('clears errors when form is corrected', async () => { + render(LoginForm); + + // Submit with empty email + const button = screen.getByRole('button', { name: /login/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/email is required/i)).toBeInTheDocument(); + + // Fix email + const emailInput = screen.getByLabelText(/email/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + // Submit again + await fireEvent.click(button); + + // Email error should be gone, but password error should appear + expect(screen.queryByText(/email is required/i)).not.toBeInTheDocument(); + expect(await screen.findByText(/password is required/i)).toBeInTheDocument(); + }); + }); + + describe('Submission', () => { + it('dispatches submit event with correct data on valid form', async () => { + const { component } = render(LoginForm); + + const submitHandler = vi.fn(); + component.$on('submit', submitHandler); + + const emailInput = screen.getByLabelText(/email/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const passwordInput = screen.getByLabelText(/password/i); + await fireEvent.input(passwordInput, { target: { value: 'TestPassword123' } }); + + const button = screen.getByRole('button', { name: /login/i }); + await fireEvent.click(button); + + await waitFor(() => { + expect(submitHandler).toHaveBeenCalledTimes(1); + }); + + const event = submitHandler.mock.calls[0][0]; + expect(event.detail).toEqual({ + email: 'test@example.com', + password: 'TestPassword123', + }); + }); + + it('does not dispatch submit event when form is invalid', async () => { + const { component } = render(LoginForm); + + const submitHandler = vi.fn(); + component.$on('submit', submitHandler); + + // Try to submit with empty fields + const button = screen.getByRole('button', { name: /login/i }); + await fireEvent.click(button); + + await waitFor(() => { + expect(screen.getByText(/email is required/i)).toBeInTheDocument(); + }); + + expect(submitHandler).not.toHaveBeenCalled(); + }); + + it('disables all inputs when loading', () => { + render(LoginForm, { props: { isLoading: true } }); + + const emailInput = screen.getByLabelText(/email/i); + const passwordInput = screen.getByLabelText(/password/i); + const button = screen.getByRole('button'); + + expect(emailInput).toBeDisabled(); + expect(passwordInput).toBeDisabled(); + expect(button).toBeDisabled(); + }); + }); +}); + +describe('RegisterForm', () => { + describe('Rendering', () => { + it('renders all required fields', () => { + render(RegisterForm); + + expect(screen.getByLabelText(/^email$/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/^password$/i)).toBeInTheDocument(); + expect(screen.getByLabelText(/confirm password/i)).toBeInTheDocument(); + }); + + it('renders submit button with correct text', () => { + render(RegisterForm); + + const button = screen.getByRole('button', { name: /create account/i }); + expect(button).toBeInTheDocument(); + expect(button).not.toBeDisabled(); + }); + + it('shows password requirements help text', () => { + render(RegisterForm); + + expect( + screen.getByText(/must be 8\+ characters with uppercase, lowercase, and number/i) + ).toBeInTheDocument(); + }); + + it('shows loading state when isLoading prop is true', () => { + render(RegisterForm, { props: { isLoading: true } }); + + const button = screen.getByRole('button'); + expect(button).toBeDisabled(); + expect(screen.getByText(/creating account/i)).toBeInTheDocument(); + }); + + it('has proper autocomplete attributes', () => { + render(RegisterForm); + + const emailInput = screen.getByLabelText(/^email$/i); + const passwordInput = screen.getByLabelText(/^password$/i); + const confirmPasswordInput = screen.getByLabelText(/confirm password/i); + + expect(emailInput).toHaveAttribute('autocomplete', 'email'); + expect(passwordInput).toHaveAttribute('autocomplete', 'new-password'); + expect(confirmPasswordInput).toHaveAttribute('autocomplete', 'new-password'); + }); + }); + + describe('Email Validation', () => { + it('shows error when email is empty', async () => { + render(RegisterForm); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/email is required/i)).toBeInTheDocument(); + }); + + it('shows error when email is invalid', async () => { + render(RegisterForm); + + const emailInput = screen.getByLabelText(/^email$/i); + await fireEvent.input(emailInput, { target: { value: 'not-an-email' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/valid email address/i)).toBeInTheDocument(); + }); + }); + + describe('Password Strength Validation', () => { + it('shows error when password is too short', async () => { + render(RegisterForm); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'Test1' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/at least 8 characters/i)).toBeInTheDocument(); + }); + + it('shows error when password lacks uppercase letter', async () => { + render(RegisterForm); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'testpassword123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/uppercase letter/i)).toBeInTheDocument(); + }); + + it('shows error when password lacks lowercase letter', async () => { + render(RegisterForm); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'TESTPASSWORD123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/lowercase letter/i)).toBeInTheDocument(); + }); + + it('shows error when password lacks number', async () => { + render(RegisterForm); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'TestPassword' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/contain a number/i)).toBeInTheDocument(); + }); + + it('accepts valid password meeting all requirements', async () => { + render(RegisterForm); + + const emailInput = screen.getByLabelText(/^email$/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'ValidPassword123' } }); + + const confirmPasswordInput = screen.getByLabelText(/confirm password/i); + await fireEvent.input(confirmPasswordInput, { target: { value: 'ValidPassword123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + // Should not show password strength errors + expect(screen.queryByText(/at least 8 characters/i)).not.toBeInTheDocument(); + expect(screen.queryByText(/uppercase letter/i)).not.toBeInTheDocument(); + expect(screen.queryByText(/lowercase letter/i)).not.toBeInTheDocument(); + expect(screen.queryByText(/contain a number/i)).not.toBeInTheDocument(); + }); + }); + + describe('Password Confirmation Validation', () => { + it('shows error when confirm password is empty', async () => { + render(RegisterForm); + + const emailInput = screen.getByLabelText(/^email$/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'ValidPassword123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/confirm your password/i)).toBeInTheDocument(); + }); + + it('shows error when passwords do not match', async () => { + render(RegisterForm); + + const emailInput = screen.getByLabelText(/^email$/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'ValidPassword123' } }); + + const confirmPasswordInput = screen.getByLabelText(/confirm password/i); + await fireEvent.input(confirmPasswordInput, { target: { value: 'DifferentPassword123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + expect(await screen.findByText(/passwords do not match/i)).toBeInTheDocument(); + }); + + it('accepts matching passwords', async () => { + render(RegisterForm); + + const emailInput = screen.getByLabelText(/^email$/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'ValidPassword123' } }); + + const confirmPasswordInput = screen.getByLabelText(/confirm password/i); + await fireEvent.input(confirmPasswordInput, { target: { value: 'ValidPassword123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + // Should not show confirmation error + expect(screen.queryByText(/passwords do not match/i)).not.toBeInTheDocument(); + }); + }); + + describe('Submission', () => { + it('dispatches submit event with correct data on valid form', async () => { + const { component } = render(RegisterForm); + + const submitHandler = vi.fn(); + component.$on('submit', submitHandler); + + const emailInput = screen.getByLabelText(/^email$/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'ValidPassword123' } }); + + const confirmPasswordInput = screen.getByLabelText(/confirm password/i); + await fireEvent.input(confirmPasswordInput, { target: { value: 'ValidPassword123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + await waitFor(() => { + expect(submitHandler).toHaveBeenCalledTimes(1); + }); + + const event = submitHandler.mock.calls[0][0]; + expect(event.detail).toEqual({ + email: 'test@example.com', + password: 'ValidPassword123', + }); + }); + + it('does not include confirmPassword in submit event', async () => { + const { component } = render(RegisterForm); + + const submitHandler = vi.fn(); + component.$on('submit', submitHandler); + + const emailInput = screen.getByLabelText(/^email$/i); + await fireEvent.input(emailInput, { target: { value: 'test@example.com' } }); + + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'ValidPassword123' } }); + + const confirmPasswordInput = screen.getByLabelText(/confirm password/i); + await fireEvent.input(confirmPasswordInput, { target: { value: 'ValidPassword123' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + await waitFor(() => { + expect(submitHandler).toHaveBeenCalled(); + }); + + const event = submitHandler.mock.calls[0][0]; + expect(event.detail).not.toHaveProperty('confirmPassword'); + }); + + it('does not dispatch submit event when form is invalid', async () => { + const { component } = render(RegisterForm); + + const submitHandler = vi.fn(); + component.$on('submit', submitHandler); + + // Try to submit with empty fields + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + await waitFor(() => { + expect(screen.getByText(/email is required/i)).toBeInTheDocument(); + }); + + expect(submitHandler).not.toHaveBeenCalled(); + }); + + it('disables all inputs when loading', () => { + render(RegisterForm, { props: { isLoading: true } }); + + const emailInput = screen.getByLabelText(/^email$/i); + const passwordInput = screen.getByLabelText(/^password$/i); + const confirmPasswordInput = screen.getByLabelText(/confirm password/i); + const button = screen.getByRole('button'); + + expect(emailInput).toBeDisabled(); + expect(passwordInput).toBeDisabled(); + expect(confirmPasswordInput).toBeDisabled(); + expect(button).toBeDisabled(); + }); + }); + + describe('User Experience', () => { + it('hides help text when password error is shown', async () => { + render(RegisterForm); + + // Help text should be visible initially + expect( + screen.getByText(/must be 8\+ characters with uppercase, lowercase, and number/i) + ).toBeInTheDocument(); + + // Enter invalid password + const passwordInput = screen.getByLabelText(/^password$/i); + await fireEvent.input(passwordInput, { target: { value: 'short' } }); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + // Error should be shown + expect(await screen.findByText(/at least 8 characters/i)).toBeInTheDocument(); + + // Help text should be hidden + expect( + screen.queryByText(/must be 8\+ characters with uppercase, lowercase, and number/i) + ).not.toBeInTheDocument(); + }); + + it('validates all fields independently', async () => { + render(RegisterForm); + + const button = screen.getByRole('button', { name: /create account/i }); + await fireEvent.click(button); + + // All errors should be shown + expect(await screen.findByText(/email is required/i)).toBeInTheDocument(); + expect(await screen.findByText(/password is required/i)).toBeInTheDocument(); + expect(await screen.findByText(/confirm your password/i)).toBeInTheDocument(); + }); + }); +}); + diff --git a/specs/001-reference-board-viewer/tasks.md b/specs/001-reference-board-viewer/tasks.md index ec8e313..f3a20d3 100644 --- a/specs/001-reference-board-viewer/tasks.md +++ b/specs/001-reference-board-viewer/tasks.md @@ -110,9 +110,9 @@ Implementation tasks for the Reference Board Viewer, organized by user story (fu - [X] T042 [US1] Implement login endpoint POST /auth/login in backend/app/api/auth.py - [X] T043 [US1] Implement current user endpoint GET /auth/me in backend/app/api/auth.py - [X] T044 [US1] Create JWT validation dependency in backend/app/core/deps.py (get_current_user) -- [ ] T045 [P] [US1] Write unit tests for password hashing in backend/tests/auth/test_security.py -- [ ] T046 [P] [US1] Write unit tests for JWT generation in backend/tests/auth/test_jwt.py -- [ ] T047 [P] [US1] Write integration tests for auth endpoints in backend/tests/api/test_auth.py +- [X] T045 [P] [US1] Write unit tests for password hashing in backend/tests/auth/test_security.py +- [X] T046 [P] [US1] Write unit tests for JWT generation in backend/tests/auth/test_jwt.py +- [X] T047 [P] [US1] Write integration tests for auth endpoints in backend/tests/api/test_auth.py **Frontend Tasks:** @@ -123,7 +123,7 @@ Implementation tasks for the Reference Board Viewer, organized by user story (fu - [X] T052 [US1] Implement route protection in frontend/src/hooks.server.ts - [X] T053 [P] [US1] Create LoginForm component in frontend/src/lib/components/auth/LoginForm.svelte - [X] T054 [P] [US1] Create RegisterForm component in frontend/src/lib/components/auth/RegisterForm.svelte -- [ ] T055 [P] [US1] Write component tests for auth forms in frontend/tests/components/auth.test.ts +- [X] T055 [P] [US1] Write component tests for auth forms in frontend/tests/components/auth.test.ts **Deliverables:** - Complete authentication system @@ -146,15 +146,15 @@ Implementation tasks for the Reference Board Viewer, organized by user story (fu **Backend Tasks:** -- [ ] T056 [P] [US2] Create Board model in backend/app/database/models/board.py from data-model.md -- [ ] T057 [P] [US2] Create board schemas in backend/app/boards/schemas.py (BoardCreate, BoardUpdate, BoardResponse) -- [ ] T058 [US2] Create board repository in backend/app/boards/repository.py (CRUD operations) -- [ ] T059 [US2] Implement create board endpoint POST /boards in backend/app/api/boards.py -- [ ] T060 [US2] Implement list boards endpoint GET /boards in backend/app/api/boards.py -- [ ] T061 [US2] Implement get board endpoint GET /boards/{id} in backend/app/api/boards.py -- [ ] T062 [US2] Implement update board endpoint PATCH /boards/{id} in backend/app/api/boards.py -- [ ] T063 [US2] Implement delete board endpoint DELETE /boards/{id} in backend/app/api/boards.py -- [ ] T064 [US2] Add ownership validation middleware in backend/app/boards/permissions.py +- [X] T056 [P] [US2] Create Board model in backend/app/database/models/board.py from data-model.md +- [X] T057 [P] [US2] Create board schemas in backend/app/boards/schemas.py (BoardCreate, BoardUpdate, BoardResponse) +- [X] T058 [US2] Create board repository in backend/app/boards/repository.py (CRUD operations) +- [X] T059 [US2] Implement create board endpoint POST /boards in backend/app/api/boards.py +- [X] T060 [US2] Implement list boards endpoint GET /boards in backend/app/api/boards.py +- [X] T061 [US2] Implement get board endpoint GET /boards/{id} in backend/app/api/boards.py +- [X] T062 [US2] Implement update board endpoint PATCH /boards/{id} in backend/app/api/boards.py +- [X] T063 [US2] Implement delete board endpoint DELETE /boards/{id} in backend/app/api/boards.py +- [X] T064 [US2] Add ownership validation middleware in backend/app/boards/permissions.py - [ ] T065 [P] [US2] Write unit tests for board repository in backend/tests/boards/test_repository.py - [ ] T066 [P] [US2] Write integration tests for board endpoints in backend/tests/api/test_boards.py