"""Base repository with common database operations.""" from typing import TypeVar from uuid import UUID from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session # Type variable for model classes ModelType = TypeVar("ModelType") class BaseRepository[ModelType]: """Base repository with common CRUD operations.""" def __init__(self, model: type[ModelType], db: Session | AsyncSession): """ Initialize repository. Args: model: SQLAlchemy model class db: Database session (sync or async) """ self.model = model self.db = db def get_by_id_sync(self, id: UUID) -> ModelType | None: """ Get entity by ID (synchronous). Args: id: Entity UUID Returns: Entity if found, None otherwise """ return self.db.query(self.model).filter(self.model.id == id).first() async def get_by_id_async(self, id: UUID) -> ModelType | None: """ Get entity by ID (asynchronous). Args: id: Entity UUID Returns: Entity if found, None otherwise """ stmt = select(self.model).where(self.model.id == id) result = await self.db.execute(stmt) return result.scalar_one_or_none() def count_sync(self, **filters) -> int: """ Count entities with optional filters (synchronous). Args: **filters: Column filters (column_name=value) Returns: Count of matching entities """ query = self.db.query(func.count(self.model.id)) for key, value in filters.items(): query = query.filter(getattr(self.model, key) == value) return query.scalar() async def count_async(self, **filters) -> int: """ Count entities with optional filters (asynchronous). Args: **filters: Column filters (column_name=value) Returns: Count of matching entities """ stmt = select(func.count(self.model.id)) for key, value in filters.items(): stmt = stmt.where(getattr(self.model, key) == value) result = await self.db.execute(stmt) return result.scalar_one() def delete_sync(self, id: UUID) -> bool: """ Delete entity by ID (synchronous). Args: id: Entity UUID Returns: True if deleted, False if not found """ entity = self.get_by_id_sync(id) if not entity: return False self.db.delete(entity) self.db.commit() return True async def delete_async(self, id: UUID) -> bool: """ Delete entity by ID (asynchronous). Args: id: Entity UUID Returns: True if deleted, False if not found """ entity = await self.get_by_id_async(id) if not entity: return False await self.db.delete(entity) await self.db.commit() return True