Refactor database connection handling in API endpoints

- Removed direct pool checks and replaced them with a centralized database initialization method in `init_db`.
- Updated API endpoints in `admin.py`, `collection.py`, `pins.py`, and `watched.py` to ensure the database connection pool is initialized before usage.
- Enhanced error handling to raise HTTP exceptions if the database is unavailable.
- Improved the `init_db` function in `database.py` to prevent multiple simultaneous initializations using an asyncio lock.
This commit is contained in:
Danilo Reyes
2025-12-28 21:37:31 -06:00
parent 98622c4119
commit c0371d85ce
6 changed files with 90 additions and 73 deletions

View File

@@ -1,7 +1,6 @@
"""Admin API endpoints""" """Admin API endpoints"""
from fastapi import APIRouter, HTTPException, Header from fastapi import APIRouter, HTTPException, Header
from typing import Optional from typing import Optional
from app.core.database import pool
from app.core.config import settings from app.core.config import settings
from app.services.sync import sync_all_arrs from app.services.sync import sync_all_arrs

View File

@@ -1,7 +1,6 @@
"""Collection API endpoints""" """Collection API endpoints"""
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from typing import List, Optional from typing import List, Optional
from app.core.database import pool
import json import json
router = APIRouter() router = APIRouter()
@@ -15,17 +14,20 @@ async def get_collection_summary(
Get collection summary by country and media type. Get collection summary by country and media type.
Returns counts per country per media type. Returns counts per country per media type.
""" """
# Pool should be initialized on startup, but check just in case # Ensure pool is initialized
if not pool: from app.core.database import init_db, pool as db_pool
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
from fastapi import HTTPException
raise HTTPException(status_code=503, detail="Database not available")
# Parse types filter # Parse types filter
type_filter = [] type_filter = []
if types: if types:
type_filter = [t.strip() for t in types.split(",") if t.strip() in ["movie", "show", "music"]] type_filter = [t.strip() for t in types.split(",") if t.strip() in ["movie", "show", "music"]]
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
# Build query # Build query
query = """ query = """

View File

@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
from uuid import UUID from uuid import UUID
from app.core.database import pool from app.core.database import init_db, pool as db_pool
router = APIRouter() router = APIRouter()
@@ -16,12 +16,11 @@ class PinCreate(BaseModel):
@router.get("") @router.get("")
async def list_pins(): async def list_pins():
"""List all manual pins""" """List all manual pins"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
query = """ query = """
SELECT id, country_code, label, pinned_at SELECT id, country_code, label, pinned_at
@@ -46,12 +45,11 @@ async def list_pins():
@router.post("") @router.post("")
async def create_pin(pin: PinCreate): async def create_pin(pin: PinCreate):
"""Create a new manual pin""" """Create a new manual pin"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
query = """ query = """
INSERT INTO moviemap.manual_pin (country_code, label) INSERT INTO moviemap.manual_pin (country_code, label)
@@ -68,12 +66,11 @@ async def create_pin(pin: PinCreate):
@router.delete("/{pin_id}") @router.delete("/{pin_id}")
async def delete_pin(pin_id: UUID): async def delete_pin(pin_id: UUID):
"""Delete a manual pin""" """Delete a manual pin"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
query = "DELETE FROM moviemap.manual_pin WHERE id = %s RETURNING id" query = "DELETE FROM moviemap.manual_pin WHERE id = %s RETURNING id"
await cur.execute(query, (str(pin_id),)) await cur.execute(query, (str(pin_id),))

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel
from typing import Optional from typing import Optional
from datetime import datetime from datetime import datetime
from uuid import UUID from uuid import UUID
from app.core.database import pool from app.core.database import init_db, pool as db_pool
import json import json
router = APIRouter() router = APIRouter()
@@ -30,12 +30,11 @@ class WatchedItemUpdate(BaseModel):
@router.get("/summary") @router.get("/summary")
async def get_watched_summary(): async def get_watched_summary():
"""Get watched items summary by country""" """Get watched items summary by country"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
query = """ query = """
SELECT SELECT
@@ -63,12 +62,11 @@ async def get_watched_summary():
@router.get("") @router.get("")
async def list_watched_items(): async def list_watched_items():
"""List all watched items""" """List all watched items"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
query = """ query = """
SELECT SELECT
@@ -100,15 +98,14 @@ async def list_watched_items():
@router.post("") @router.post("")
async def create_watched_item(item: WatchedItemCreate): async def create_watched_item(item: WatchedItemCreate):
"""Create a new watched item""" """Create a new watched item"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
if item.media_type not in ["movie", "show"]: if item.media_type not in ["movie", "show"]:
raise HTTPException(status_code=400, detail="media_type must be 'movie' or 'show'") raise HTTPException(status_code=400, detail="media_type must be 'movie' or 'show'")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
query = """ query = """
INSERT INTO moviemap.watched_item INSERT INTO moviemap.watched_item
@@ -136,12 +133,11 @@ async def create_watched_item(item: WatchedItemCreate):
@router.patch("/{item_id}") @router.patch("/{item_id}")
async def update_watched_item(item_id: UUID, item: WatchedItemUpdate): async def update_watched_item(item_id: UUID, item: WatchedItemUpdate):
"""Update a watched item""" """Update a watched item"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
# Build dynamic update query # Build dynamic update query
updates = [] updates = []
@@ -189,12 +185,11 @@ async def update_watched_item(item_id: UUID, item: WatchedItemUpdate):
@router.delete("/{item_id}") @router.delete("/{item_id}")
async def delete_watched_item(item_id: UUID): async def delete_watched_item(item_id: UUID):
"""Delete a watched item""" """Delete a watched item"""
# Pool should be initialized on startup
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise HTTPException(status_code=503, detail="Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
query = "DELETE FROM moviemap.watched_item WHERE id = %s RETURNING id" query = "DELETE FROM moviemap.watched_item WHERE id = %s RETURNING id"
await cur.execute(query, (str(item_id),)) await cur.execute(query, (str(item_id),))

View File

@@ -4,16 +4,37 @@ from psycopg_pool import AsyncConnectionPool
from app.core.config import settings from app.core.config import settings
from typing import Optional from typing import Optional
import logging import logging
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Connection pool # Connection pool
pool: Optional[AsyncConnectionPool] = None pool: Optional[AsyncConnectionPool] = None
_init_lock = asyncio.Lock()
_initializing = False
async def init_db(): async def init_db():
"""Initialize database connection pool""" """Initialize database connection pool"""
global pool global pool, _initializing
# If already initialized, return
if pool is not None:
return
# Use lock to prevent multiple simultaneous initializations
async with _init_lock:
# Double-check after acquiring lock
if pool is not None:
return
if _initializing:
# Wait for other initialization to complete
while _initializing:
await asyncio.sleep(0.1)
return
_initializing = True
try: try:
pool = AsyncConnectionPool( pool = AsyncConnectionPool(
conninfo=settings.database_url, conninfo=settings.database_url,
@@ -25,7 +46,10 @@ async def init_db():
logger.info("Database connection pool initialized") logger.info("Database connection pool initialized")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize database pool: {e}") logger.error(f"Failed to initialize database pool: {e}")
pool = None
raise raise
finally:
_initializing = False
async def close_db(): async def close_db():

View File

@@ -3,7 +3,6 @@ import httpx
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from app.core.config import settings from app.core.config import settings
from app.core.database import pool
import json import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -130,12 +129,12 @@ def extract_country_from_lidarr(artist: Dict) -> Optional[str]:
async def upsert_media_item(source_kind: str, source_item_id: int, title: str, async def upsert_media_item(source_kind: str, source_item_id: int, title: str,
year: Optional[int], media_type: str, arr_raw: Dict): year: Optional[int], media_type: str, arr_raw: Dict):
"""Upsert a media item into the database""" """Upsert a media item into the database"""
# Pool should be initialized on startup from app.core.database import init_db, pool as db_pool
if not pool:
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise Exception("Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
# Upsert media item # Upsert media item
query = """ query = """
@@ -273,12 +272,13 @@ async def sync_all_arrs() -> Dict:
logger.error(f"Lidarr sync failed: {e}") logger.error(f"Lidarr sync failed: {e}")
results["lidarr"] = 0 results["lidarr"] = 0
# Update last sync time (pool should be initialized) # Update last sync time
if not pool: from app.core.database import init_db, pool as db_pool
from app.core.database import init_db
await init_db() await init_db()
if db_pool is None:
raise Exception("Database not available")
async with pool.connection() as conn: async with db_pool.connection() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
for source_kind in ["radarr", "sonarr", "lidarr"]: for source_kind in ["radarr", "sonarr", "lidarr"]:
await cur.execute( await cur.execute(