Refactor database connection handling in API endpoints
- Removed direct pool checks and replaced them with a centralized database initialization method in `init_db`. - Updated API endpoints in `admin.py`, `collection.py`, `pins.py`, and `watched.py` to ensure the database connection pool is initialized before usage. - Enhanced error handling to raise HTTP exceptions if the database is unavailable. - Improved the `init_db` function in `database.py` to prevent multiple simultaneous initializations using an asyncio lock.
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
"""Admin API endpoints"""
|
||||
from fastapi import APIRouter, HTTPException, Header
|
||||
from typing import Optional
|
||||
from app.core.database import pool
|
||||
from app.core.config import settings
|
||||
from app.services.sync import sync_all_arrs
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Collection API endpoints"""
|
||||
from fastapi import APIRouter, Query
|
||||
from typing import List, Optional
|
||||
from app.core.database import pool
|
||||
import json
|
||||
|
||||
router = APIRouter()
|
||||
@@ -15,17 +14,20 @@ async def get_collection_summary(
|
||||
Get collection summary by country and media type.
|
||||
Returns counts per country per media type.
|
||||
"""
|
||||
# Pool should be initialized on startup, but check just in case
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
# Ensure pool is initialized
|
||||
from app.core.database import init_db, pool as db_pool
|
||||
await init_db()
|
||||
|
||||
if db_pool is None:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
# Parse types filter
|
||||
type_filter = []
|
||||
if types:
|
||||
type_filter = [t.strip() for t in types.split(",") if t.strip() in ["movie", "show", "music"]]
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
# Build query
|
||||
query = """
|
||||
|
||||
@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from app.core.database import pool
|
||||
from app.core.database import init_db, pool as db_pool
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -16,12 +16,11 @@ class PinCreate(BaseModel):
|
||||
@router.get("")
|
||||
async def list_pins():
|
||||
"""List all manual pins"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
query = """
|
||||
SELECT id, country_code, label, pinned_at
|
||||
@@ -46,12 +45,11 @@ async def list_pins():
|
||||
@router.post("")
|
||||
async def create_pin(pin: PinCreate):
|
||||
"""Create a new manual pin"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
query = """
|
||||
INSERT INTO moviemap.manual_pin (country_code, label)
|
||||
@@ -68,12 +66,11 @@ async def create_pin(pin: PinCreate):
|
||||
@router.delete("/{pin_id}")
|
||||
async def delete_pin(pin_id: UUID):
|
||||
"""Delete a manual pin"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
query = "DELETE FROM moviemap.manual_pin WHERE id = %s RETURNING id"
|
||||
await cur.execute(query, (str(pin_id),))
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from app.core.database import pool
|
||||
from app.core.database import init_db, pool as db_pool
|
||||
import json
|
||||
|
||||
router = APIRouter()
|
||||
@@ -30,12 +30,11 @@ class WatchedItemUpdate(BaseModel):
|
||||
@router.get("/summary")
|
||||
async def get_watched_summary():
|
||||
"""Get watched items summary by country"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
query = """
|
||||
SELECT
|
||||
@@ -63,12 +62,11 @@ async def get_watched_summary():
|
||||
@router.get("")
|
||||
async def list_watched_items():
|
||||
"""List all watched items"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
query = """
|
||||
SELECT
|
||||
@@ -100,15 +98,14 @@ async def list_watched_items():
|
||||
@router.post("")
|
||||
async def create_watched_item(item: WatchedItemCreate):
|
||||
"""Create a new watched item"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
if item.media_type not in ["movie", "show"]:
|
||||
raise HTTPException(status_code=400, detail="media_type must be 'movie' or 'show'")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
query = """
|
||||
INSERT INTO moviemap.watched_item
|
||||
@@ -136,12 +133,11 @@ async def create_watched_item(item: WatchedItemCreate):
|
||||
@router.patch("/{item_id}")
|
||||
async def update_watched_item(item_id: UUID, item: WatchedItemUpdate):
|
||||
"""Update a watched item"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
# Build dynamic update query
|
||||
updates = []
|
||||
@@ -189,12 +185,11 @@ async def update_watched_item(item_id: UUID, item: WatchedItemUpdate):
|
||||
@router.delete("/{item_id}")
|
||||
async def delete_watched_item(item_id: UUID):
|
||||
"""Delete a watched item"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise HTTPException(status_code=503, detail="Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
query = "DELETE FROM moviemap.watched_item WHERE id = %s RETURNING id"
|
||||
await cur.execute(query, (str(item_id),))
|
||||
|
||||
@@ -4,16 +4,37 @@ from psycopg_pool import AsyncConnectionPool
|
||||
from app.core.config import settings
|
||||
from typing import Optional
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Connection pool
|
||||
pool: Optional[AsyncConnectionPool] = None
|
||||
_init_lock = asyncio.Lock()
|
||||
_initializing = False
|
||||
|
||||
|
||||
async def init_db():
|
||||
"""Initialize database connection pool"""
|
||||
global pool
|
||||
global pool, _initializing
|
||||
|
||||
# If already initialized, return
|
||||
if pool is not None:
|
||||
return
|
||||
|
||||
# Use lock to prevent multiple simultaneous initializations
|
||||
async with _init_lock:
|
||||
# Double-check after acquiring lock
|
||||
if pool is not None:
|
||||
return
|
||||
|
||||
if _initializing:
|
||||
# Wait for other initialization to complete
|
||||
while _initializing:
|
||||
await asyncio.sleep(0.1)
|
||||
return
|
||||
|
||||
_initializing = True
|
||||
try:
|
||||
pool = AsyncConnectionPool(
|
||||
conninfo=settings.database_url,
|
||||
@@ -25,7 +46,10 @@ async def init_db():
|
||||
logger.info("Database connection pool initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database pool: {e}")
|
||||
pool = None
|
||||
raise
|
||||
finally:
|
||||
_initializing = False
|
||||
|
||||
|
||||
async def close_db():
|
||||
|
||||
@@ -3,7 +3,6 @@ import httpx
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from app.core.config import settings
|
||||
from app.core.database import pool
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -130,12 +129,12 @@ def extract_country_from_lidarr(artist: Dict) -> Optional[str]:
|
||||
async def upsert_media_item(source_kind: str, source_item_id: int, title: str,
|
||||
year: Optional[int], media_type: str, arr_raw: Dict):
|
||||
"""Upsert a media item into the database"""
|
||||
# Pool should be initialized on startup
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
from app.core.database import init_db, pool as db_pool
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise Exception("Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
# Upsert media item
|
||||
query = """
|
||||
@@ -273,12 +272,13 @@ async def sync_all_arrs() -> Dict:
|
||||
logger.error(f"Lidarr sync failed: {e}")
|
||||
results["lidarr"] = 0
|
||||
|
||||
# Update last sync time (pool should be initialized)
|
||||
if not pool:
|
||||
from app.core.database import init_db
|
||||
# Update last sync time
|
||||
from app.core.database import init_db, pool as db_pool
|
||||
await init_db()
|
||||
if db_pool is None:
|
||||
raise Exception("Database not available")
|
||||
|
||||
async with pool.connection() as conn:
|
||||
async with db_pool.connection() as conn:
|
||||
async with conn.cursor() as cur:
|
||||
for source_kind in ["radarr", "sonarr", "lidarr"]:
|
||||
await cur.execute(
|
||||
|
||||
Reference in New Issue
Block a user