Files
local-deep-research/tests/database/test_thread_local_session.py

603 lines
21 KiB
Python

"""Tests for database/thread_local_session.py."""
import threading
from unittest.mock import Mock, call, patch
from sqlalchemy.exc import OperationalError, PendingRollbackError
class TestThreadLocalSessionManager:
"""Tests for ThreadLocalSessionManager class."""
def test_init_creates_thread_local_storage(self):
"""Test that initialization creates thread-local storage."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
assert hasattr(manager, "_local")
assert isinstance(manager._local, threading.local)
def test_init_creates_credentials_tracking(self):
"""Test that initialization creates credentials tracking dict."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
assert hasattr(manager, "_thread_credentials")
assert isinstance(manager._thread_credentials, dict)
def test_init_creates_lock(self):
"""Test that initialization creates a threading lock."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
assert hasattr(manager, "_lock")
assert isinstance(manager._lock, type(threading.Lock()))
def test_get_session_creates_new_session(self):
"""Test that get_session creates a new session when none exists."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
with patch(
"local_deep_research.database.thread_local_session.db_manager"
) as mock_db:
mock_engine = Mock()
mock_session = Mock()
mock_db.open_user_database.return_value = mock_engine
mock_db.create_thread_safe_session_for_metrics.return_value = (
mock_session
)
result = manager.get_session("testuser", "testpass")
assert result is mock_session
mock_db.open_user_database.assert_called_once_with(
"testuser", "testpass"
)
def test_get_session_reuses_existing_session(self):
"""Test that get_session reuses an existing valid session."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_existing_session = Mock()
mock_existing_session.execute.return_value = None
# Manually set up an existing session
manager._local.session = mock_existing_session
manager._local.username = "testuser"
result = manager.get_session("testuser", "testpass")
# Should return the existing session
assert result is mock_existing_session
# Verify text() wrapper type and content explicitly
from sqlalchemy.sql.elements import TextClause
call_args = mock_existing_session.execute.call_args[0][0]
assert isinstance(call_args, TextClause)
assert call_args.text == "SELECT 1"
def test_get_session_creates_new_when_existing_invalid(self):
"""Test that get_session creates new session when existing is invalid."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_invalid_session = Mock()
mock_invalid_session.execute.side_effect = OperationalError(
"stmt", {}, Exception("Connection lost")
)
manager._local.session = mock_invalid_session
manager._local.username = "testuser"
with patch(
"local_deep_research.database.thread_local_session.db_manager"
) as mock_db:
mock_engine = Mock()
mock_new_session = Mock()
mock_db.open_user_database.return_value = mock_engine
mock_db.create_thread_safe_session_for_metrics.return_value = (
mock_new_session
)
result = manager.get_session("testuser", "testpass")
# Should create a new session
assert result is mock_new_session
def test_get_session_clears_cross_user_cached_session(self):
"""Cached session belonging to a different user is cleared and a fresh one created."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
stale_session = Mock()
manager._local.session = stale_session
manager._local.username = "alice" # cached for alice
with patch(
"local_deep_research.database.thread_local_session.db_manager"
) as mock_db:
mock_engine = Mock()
mock_db.open_user_database.return_value = mock_engine
new_session = Mock()
mock_db.create_thread_safe_session_for_metrics.return_value = (
new_session
)
result = manager.get_session("bob", "bobpass")
assert result is new_session
assert result is not stale_session # old session was cleared
stale_session.close.assert_called() # cleanup was called
mock_db.open_user_database.assert_called_once_with("bob", "bobpass")
def test_get_session_returns_none_on_db_open_failure(self):
"""Test that get_session returns None when database fails to open."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
with patch(
"local_deep_research.database.thread_local_session.db_manager"
) as mock_db:
mock_db.open_user_database.return_value = None
result = manager.get_session("testuser", "testpass")
assert result is None
def test_get_current_session_returns_none_when_no_session(self):
"""Test that get_current_session returns None when no session exists."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
result = manager.get_current_session()
assert result is None
def test_get_current_session_returns_existing_session(self):
"""Test that get_current_session returns the existing session."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_session = Mock()
manager._local.session = mock_session
result = manager.get_current_session()
assert result is mock_session
def test_cleanup_thread_cleans_current_thread(self):
"""Test that cleanup_thread cleans up the current thread's session."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_session = Mock()
manager._local.session = mock_session
manager._local.username = "testuser"
thread_id = threading.get_ident()
manager._thread_credentials[thread_id] = ("testuser", "testpass")
with patch(
"local_deep_research.database.thread_local_session.db_manager"
):
manager.cleanup_thread()
mock_session.close.assert_called_once()
assert manager._local.session is None
assert thread_id not in manager._thread_credentials
def test_get_session_recovers_from_pending_rollback_error(self):
"""Test that get_session recovers from PendingRollbackError via rollback.
Rollback is called twice in the recovery path: once to clear the
pending-rollback state so the retry SELECT 1 can run, and once
after that SELECT 1 succeeds to release the SHARED lock the
validation transaction held under DEFERRED isolation.
"""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_session = Mock()
# First execute raises PendingRollbackError, after rollback it succeeds
mock_session.execute.side_effect = [
PendingRollbackError("test"), # Initial validation fails
None, # Retry after rollback succeeds
]
manager._local.session = mock_session
manager._local.username = "testuser"
result = manager.get_session("testuser", "testpass")
# Should recover the same session via rollback
assert result is mock_session
assert mock_session.rollback.call_count == 2
assert mock_session.execute.call_count == 2
def test_get_session_recreates_when_rollback_recovery_fails(self):
"""Test that get_session recreates session when rollback recovery fails."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_old_session = Mock()
# Both execute calls fail — rollback doesn't help
mock_old_session.execute.side_effect = PendingRollbackError("test")
mock_old_session.rollback.side_effect = OperationalError(
"stmt", {}, Exception("rollback failed")
)
manager._local.session = mock_old_session
manager._local.username = "testuser"
thread_id = threading.get_ident()
manager._thread_credentials[thread_id] = ("testuser", "testpass")
with patch(
"local_deep_research.database.thread_local_session.db_manager"
) as mock_db:
mock_new_session = Mock()
mock_db.open_user_database.return_value = Mock()
mock_db.create_thread_safe_session_for_metrics.return_value = (
mock_new_session
)
result = manager.get_session("testuser", "testpass")
# Should fall back to creating a new session
assert result is mock_new_session
# close() must still be called on old session even though rollback() failed
mock_old_session.close.assert_called_once()
def test_cleanup_thread_session_calls_rollback_before_close(self):
"""Test that _cleanup_thread_session calls rollback before close."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_session = Mock()
manager._local.session = mock_session
manager._local.username = "testuser"
thread_id = threading.get_ident()
manager._thread_credentials[thread_id] = ("testuser", "testpass")
with patch(
"local_deep_research.database.thread_local_session.db_manager"
):
manager._cleanup_thread_session()
# Verify rollback is called before close
expected_calls = [call.rollback(), call.close()]
mock_session.assert_has_calls(expected_calls, any_order=False)
def test_cleanup_thread_session_still_closes_when_rollback_fails(self):
"""Test that close() is called even when rollback() raises."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
mock_session = Mock()
mock_session.rollback.side_effect = OperationalError(
"stmt", {}, Exception("dead connection")
)
manager._local.session = mock_session
manager._local.username = "testuser"
thread_id = threading.get_ident()
manager._thread_credentials[thread_id] = ("testuser", "testpass")
with patch(
"local_deep_research.database.thread_local_session.db_manager"
):
manager._cleanup_thread_session()
mock_session.close.assert_called_once()
assert manager._local.session is None
def test_cleanup_all_cleans_all_threads(self):
"""Test that cleanup_all cleans up all tracked sessions."""
from local_deep_research.database.thread_local_session import (
ThreadLocalSessionManager,
)
manager = ThreadLocalSessionManager()
manager._thread_credentials = {
1: ("user1", "pass1"),
2: ("user2", "pass2"),
}
# cleanup_all should iterate credentials and call cleanup_thread
# without touching any engine machinery.
manager.cleanup_all()
assert manager._thread_credentials == {}
class TestThreadSessionContext:
"""Tests for ThreadSessionContext context manager."""
def test_context_manager_returns_session(self):
"""Test that context manager returns a session on enter."""
from local_deep_research.database.thread_local_session import (
ThreadSessionContext,
)
with patch(
"local_deep_research.database.thread_local_session.get_metrics_session"
) as mock_get:
mock_session = Mock()
mock_get.return_value = mock_session
with ThreadSessionContext("testuser", "testpass") as session:
assert session is mock_session
def test_context_manager_stores_credentials(self):
"""Test that context manager stores username and password."""
from local_deep_research.database.thread_local_session import (
ThreadSessionContext,
)
ctx = ThreadSessionContext("myuser", "mypass")
assert ctx.username == "myuser"
assert ctx.password == "mypass"
class TestModuleFunctions:
"""Tests for module-level functions."""
def test_get_metrics_session_delegates_to_manager(self):
"""Test that get_metrics_session delegates to thread_session_manager."""
from local_deep_research.database.thread_local_session import (
get_metrics_session,
thread_session_manager,
)
with patch.object(thread_session_manager, "get_session") as mock_get:
mock_session = Mock()
mock_get.return_value = mock_session
result = get_metrics_session("testuser", "testpass")
mock_get.assert_called_once_with("testuser", "testpass")
assert result is mock_session
def test_get_current_thread_session_delegates_to_manager(self):
"""Test that get_current_thread_session delegates to manager."""
from local_deep_research.database.thread_local_session import (
get_current_thread_session,
thread_session_manager,
)
with patch.object(
thread_session_manager, "get_current_session"
) as mock_get:
mock_session = Mock()
mock_get.return_value = mock_session
result = get_current_thread_session()
mock_get.assert_called_once()
assert result is mock_session
def test_cleanup_current_thread_delegates_to_manager(self):
"""Test that cleanup_current_thread delegates to manager."""
from local_deep_research.database.thread_local_session import (
cleanup_current_thread,
thread_session_manager,
)
with patch.object(
thread_session_manager, "cleanup_thread"
) as mock_cleanup:
cleanup_current_thread()
mock_cleanup.assert_called_once()
class TestGlobalInstance:
"""Tests for the global thread_session_manager instance."""
def test_global_instance_exists(self):
"""Test that global thread_session_manager instance exists."""
from local_deep_research.database.thread_local_session import (
thread_session_manager,
)
assert thread_session_manager is not None
def test_global_instance_is_correct_type(self):
"""Test that global instance is ThreadLocalSessionManager."""
from local_deep_research.database.thread_local_session import (
thread_session_manager,
ThreadLocalSessionManager,
)
assert isinstance(thread_session_manager, ThreadLocalSessionManager)
class TestThreadCleanup:
"""Tests for thread_cleanup decorator / context manager."""
def _patch_all_cleanup(self):
"""Return a stack of patches for the three cleanup functions."""
return (
patch(
"local_deep_research.database.thread_local_session.cleanup_current_thread"
),
patch(
"local_deep_research.database.thread_local_session._ThreadCleanup.__exit__",
wraps=None,
),
)
def test_bare_decorator_runs_cleanup(self):
"""@thread_cleanup runs cleanup on exit and returns result."""
from local_deep_research.database.thread_local_session import (
thread_cleanup,
)
mock_cleanup = Mock()
with patch(
"local_deep_research.database.thread_local_session.cleanup_current_thread",
mock_cleanup,
):
@thread_cleanup
def worker():
return 42
result = worker()
assert result == 42
mock_cleanup.assert_called_once()
def test_factory_decorator_runs_cleanup(self):
"""@thread_cleanup() (with parens) runs cleanup on exit and returns result."""
from local_deep_research.database.thread_local_session import (
thread_cleanup,
)
mock_cleanup = Mock()
with patch(
"local_deep_research.database.thread_local_session.cleanup_current_thread",
mock_cleanup,
):
@thread_cleanup()
def worker():
return 99
result = worker()
assert result == 99
mock_cleanup.assert_called_once()
def test_context_manager_runs_cleanup(self):
"""with thread_cleanup(): runs cleanup on exit."""
from local_deep_research.database.thread_local_session import (
thread_cleanup,
)
mock_cleanup = Mock()
with patch(
"local_deep_research.database.thread_local_session.cleanup_current_thread",
mock_cleanup,
):
with thread_cleanup():
pass
mock_cleanup.assert_called_once()
def test_inline_wrapper_runs_cleanup(self):
"""thread_cleanup(func) as inline wrapper runs cleanup on exit."""
from local_deep_research.database.thread_local_session import (
thread_cleanup,
)
mock_cleanup = Mock()
def worker(x):
return x * 2
with patch(
"local_deep_research.database.thread_local_session.cleanup_current_thread",
mock_cleanup,
):
wrapped = thread_cleanup(worker)
result = wrapped(5)
assert result == 10
mock_cleanup.assert_called_once()
def test_cleanup_exception_logged_not_raised(self):
"""Cleanup exceptions are logged at debug level, not raised."""
from local_deep_research.database.thread_local_session import (
thread_cleanup,
)
with (
patch(
"local_deep_research.database.thread_local_session.cleanup_current_thread",
side_effect=RuntimeError("cleanup boom"),
),
patch(
"local_deep_research.database.thread_local_session.logger"
) as mock_logger,
):
@thread_cleanup
def worker():
return "ok"
result = worker()
assert result == "ok"
mock_logger.debug.assert_called()
def test_original_exception_propagates_when_cleanup_fails(self):
"""Original exceptions propagate even when cleanup fails."""
from local_deep_research.database.thread_local_session import (
thread_cleanup,
)
with (
patch(
"local_deep_research.database.thread_local_session.cleanup_current_thread",
side_effect=RuntimeError("cleanup boom"),
),
patch("local_deep_research.database.thread_local_session.logger"),
):
@thread_cleanup
def worker():
raise ValueError("original error")
try:
worker()
assert False, "Should have raised ValueError"
except ValueError as e:
assert str(e) == "original error"
def test_functools_wraps_metadata_preserved(self):
"""functools.wraps metadata preserved on decorated functions."""
from local_deep_research.database.thread_local_session import (
thread_cleanup,
)
@thread_cleanup
def my_worker():
"""My docstring."""
pass
assert my_worker.__name__ == "my_worker"
assert my_worker.__doc__ == "My docstring."