mirror of
https://github.com/LearningCircuit/local-deep-research.git
synced 2026-06-16 03:51:07 +03:00
603 lines
21 KiB
Python
603 lines
21 KiB
Python
"""Tests for database/thread_local_session.py."""
|
|
|
|
import threading
|
|
from unittest.mock import Mock, call, patch
|
|
|
|
from sqlalchemy.exc import OperationalError, PendingRollbackError
|
|
|
|
|
|
class TestThreadLocalSessionManager:
|
|
"""Tests for ThreadLocalSessionManager class."""
|
|
|
|
def test_init_creates_thread_local_storage(self):
|
|
"""Test that initialization creates thread-local storage."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
assert hasattr(manager, "_local")
|
|
assert isinstance(manager._local, threading.local)
|
|
|
|
def test_init_creates_credentials_tracking(self):
|
|
"""Test that initialization creates credentials tracking dict."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
assert hasattr(manager, "_thread_credentials")
|
|
assert isinstance(manager._thread_credentials, dict)
|
|
|
|
def test_init_creates_lock(self):
|
|
"""Test that initialization creates a threading lock."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
assert hasattr(manager, "_lock")
|
|
assert isinstance(manager._lock, type(threading.Lock()))
|
|
|
|
def test_get_session_creates_new_session(self):
|
|
"""Test that get_session creates a new session when none exists."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
) as mock_db:
|
|
mock_engine = Mock()
|
|
mock_session = Mock()
|
|
mock_db.open_user_database.return_value = mock_engine
|
|
mock_db.create_thread_safe_session_for_metrics.return_value = (
|
|
mock_session
|
|
)
|
|
|
|
result = manager.get_session("testuser", "testpass")
|
|
|
|
assert result is mock_session
|
|
mock_db.open_user_database.assert_called_once_with(
|
|
"testuser", "testpass"
|
|
)
|
|
|
|
def test_get_session_reuses_existing_session(self):
|
|
"""Test that get_session reuses an existing valid session."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_existing_session = Mock()
|
|
mock_existing_session.execute.return_value = None
|
|
|
|
# Manually set up an existing session
|
|
manager._local.session = mock_existing_session
|
|
manager._local.username = "testuser"
|
|
|
|
result = manager.get_session("testuser", "testpass")
|
|
|
|
# Should return the existing session
|
|
assert result is mock_existing_session
|
|
# Verify text() wrapper type and content explicitly
|
|
from sqlalchemy.sql.elements import TextClause
|
|
|
|
call_args = mock_existing_session.execute.call_args[0][0]
|
|
assert isinstance(call_args, TextClause)
|
|
assert call_args.text == "SELECT 1"
|
|
|
|
def test_get_session_creates_new_when_existing_invalid(self):
|
|
"""Test that get_session creates new session when existing is invalid."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_invalid_session = Mock()
|
|
mock_invalid_session.execute.side_effect = OperationalError(
|
|
"stmt", {}, Exception("Connection lost")
|
|
)
|
|
|
|
manager._local.session = mock_invalid_session
|
|
manager._local.username = "testuser"
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
) as mock_db:
|
|
mock_engine = Mock()
|
|
mock_new_session = Mock()
|
|
mock_db.open_user_database.return_value = mock_engine
|
|
mock_db.create_thread_safe_session_for_metrics.return_value = (
|
|
mock_new_session
|
|
)
|
|
|
|
result = manager.get_session("testuser", "testpass")
|
|
|
|
# Should create a new session
|
|
assert result is mock_new_session
|
|
|
|
def test_get_session_clears_cross_user_cached_session(self):
|
|
"""Cached session belonging to a different user is cleared and a fresh one created."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
stale_session = Mock()
|
|
manager._local.session = stale_session
|
|
manager._local.username = "alice" # cached for alice
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
) as mock_db:
|
|
mock_engine = Mock()
|
|
mock_db.open_user_database.return_value = mock_engine
|
|
new_session = Mock()
|
|
mock_db.create_thread_safe_session_for_metrics.return_value = (
|
|
new_session
|
|
)
|
|
|
|
result = manager.get_session("bob", "bobpass")
|
|
|
|
assert result is new_session
|
|
assert result is not stale_session # old session was cleared
|
|
stale_session.close.assert_called() # cleanup was called
|
|
mock_db.open_user_database.assert_called_once_with("bob", "bobpass")
|
|
|
|
def test_get_session_returns_none_on_db_open_failure(self):
|
|
"""Test that get_session returns None when database fails to open."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
) as mock_db:
|
|
mock_db.open_user_database.return_value = None
|
|
|
|
result = manager.get_session("testuser", "testpass")
|
|
|
|
assert result is None
|
|
|
|
def test_get_current_session_returns_none_when_no_session(self):
|
|
"""Test that get_current_session returns None when no session exists."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
result = manager.get_current_session()
|
|
assert result is None
|
|
|
|
def test_get_current_session_returns_existing_session(self):
|
|
"""Test that get_current_session returns the existing session."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_session = Mock()
|
|
manager._local.session = mock_session
|
|
|
|
result = manager.get_current_session()
|
|
assert result is mock_session
|
|
|
|
def test_cleanup_thread_cleans_current_thread(self):
|
|
"""Test that cleanup_thread cleans up the current thread's session."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_session = Mock()
|
|
manager._local.session = mock_session
|
|
manager._local.username = "testuser"
|
|
thread_id = threading.get_ident()
|
|
manager._thread_credentials[thread_id] = ("testuser", "testpass")
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
):
|
|
manager.cleanup_thread()
|
|
|
|
mock_session.close.assert_called_once()
|
|
assert manager._local.session is None
|
|
assert thread_id not in manager._thread_credentials
|
|
|
|
def test_get_session_recovers_from_pending_rollback_error(self):
|
|
"""Test that get_session recovers from PendingRollbackError via rollback.
|
|
|
|
Rollback is called twice in the recovery path: once to clear the
|
|
pending-rollback state so the retry SELECT 1 can run, and once
|
|
after that SELECT 1 succeeds to release the SHARED lock the
|
|
validation transaction held under DEFERRED isolation.
|
|
"""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_session = Mock()
|
|
|
|
# First execute raises PendingRollbackError, after rollback it succeeds
|
|
mock_session.execute.side_effect = [
|
|
PendingRollbackError("test"), # Initial validation fails
|
|
None, # Retry after rollback succeeds
|
|
]
|
|
|
|
manager._local.session = mock_session
|
|
manager._local.username = "testuser"
|
|
|
|
result = manager.get_session("testuser", "testpass")
|
|
|
|
# Should recover the same session via rollback
|
|
assert result is mock_session
|
|
assert mock_session.rollback.call_count == 2
|
|
assert mock_session.execute.call_count == 2
|
|
|
|
def test_get_session_recreates_when_rollback_recovery_fails(self):
|
|
"""Test that get_session recreates session when rollback recovery fails."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_old_session = Mock()
|
|
|
|
# Both execute calls fail — rollback doesn't help
|
|
mock_old_session.execute.side_effect = PendingRollbackError("test")
|
|
mock_old_session.rollback.side_effect = OperationalError(
|
|
"stmt", {}, Exception("rollback failed")
|
|
)
|
|
|
|
manager._local.session = mock_old_session
|
|
manager._local.username = "testuser"
|
|
thread_id = threading.get_ident()
|
|
manager._thread_credentials[thread_id] = ("testuser", "testpass")
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
) as mock_db:
|
|
mock_new_session = Mock()
|
|
mock_db.open_user_database.return_value = Mock()
|
|
mock_db.create_thread_safe_session_for_metrics.return_value = (
|
|
mock_new_session
|
|
)
|
|
|
|
result = manager.get_session("testuser", "testpass")
|
|
|
|
# Should fall back to creating a new session
|
|
assert result is mock_new_session
|
|
# close() must still be called on old session even though rollback() failed
|
|
mock_old_session.close.assert_called_once()
|
|
|
|
def test_cleanup_thread_session_calls_rollback_before_close(self):
|
|
"""Test that _cleanup_thread_session calls rollback before close."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_session = Mock()
|
|
manager._local.session = mock_session
|
|
manager._local.username = "testuser"
|
|
thread_id = threading.get_ident()
|
|
manager._thread_credentials[thread_id] = ("testuser", "testpass")
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
):
|
|
manager._cleanup_thread_session()
|
|
|
|
# Verify rollback is called before close
|
|
expected_calls = [call.rollback(), call.close()]
|
|
mock_session.assert_has_calls(expected_calls, any_order=False)
|
|
|
|
def test_cleanup_thread_session_still_closes_when_rollback_fails(self):
|
|
"""Test that close() is called even when rollback() raises."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
mock_session = Mock()
|
|
mock_session.rollback.side_effect = OperationalError(
|
|
"stmt", {}, Exception("dead connection")
|
|
)
|
|
manager._local.session = mock_session
|
|
manager._local.username = "testuser"
|
|
thread_id = threading.get_ident()
|
|
manager._thread_credentials[thread_id] = ("testuser", "testpass")
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.db_manager"
|
|
):
|
|
manager._cleanup_thread_session()
|
|
|
|
mock_session.close.assert_called_once()
|
|
assert manager._local.session is None
|
|
|
|
def test_cleanup_all_cleans_all_threads(self):
|
|
"""Test that cleanup_all cleans up all tracked sessions."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
manager = ThreadLocalSessionManager()
|
|
manager._thread_credentials = {
|
|
1: ("user1", "pass1"),
|
|
2: ("user2", "pass2"),
|
|
}
|
|
|
|
# cleanup_all should iterate credentials and call cleanup_thread
|
|
# without touching any engine machinery.
|
|
manager.cleanup_all()
|
|
assert manager._thread_credentials == {}
|
|
|
|
|
|
class TestThreadSessionContext:
|
|
"""Tests for ThreadSessionContext context manager."""
|
|
|
|
def test_context_manager_returns_session(self):
|
|
"""Test that context manager returns a session on enter."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadSessionContext,
|
|
)
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.get_metrics_session"
|
|
) as mock_get:
|
|
mock_session = Mock()
|
|
mock_get.return_value = mock_session
|
|
|
|
with ThreadSessionContext("testuser", "testpass") as session:
|
|
assert session is mock_session
|
|
|
|
def test_context_manager_stores_credentials(self):
|
|
"""Test that context manager stores username and password."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
ThreadSessionContext,
|
|
)
|
|
|
|
ctx = ThreadSessionContext("myuser", "mypass")
|
|
assert ctx.username == "myuser"
|
|
assert ctx.password == "mypass"
|
|
|
|
|
|
class TestModuleFunctions:
|
|
"""Tests for module-level functions."""
|
|
|
|
def test_get_metrics_session_delegates_to_manager(self):
|
|
"""Test that get_metrics_session delegates to thread_session_manager."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
get_metrics_session,
|
|
thread_session_manager,
|
|
)
|
|
|
|
with patch.object(thread_session_manager, "get_session") as mock_get:
|
|
mock_session = Mock()
|
|
mock_get.return_value = mock_session
|
|
|
|
result = get_metrics_session("testuser", "testpass")
|
|
|
|
mock_get.assert_called_once_with("testuser", "testpass")
|
|
assert result is mock_session
|
|
|
|
def test_get_current_thread_session_delegates_to_manager(self):
|
|
"""Test that get_current_thread_session delegates to manager."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
get_current_thread_session,
|
|
thread_session_manager,
|
|
)
|
|
|
|
with patch.object(
|
|
thread_session_manager, "get_current_session"
|
|
) as mock_get:
|
|
mock_session = Mock()
|
|
mock_get.return_value = mock_session
|
|
|
|
result = get_current_thread_session()
|
|
|
|
mock_get.assert_called_once()
|
|
assert result is mock_session
|
|
|
|
def test_cleanup_current_thread_delegates_to_manager(self):
|
|
"""Test that cleanup_current_thread delegates to manager."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
cleanup_current_thread,
|
|
thread_session_manager,
|
|
)
|
|
|
|
with patch.object(
|
|
thread_session_manager, "cleanup_thread"
|
|
) as mock_cleanup:
|
|
cleanup_current_thread()
|
|
mock_cleanup.assert_called_once()
|
|
|
|
|
|
class TestGlobalInstance:
|
|
"""Tests for the global thread_session_manager instance."""
|
|
|
|
def test_global_instance_exists(self):
|
|
"""Test that global thread_session_manager instance exists."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_session_manager,
|
|
)
|
|
|
|
assert thread_session_manager is not None
|
|
|
|
def test_global_instance_is_correct_type(self):
|
|
"""Test that global instance is ThreadLocalSessionManager."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_session_manager,
|
|
ThreadLocalSessionManager,
|
|
)
|
|
|
|
assert isinstance(thread_session_manager, ThreadLocalSessionManager)
|
|
|
|
|
|
class TestThreadCleanup:
|
|
"""Tests for thread_cleanup decorator / context manager."""
|
|
|
|
def _patch_all_cleanup(self):
|
|
"""Return a stack of patches for the three cleanup functions."""
|
|
return (
|
|
patch(
|
|
"local_deep_research.database.thread_local_session.cleanup_current_thread"
|
|
),
|
|
patch(
|
|
"local_deep_research.database.thread_local_session._ThreadCleanup.__exit__",
|
|
wraps=None,
|
|
),
|
|
)
|
|
|
|
def test_bare_decorator_runs_cleanup(self):
|
|
"""@thread_cleanup runs cleanup on exit and returns result."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_cleanup,
|
|
)
|
|
|
|
mock_cleanup = Mock()
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.cleanup_current_thread",
|
|
mock_cleanup,
|
|
):
|
|
|
|
@thread_cleanup
|
|
def worker():
|
|
return 42
|
|
|
|
result = worker()
|
|
|
|
assert result == 42
|
|
mock_cleanup.assert_called_once()
|
|
|
|
def test_factory_decorator_runs_cleanup(self):
|
|
"""@thread_cleanup() (with parens) runs cleanup on exit and returns result."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_cleanup,
|
|
)
|
|
|
|
mock_cleanup = Mock()
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.cleanup_current_thread",
|
|
mock_cleanup,
|
|
):
|
|
|
|
@thread_cleanup()
|
|
def worker():
|
|
return 99
|
|
|
|
result = worker()
|
|
|
|
assert result == 99
|
|
mock_cleanup.assert_called_once()
|
|
|
|
def test_context_manager_runs_cleanup(self):
|
|
"""with thread_cleanup(): runs cleanup on exit."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_cleanup,
|
|
)
|
|
|
|
mock_cleanup = Mock()
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.cleanup_current_thread",
|
|
mock_cleanup,
|
|
):
|
|
with thread_cleanup():
|
|
pass
|
|
|
|
mock_cleanup.assert_called_once()
|
|
|
|
def test_inline_wrapper_runs_cleanup(self):
|
|
"""thread_cleanup(func) as inline wrapper runs cleanup on exit."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_cleanup,
|
|
)
|
|
|
|
mock_cleanup = Mock()
|
|
|
|
def worker(x):
|
|
return x * 2
|
|
|
|
with patch(
|
|
"local_deep_research.database.thread_local_session.cleanup_current_thread",
|
|
mock_cleanup,
|
|
):
|
|
wrapped = thread_cleanup(worker)
|
|
result = wrapped(5)
|
|
|
|
assert result == 10
|
|
mock_cleanup.assert_called_once()
|
|
|
|
def test_cleanup_exception_logged_not_raised(self):
|
|
"""Cleanup exceptions are logged at debug level, not raised."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_cleanup,
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"local_deep_research.database.thread_local_session.cleanup_current_thread",
|
|
side_effect=RuntimeError("cleanup boom"),
|
|
),
|
|
patch(
|
|
"local_deep_research.database.thread_local_session.logger"
|
|
) as mock_logger,
|
|
):
|
|
|
|
@thread_cleanup
|
|
def worker():
|
|
return "ok"
|
|
|
|
result = worker()
|
|
|
|
assert result == "ok"
|
|
mock_logger.debug.assert_called()
|
|
|
|
def test_original_exception_propagates_when_cleanup_fails(self):
|
|
"""Original exceptions propagate even when cleanup fails."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_cleanup,
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"local_deep_research.database.thread_local_session.cleanup_current_thread",
|
|
side_effect=RuntimeError("cleanup boom"),
|
|
),
|
|
patch("local_deep_research.database.thread_local_session.logger"),
|
|
):
|
|
|
|
@thread_cleanup
|
|
def worker():
|
|
raise ValueError("original error")
|
|
|
|
try:
|
|
worker()
|
|
assert False, "Should have raised ValueError"
|
|
except ValueError as e:
|
|
assert str(e) == "original error"
|
|
|
|
def test_functools_wraps_metadata_preserved(self):
|
|
"""functools.wraps metadata preserved on decorated functions."""
|
|
from local_deep_research.database.thread_local_session import (
|
|
thread_cleanup,
|
|
)
|
|
|
|
@thread_cleanup
|
|
def my_worker():
|
|
"""My docstring."""
|
|
pass
|
|
|
|
assert my_worker.__name__ == "my_worker"
|
|
assert my_worker.__doc__ == "My docstring."
|