Refactor database connection handling in API endpoints
- Removed direct pool checks and replaced them with a centralized database initialization method in `init_db`. - Updated API endpoints in `admin.py`, `collection.py`, `pins.py`, and `watched.py` to ensure the database connection pool is initialized before usage. - Enhanced error handling to raise HTTP exceptions if the database is unavailable. - Improved the `init_db` function in `database.py` to prevent multiple simultaneous initializations using an asyncio lock.
This commit is contained in:
@@ -1,7 +1,6 @@
|
|||||||
"""Admin API endpoints"""
|
"""Admin API endpoints"""
|
||||||
from fastapi import APIRouter, HTTPException, Header
|
from fastapi import APIRouter, HTTPException, Header
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from app.core.database import pool
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.services.sync import sync_all_arrs
|
from app.services.sync import sync_all_arrs
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Collection API endpoints"""
|
"""Collection API endpoints"""
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from app.core.database import pool
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -15,17 +14,20 @@ async def get_collection_summary(
|
|||||||
Get collection summary by country and media type.
|
Get collection summary by country and media type.
|
||||||
Returns counts per country per media type.
|
Returns counts per country per media type.
|
||||||
"""
|
"""
|
||||||
# Pool should be initialized on startup, but check just in case
|
# Ensure pool is initialized
|
||||||
if not pool:
|
from app.core.database import init_db, pool as db_pool
|
||||||
from app.core.database import init_db
|
await init_db()
|
||||||
await init_db()
|
|
||||||
|
if db_pool is None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
|
|
||||||
# Parse types filter
|
# Parse types filter
|
||||||
type_filter = []
|
type_filter = []
|
||||||
if types:
|
if types:
|
||||||
type_filter = [t.strip() for t in types.split(",") if t.strip() in ["movie", "show", "music"]]
|
type_filter = [t.strip() for t in types.split(",") if t.strip() in ["movie", "show", "music"]]
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
# Build query
|
# Build query
|
||||||
query = """
|
query = """
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from fastapi import APIRouter, HTTPException
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from app.core.database import pool
|
from app.core.database import init_db, pool as db_pool
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -16,12 +16,11 @@ class PinCreate(BaseModel):
|
|||||||
@router.get("")
|
@router.get("")
|
||||||
async def list_pins():
|
async def list_pins():
|
||||||
"""List all manual pins"""
|
"""List all manual pins"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
query = """
|
query = """
|
||||||
SELECT id, country_code, label, pinned_at
|
SELECT id, country_code, label, pinned_at
|
||||||
@@ -46,12 +45,11 @@ async def list_pins():
|
|||||||
@router.post("")
|
@router.post("")
|
||||||
async def create_pin(pin: PinCreate):
|
async def create_pin(pin: PinCreate):
|
||||||
"""Create a new manual pin"""
|
"""Create a new manual pin"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
query = """
|
query = """
|
||||||
INSERT INTO moviemap.manual_pin (country_code, label)
|
INSERT INTO moviemap.manual_pin (country_code, label)
|
||||||
@@ -68,12 +66,11 @@ async def create_pin(pin: PinCreate):
|
|||||||
@router.delete("/{pin_id}")
|
@router.delete("/{pin_id}")
|
||||||
async def delete_pin(pin_id: UUID):
|
async def delete_pin(pin_id: UUID):
|
||||||
"""Delete a manual pin"""
|
"""Delete a manual pin"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
query = "DELETE FROM moviemap.manual_pin WHERE id = %s RETURNING id"
|
query = "DELETE FROM moviemap.manual_pin WHERE id = %s RETURNING id"
|
||||||
await cur.execute(query, (str(pin_id),))
|
await cur.execute(query, (str(pin_id),))
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from app.core.database import pool
|
from app.core.database import init_db, pool as db_pool
|
||||||
import json
|
import json
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -30,12 +30,11 @@ class WatchedItemUpdate(BaseModel):
|
|||||||
@router.get("/summary")
|
@router.get("/summary")
|
||||||
async def get_watched_summary():
|
async def get_watched_summary():
|
||||||
"""Get watched items summary by country"""
|
"""Get watched items summary by country"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
query = """
|
query = """
|
||||||
SELECT
|
SELECT
|
||||||
@@ -63,12 +62,11 @@ async def get_watched_summary():
|
|||||||
@router.get("")
|
@router.get("")
|
||||||
async def list_watched_items():
|
async def list_watched_items():
|
||||||
"""List all watched items"""
|
"""List all watched items"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
query = """
|
query = """
|
||||||
SELECT
|
SELECT
|
||||||
@@ -100,15 +98,14 @@ async def list_watched_items():
|
|||||||
@router.post("")
|
@router.post("")
|
||||||
async def create_watched_item(item: WatchedItemCreate):
|
async def create_watched_item(item: WatchedItemCreate):
|
||||||
"""Create a new watched item"""
|
"""Create a new watched item"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
if item.media_type not in ["movie", "show"]:
|
if item.media_type not in ["movie", "show"]:
|
||||||
raise HTTPException(status_code=400, detail="media_type must be 'movie' or 'show'")
|
raise HTTPException(status_code=400, detail="media_type must be 'movie' or 'show'")
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
query = """
|
query = """
|
||||||
INSERT INTO moviemap.watched_item
|
INSERT INTO moviemap.watched_item
|
||||||
@@ -136,12 +133,11 @@ async def create_watched_item(item: WatchedItemCreate):
|
|||||||
@router.patch("/{item_id}")
|
@router.patch("/{item_id}")
|
||||||
async def update_watched_item(item_id: UUID, item: WatchedItemUpdate):
|
async def update_watched_item(item_id: UUID, item: WatchedItemUpdate):
|
||||||
"""Update a watched item"""
|
"""Update a watched item"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
# Build dynamic update query
|
# Build dynamic update query
|
||||||
updates = []
|
updates = []
|
||||||
@@ -189,12 +185,11 @@ async def update_watched_item(item_id: UUID, item: WatchedItemUpdate):
|
|||||||
@router.delete("/{item_id}")
|
@router.delete("/{item_id}")
|
||||||
async def delete_watched_item(item_id: UUID):
|
async def delete_watched_item(item_id: UUID):
|
||||||
"""Delete a watched item"""
|
"""Delete a watched item"""
|
||||||
# Pool should be initialized on startup
|
await init_db()
|
||||||
if not pool:
|
if db_pool is None:
|
||||||
from app.core.database import init_db
|
raise HTTPException(status_code=503, detail="Database not available")
|
||||||
await init_db()
|
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
query = "DELETE FROM moviemap.watched_item WHERE id = %s RETURNING id"
|
query = "DELETE FROM moviemap.watched_item WHERE id = %s RETURNING id"
|
||||||
await cur.execute(query, (str(item_id),))
|
await cur.execute(query, (str(item_id),))
|
||||||
|
|||||||
@@ -4,28 +4,52 @@ from psycopg_pool import AsyncConnectionPool
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
import asyncio
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Connection pool
|
# Connection pool
|
||||||
pool: Optional[AsyncConnectionPool] = None
|
pool: Optional[AsyncConnectionPool] = None
|
||||||
|
_init_lock = asyncio.Lock()
|
||||||
|
_initializing = False
|
||||||
|
|
||||||
|
|
||||||
async def init_db():
|
async def init_db():
|
||||||
"""Initialize database connection pool"""
|
"""Initialize database connection pool"""
|
||||||
global pool
|
global pool, _initializing
|
||||||
try:
|
|
||||||
pool = AsyncConnectionPool(
|
# If already initialized, return
|
||||||
conninfo=settings.database_url,
|
if pool is not None:
|
||||||
min_size=1,
|
return
|
||||||
max_size=10,
|
|
||||||
open=False,
|
# Use lock to prevent multiple simultaneous initializations
|
||||||
)
|
async with _init_lock:
|
||||||
await pool.open()
|
# Double-check after acquiring lock
|
||||||
logger.info("Database connection pool initialized")
|
if pool is not None:
|
||||||
except Exception as e:
|
return
|
||||||
logger.error(f"Failed to initialize database pool: {e}")
|
|
||||||
raise
|
if _initializing:
|
||||||
|
# Wait for other initialization to complete
|
||||||
|
while _initializing:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
return
|
||||||
|
|
||||||
|
_initializing = True
|
||||||
|
try:
|
||||||
|
pool = AsyncConnectionPool(
|
||||||
|
conninfo=settings.database_url,
|
||||||
|
min_size=1,
|
||||||
|
max_size=10,
|
||||||
|
open=False,
|
||||||
|
)
|
||||||
|
await pool.open()
|
||||||
|
logger.info("Database connection pool initialized")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to initialize database pool: {e}")
|
||||||
|
pool = None
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
_initializing = False
|
||||||
|
|
||||||
|
|
||||||
async def close_db():
|
async def close_db():
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import httpx
|
|||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.database import pool
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -130,12 +129,12 @@ def extract_country_from_lidarr(artist: Dict) -> Optional[str]:
|
|||||||
async def upsert_media_item(source_kind: str, source_item_id: int, title: str,
|
async def upsert_media_item(source_kind: str, source_item_id: int, title: str,
|
||||||
year: Optional[int], media_type: str, arr_raw: Dict):
|
year: Optional[int], media_type: str, arr_raw: Dict):
|
||||||
"""Upsert a media item into the database"""
|
"""Upsert a media item into the database"""
|
||||||
# Pool should be initialized on startup
|
from app.core.database import init_db, pool as db_pool
|
||||||
if not pool:
|
await init_db()
|
||||||
from app.core.database import init_db
|
if db_pool is None:
|
||||||
await init_db()
|
raise Exception("Database not available")
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
# Upsert media item
|
# Upsert media item
|
||||||
query = """
|
query = """
|
||||||
@@ -273,12 +272,13 @@ async def sync_all_arrs() -> Dict:
|
|||||||
logger.error(f"Lidarr sync failed: {e}")
|
logger.error(f"Lidarr sync failed: {e}")
|
||||||
results["lidarr"] = 0
|
results["lidarr"] = 0
|
||||||
|
|
||||||
# Update last sync time (pool should be initialized)
|
# Update last sync time
|
||||||
if not pool:
|
from app.core.database import init_db, pool as db_pool
|
||||||
from app.core.database import init_db
|
await init_db()
|
||||||
await init_db()
|
if db_pool is None:
|
||||||
|
raise Exception("Database not available")
|
||||||
|
|
||||||
async with pool.connection() as conn:
|
async with db_pool.connection() as conn:
|
||||||
async with conn.cursor() as cur:
|
async with conn.cursor() as cur:
|
||||||
for source_kind in ["radarr", "sonarr", "lidarr"]:
|
for source_kind in ["radarr", "sonarr", "lidarr"]:
|
||||||
await cur.execute(
|
await cur.execute(
|
||||||
|
|||||||
Reference in New Issue
Block a user