mirror of
https://github.com/LearningCircuit/local-deep-research.git
synced 2026-06-16 03:51:07 +03:00
test: add 134 high-value tests for token counter, repository, meta search, and path validator (#2498)
Target modules with complex untested logic rather than chasing coverage %. Each test targets real bugs that could occur in production: - token_counter: model detection fallback chain, context overflow at 95% boundary, token extraction from 4 different response formats, error tracking - repository: knowledge truncation thresholds, LLM error type classification (5 error types), format_findings_to_text exception fallback, old_formatting path - meta_search_engine: engine filtering (meta/auto exclusion, API key combos, auto_search toggle), query analysis with specialized domains, content retrieval failure paths, engine instance caching - path_validator: traversal blocking, extension case sensitivity, null byte handling, sanitize_for_filesystem_ops (previously 0% covered), restricted prefix enforcement, home dir expansion edge cases
This commit is contained in:
477
tests/advanced_search_system/test_repository_logic.py
Normal file
477
tests/advanced_search_system/test_repository_logic.py
Normal file
@@ -0,0 +1,477 @@
|
||||
"""Tests for FindingsRepository — truncation, error detection, untested methods, formatting fallback."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from local_deep_research.advanced_search_system.findings.repository import (
|
||||
FindingsRepository,
|
||||
format_links,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_repo():
|
||||
"""Create a FindingsRepository with a mock LLM."""
|
||||
model = MagicMock()
|
||||
return FindingsRepository(model=model)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. format_links utility
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestFormatLinks:
|
||||
"""Verify the format_links helper."""
|
||||
|
||||
def test_single_link(self):
|
||||
"""Single link formatted with index."""
|
||||
result = format_links(
|
||||
[{"title": "Example", "url": "https://example.com"}]
|
||||
)
|
||||
assert "1. Example" in result
|
||||
assert "https://example.com" in result
|
||||
|
||||
def test_multiple_links(self):
|
||||
"""Multiple links numbered sequentially."""
|
||||
links = [
|
||||
{"title": "A", "url": "https://a.com"},
|
||||
{"title": "B", "url": "https://b.com"},
|
||||
]
|
||||
result = format_links(links)
|
||||
assert "1. A" in result
|
||||
assert "2. B" in result
|
||||
|
||||
def test_empty_links(self):
|
||||
"""Empty list yields empty string."""
|
||||
assert format_links([]) == ""
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. add_documents() — entirely untested method
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAddDocuments:
|
||||
"""Verify add_documents extends the document list."""
|
||||
|
||||
def test_add_documents_extends_list(self):
|
||||
"""Documents are appended to the internal list."""
|
||||
repo = _make_repo()
|
||||
docs = [
|
||||
Document(page_content="Doc 1"),
|
||||
Document(page_content="Doc 2"),
|
||||
]
|
||||
repo.add_documents(docs)
|
||||
|
||||
assert len(repo.documents) == 2
|
||||
assert repo.documents[0].page_content == "Doc 1"
|
||||
|
||||
def test_add_documents_accumulates(self):
|
||||
"""Multiple calls accumulate documents."""
|
||||
repo = _make_repo()
|
||||
repo.add_documents([Document(page_content="A")])
|
||||
repo.add_documents(
|
||||
[Document(page_content="B"), Document(page_content="C")]
|
||||
)
|
||||
|
||||
assert len(repo.documents) == 3
|
||||
|
||||
def test_add_empty_documents(self):
|
||||
"""Adding empty list is harmless."""
|
||||
repo = _make_repo()
|
||||
repo.add_documents([])
|
||||
assert len(repo.documents) == 0
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. set_questions_by_iteration() — entirely untested method
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSetQuestionsByIteration:
|
||||
"""Verify set_questions_by_iteration stores a copy."""
|
||||
|
||||
def test_stores_questions(self):
|
||||
"""Questions dict is stored by iteration number."""
|
||||
repo = _make_repo()
|
||||
questions = {1: ["Q1", "Q2"], 2: ["Q3"]}
|
||||
repo.set_questions_by_iteration(questions)
|
||||
|
||||
assert repo.questions_by_iteration == {1: ["Q1", "Q2"], 2: ["Q3"]}
|
||||
|
||||
def test_stores_copy_not_reference(self):
|
||||
"""Modification of original dict does not affect stored copy."""
|
||||
repo = _make_repo()
|
||||
questions = {1: ["Q1"]}
|
||||
repo.set_questions_by_iteration(questions)
|
||||
|
||||
questions[2] = ["Q2"]
|
||||
assert 2 not in repo.questions_by_iteration
|
||||
|
||||
def test_overwrites_previous(self):
|
||||
"""Calling again overwrites previous questions."""
|
||||
repo = _make_repo()
|
||||
repo.set_questions_by_iteration({1: ["old"]})
|
||||
repo.set_questions_by_iteration({1: ["new"]})
|
||||
|
||||
assert repo.questions_by_iteration[1] == ["new"]
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. format_findings_to_text — exception handler
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestFormatFindingsToText:
|
||||
"""Verify format_findings_to_text and its exception fallback."""
|
||||
|
||||
@patch(
|
||||
"local_deep_research.advanced_search_system.findings.repository.format_findings"
|
||||
)
|
||||
def test_successful_formatting(self, mock_format):
|
||||
"""Successful formatting returns the formatted report."""
|
||||
mock_format.return_value = "Formatted report"
|
||||
repo = _make_repo()
|
||||
repo.questions_by_iteration = {1: ["Q1"]}
|
||||
|
||||
result = repo.format_findings_to_text(
|
||||
[{"phase": "test", "content": "data"}],
|
||||
"synthesized content",
|
||||
)
|
||||
assert result == "Formatted report"
|
||||
mock_format.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"local_deep_research.advanced_search_system.findings.repository.format_findings"
|
||||
)
|
||||
def test_exception_returns_fallback_message(self, mock_format):
|
||||
"""When format_findings raises, fallback error message is returned."""
|
||||
mock_format.side_effect = RuntimeError("formatting broke")
|
||||
repo = _make_repo()
|
||||
|
||||
result = repo.format_findings_to_text([], "raw synthesis")
|
||||
|
||||
assert "Error during final formatting" in result
|
||||
assert "raw synthesis" in result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. Knowledge truncation in synthesize_findings
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestKnowledgeTruncation:
|
||||
"""Verify truncation logic when knowledge exceeds limits."""
|
||||
|
||||
def test_no_truncation_under_limit(self):
|
||||
"""Content under 24000 chars is not truncated."""
|
||||
repo = _make_repo()
|
||||
content = "x" * 20000
|
||||
repo.model.invoke.return_value = MagicMock(content="synthesized")
|
||||
|
||||
result = repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=[{"content": content}],
|
||||
)
|
||||
# Should succeed without truncation marker
|
||||
assert "content truncated" not in result
|
||||
|
||||
def test_truncation_over_24000_chars(self):
|
||||
"""Content over 24000 chars triggers truncation (also needs estimated_tokens > 12000)."""
|
||||
repo = _make_repo()
|
||||
# Need > 48000 chars so estimated_tokens (len/4) > max_safe_tokens (12000)
|
||||
# AND len > 24000 for the actual truncation
|
||||
content = "A" * 25000 + "B" * 25000 # 50000 chars total
|
||||
|
||||
# Capture the prompt passed to model.invoke
|
||||
captured_prompts = []
|
||||
|
||||
def capture_invoke(prompt):
|
||||
captured_prompts.append(prompt)
|
||||
return MagicMock(content="result")
|
||||
|
||||
repo.model.invoke.side_effect = capture_invoke
|
||||
|
||||
repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=[{"content": content}],
|
||||
)
|
||||
|
||||
# The prompt should contain the truncation marker
|
||||
assert len(captured_prompts) == 1
|
||||
assert "content truncated due to length" in captured_prompts[0]
|
||||
|
||||
def test_exactly_24000_chars_no_truncation(self):
|
||||
"""At 24000 chars estimated_tokens=6000 < max_safe_tokens=12000, so no truncation."""
|
||||
repo = _make_repo()
|
||||
content = "x" * 24000
|
||||
|
||||
captured_prompts = []
|
||||
|
||||
def capture_invoke(prompt):
|
||||
captured_prompts.append(prompt)
|
||||
return MagicMock(content="result")
|
||||
|
||||
repo.model.invoke.side_effect = capture_invoke
|
||||
|
||||
repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=[{"content": content}],
|
||||
)
|
||||
|
||||
assert len(captured_prompts) == 1
|
||||
assert "content truncated" not in captured_prompts[0]
|
||||
|
||||
def test_truncation_preserves_start_and_end(self):
|
||||
"""Truncated content keeps first 12000 and last 12000 chars."""
|
||||
repo = _make_repo()
|
||||
# Need > 48000 chars for estimated_tokens > max_safe_tokens AND > 24000 for truncation
|
||||
content = "START" + "x" * 50000 + "END__" # well over both thresholds
|
||||
|
||||
captured_prompts = []
|
||||
|
||||
def capture_invoke(prompt):
|
||||
captured_prompts.append(prompt)
|
||||
return MagicMock(content="result")
|
||||
|
||||
repo.model.invoke.side_effect = capture_invoke
|
||||
|
||||
repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=[{"content": content}],
|
||||
)
|
||||
|
||||
prompt = captured_prompts[0]
|
||||
assert "START" in prompt # beginning preserved
|
||||
assert "END__" in prompt # ending preserved
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. LLM exception type detection (lines 430-466)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLLMExceptionTypeDetection:
|
||||
"""Verify error classification in synthesize_findings exception handler."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_msg, expected_fragment",
|
||||
[
|
||||
("Request timed out after 30s", "LLM timeout"),
|
||||
("too many tokens in request", "token limit"),
|
||||
("context length exceeded", "token limit"),
|
||||
("rate limit exceeded", "rate limit"),
|
||||
("rate_limit_error", "rate limit"),
|
||||
("connection refused", "connection issues"),
|
||||
("network error occurred", "connection issues"),
|
||||
("invalid api key provided", "authentication"),
|
||||
("authentication failed", "authentication"),
|
||||
],
|
||||
)
|
||||
def test_error_type_detection(self, error_msg, expected_fragment):
|
||||
"""Each error message pattern maps to the correct user-facing message."""
|
||||
repo = _make_repo()
|
||||
repo.model.invoke.side_effect = Exception(error_msg)
|
||||
|
||||
result = repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=[{"content": "data"}],
|
||||
)
|
||||
|
||||
assert expected_fragment in result
|
||||
|
||||
def test_unknown_error_includes_details(self):
|
||||
"""Unknown error type includes the actual error message."""
|
||||
repo = _make_repo()
|
||||
repo.model.invoke.side_effect = Exception(
|
||||
"something completely unexpected"
|
||||
)
|
||||
|
||||
result = repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=[{"content": "data"}],
|
||||
)
|
||||
|
||||
assert "something completely unexpected" in result
|
||||
assert "LLM error:" in result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7. old_formatting=True path
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOldFormattingPath:
|
||||
"""Verify the old_formatting=True branch in synthesize_findings."""
|
||||
|
||||
@patch(
|
||||
"local_deep_research.advanced_search_system.findings.repository.format_findings"
|
||||
)
|
||||
def test_old_formatting_converts_strings(self, mock_format):
|
||||
"""String findings are converted to dicts with phase labels."""
|
||||
mock_format.return_value = "old-formatted"
|
||||
repo = _make_repo()
|
||||
repo.questions_by_iteration = {1: ["Q1"]}
|
||||
|
||||
result = repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=["finding one", "finding two"],
|
||||
old_formatting=True,
|
||||
)
|
||||
|
||||
assert result == "old-formatted"
|
||||
call_args = mock_format.call_args
|
||||
findings_list = (
|
||||
call_args.kwargs.get("findings_list")
|
||||
or call_args[1].get("findings_list")
|
||||
or call_args[0][0]
|
||||
)
|
||||
assert findings_list[0]["phase"] == "Finding 1"
|
||||
assert findings_list[1]["phase"] == "Finding 2"
|
||||
|
||||
@patch(
|
||||
"local_deep_research.advanced_search_system.findings.repository.format_findings"
|
||||
)
|
||||
def test_old_formatting_preserves_dicts(self, mock_format):
|
||||
"""Dict findings are passed through as-is."""
|
||||
mock_format.return_value = "old-formatted"
|
||||
repo = _make_repo()
|
||||
|
||||
repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=[],
|
||||
findings=[{"phase": "Analysis", "content": "data"}],
|
||||
old_formatting=True,
|
||||
)
|
||||
|
||||
call_args = mock_format.call_args
|
||||
findings_list = call_args.kwargs.get("findings_list") or call_args[0][0]
|
||||
assert findings_list[0]["phase"] == "Analysis"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 8. accumulated_knowledge handling
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAccumulatedKnowledge:
|
||||
"""Verify accumulated_knowledge parameter behavior."""
|
||||
|
||||
def test_none_accumulated_joins_findings(self):
|
||||
"""When accumulated_knowledge is None, findings are joined."""
|
||||
repo = _make_repo()
|
||||
|
||||
captured_prompts = []
|
||||
|
||||
def capture_invoke(prompt):
|
||||
captured_prompts.append(prompt)
|
||||
return MagicMock(content="result")
|
||||
|
||||
repo.model.invoke.side_effect = capture_invoke
|
||||
|
||||
repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=["sq1"],
|
||||
findings=[{"content": "part1"}, {"content": "part2"}],
|
||||
accumulated_knowledge=None,
|
||||
)
|
||||
|
||||
# Both parts should appear in the prompt
|
||||
prompt = captured_prompts[0]
|
||||
assert "part1" in prompt
|
||||
assert "part2" in prompt
|
||||
|
||||
def test_response_with_content_attr(self):
|
||||
"""Response object with .content attribute is handled."""
|
||||
repo = _make_repo()
|
||||
repo.model.invoke.return_value = MagicMock(content="synthesized answer")
|
||||
|
||||
result = repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=[],
|
||||
findings=[{"content": "data"}],
|
||||
)
|
||||
assert result == "synthesized answer"
|
||||
|
||||
def test_response_without_content_attr(self):
|
||||
"""String response (no .content) is converted via str()."""
|
||||
repo = _make_repo()
|
||||
|
||||
class PlainResponse:
|
||||
def __str__(self):
|
||||
return "string response"
|
||||
|
||||
repo.model.invoke.return_value = PlainResponse()
|
||||
|
||||
result = repo.synthesize_findings(
|
||||
query="test",
|
||||
sub_queries=[],
|
||||
findings=[{"content": "data"}],
|
||||
)
|
||||
assert "string response" in result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 9. add_finding edge cases
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAddFinding:
|
||||
"""Verify add_finding behavior with different input types."""
|
||||
|
||||
def test_string_finding_converted_to_dict(self):
|
||||
"""String finding is wrapped in a standard dict."""
|
||||
repo = _make_repo()
|
||||
repo.add_finding("query1", "some text")
|
||||
|
||||
findings = repo.get_findings("query1")
|
||||
assert len(findings) == 1
|
||||
assert findings[0]["content"] == "some text"
|
||||
assert findings[0]["phase"] == "Synthesis"
|
||||
|
||||
def test_dict_finding_stored_directly(self):
|
||||
"""Dict finding is stored as-is."""
|
||||
repo = _make_repo()
|
||||
finding = {"phase": "Analysis", "content": "data", "extra": "field"}
|
||||
repo.add_finding("query1", finding)
|
||||
|
||||
assert repo.get_findings("query1")[0]["extra"] == "field"
|
||||
|
||||
def test_final_synthesis_creates_synthesis_key(self):
|
||||
"""Finding with 'Final synthesis' phase creates a _synthesis key."""
|
||||
repo = _make_repo()
|
||||
repo.add_finding(
|
||||
"query1",
|
||||
{"phase": "Final synthesis", "content": "final answer"},
|
||||
)
|
||||
|
||||
synthesis = repo.get_findings("query1_synthesis")
|
||||
assert len(synthesis) == 1
|
||||
assert synthesis[0]["content"] == "final answer"
|
||||
|
||||
def test_clear_findings(self):
|
||||
"""clear_findings removes all findings for a query."""
|
||||
repo = _make_repo()
|
||||
repo.add_finding("q", "data")
|
||||
repo.clear_findings("q")
|
||||
|
||||
assert repo.get_findings("q") == []
|
||||
|
||||
def test_clear_nonexistent_query_is_noop(self):
|
||||
"""Clearing a query that doesn't exist does nothing."""
|
||||
repo = _make_repo()
|
||||
repo.clear_findings("nonexistent") # should not raise
|
||||
523
tests/metrics/test_token_counter_logic.py
Normal file
523
tests/metrics/test_token_counter_logic.py
Normal file
@@ -0,0 +1,523 @@
|
||||
"""Tests for token_counter.py — model detection, context overflow, token extraction, and rate limiting metrics."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
from local_deep_research.metrics.token_counter import TokenCountingCallback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_callback(**kw):
|
||||
"""Create a TokenCountingCallback with sensible defaults."""
|
||||
return TokenCountingCallback(
|
||||
research_id=kw.get("research_id"),
|
||||
research_context=kw.get("research_context", {}),
|
||||
)
|
||||
|
||||
|
||||
def _make_llm_result(llm_output=None, generations=None):
|
||||
"""Build a minimal LLMResult for on_llm_end tests."""
|
||||
result = MagicMock(spec=LLMResult)
|
||||
result.llm_output = llm_output
|
||||
if generations is not None:
|
||||
result.generations = generations
|
||||
else:
|
||||
result.generations = []
|
||||
return result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. Model detection fallback chain
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestModelDetectionFallback:
|
||||
"""Verify the on_llm_start model-name resolution chain."""
|
||||
|
||||
def test_preset_model_takes_priority(self):
|
||||
"""Preset model should override anything in serialized/kwargs."""
|
||||
cb = _make_callback()
|
||||
cb.preset_model = "my-preset-model"
|
||||
cb.preset_provider = "my-provider"
|
||||
|
||||
serialized = {"kwargs": {"model": "should-be-ignored"}}
|
||||
cb.on_llm_start(
|
||||
serialized, ["hello"], invocation_params={"model": "also-ignored"}
|
||||
)
|
||||
|
||||
assert cb.current_model == "my-preset-model"
|
||||
assert cb.current_provider == "my-provider"
|
||||
|
||||
def test_model_from_invocation_params(self):
|
||||
"""Model extracted from invocation_params when preset is absent."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI"},
|
||||
["hello"],
|
||||
invocation_params={"model": "gpt-4"},
|
||||
)
|
||||
assert cb.current_model == "gpt-4"
|
||||
|
||||
def test_model_from_kwargs_directly(self):
|
||||
"""Model extracted from kwargs when invocation_params lacks it."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI"},
|
||||
["hello"],
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
assert cb.current_model == "gpt-3.5-turbo"
|
||||
|
||||
def test_model_from_serialized_kwargs(self):
|
||||
"""Model extracted from serialized['kwargs'] as next fallback."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI", "kwargs": {"model": "from-serialized"}},
|
||||
["hello"],
|
||||
)
|
||||
assert cb.current_model == "from-serialized"
|
||||
|
||||
def test_model_from_serialized_name(self):
|
||||
"""serialized['name'] used when kwargs has no model."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI", "name": "my-name", "kwargs": {}},
|
||||
["hello"],
|
||||
)
|
||||
assert cb.current_model == "my-name"
|
||||
|
||||
def test_ollama_specific_extraction(self):
|
||||
"""ChatOllama type triggers Ollama-specific model extraction."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
|
||||
["hello"],
|
||||
)
|
||||
assert cb.current_model == "llama3"
|
||||
|
||||
def test_ollama_fallback_to_ollama_string(self):
|
||||
"""ChatOllama without model in kwargs falls back to 'ollama'."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {}},
|
||||
["hello"],
|
||||
)
|
||||
assert cb.current_model == "ollama"
|
||||
|
||||
def test_final_fallback_to_type(self):
|
||||
"""When no model name found, _type string is used."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "SomeCustomLLM", "kwargs": {}},
|
||||
["hello"],
|
||||
)
|
||||
assert cb.current_model == "SomeCustomLLM"
|
||||
|
||||
def test_final_fallback_to_unknown(self):
|
||||
"""When nothing at all, model is 'unknown'."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start({}, ["hello"])
|
||||
assert cb.current_model == "unknown"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. Provider detection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestProviderDetection:
|
||||
"""Verify provider extraction from serialized type strings."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"type_str, expected",
|
||||
[
|
||||
("ChatOllama", "ollama"),
|
||||
("ChatOpenAI", "openai"),
|
||||
("ChatAnthropic", "anthropic"),
|
||||
],
|
||||
)
|
||||
def test_known_providers(self, type_str, expected):
|
||||
"""Known provider types are mapped correctly."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": type_str, "kwargs": {"model": "test"}},
|
||||
["hello"],
|
||||
)
|
||||
assert cb.current_provider == expected
|
||||
|
||||
def test_unknown_provider_from_kwargs(self):
|
||||
"""Unknown type falls back to provider kwarg."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "CustomLLM", "kwargs": {"model": "test"}},
|
||||
["hello"],
|
||||
provider="custom-prov",
|
||||
)
|
||||
assert cb.current_provider == "custom-prov"
|
||||
|
||||
def test_unknown_provider_no_kwarg(self):
|
||||
"""No _type and no provider kwarg yields 'unknown'."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start({"kwargs": {"model": "test"}}, ["hello"])
|
||||
assert cb.current_provider == "unknown"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. Token extraction from different response formats
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestTokenExtraction:
|
||||
"""Verify on_llm_end token extraction from various LLMResult shapes."""
|
||||
|
||||
def test_tokens_from_llm_output_token_usage(self):
|
||||
"""Standard token_usage dict in llm_output."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
|
||||
)
|
||||
|
||||
result = _make_llm_result(
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
}
|
||||
},
|
||||
)
|
||||
cb.on_llm_end(result)
|
||||
|
||||
assert cb.counts["total_prompt_tokens"] == 10
|
||||
assert cb.counts["total_completion_tokens"] == 20
|
||||
assert cb.counts["total_tokens"] == 30
|
||||
|
||||
def test_tokens_from_llm_output_usage_key(self):
|
||||
"""Alternative 'usage' key in llm_output."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
|
||||
)
|
||||
|
||||
result = _make_llm_result(
|
||||
llm_output={
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 15,
|
||||
"total_tokens": 20,
|
||||
}
|
||||
},
|
||||
)
|
||||
cb.on_llm_end(result)
|
||||
|
||||
assert cb.counts["total_tokens"] == 20
|
||||
|
||||
def test_tokens_from_usage_metadata_in_generations(self):
|
||||
"""Ollama-style usage_metadata in generation messages."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.usage_metadata = {
|
||||
"input_tokens": 8,
|
||||
"output_tokens": 12,
|
||||
"total_tokens": 20,
|
||||
}
|
||||
msg.response_metadata = {}
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
result = _make_llm_result(generations=[[gen]])
|
||||
|
||||
cb.on_llm_end(result)
|
||||
assert cb.counts["total_tokens"] == 20
|
||||
|
||||
def test_tokens_from_response_metadata_ollama(self):
|
||||
"""Ollama response_metadata with prompt_eval_count / eval_count."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.usage_metadata = None
|
||||
msg.response_metadata = {
|
||||
"prompt_eval_count": 100,
|
||||
"eval_count": 50,
|
||||
"total_duration": 1000,
|
||||
}
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
result = _make_llm_result(generations=[[gen]])
|
||||
|
||||
cb.on_llm_end(result)
|
||||
assert cb.counts["total_prompt_tokens"] == 100
|
||||
assert cb.counts["total_completion_tokens"] == 50
|
||||
assert cb.counts["total_tokens"] == 150
|
||||
|
||||
def test_missing_usage_entirely(self):
|
||||
"""No usage info at all — counts stay zero."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
|
||||
)
|
||||
|
||||
result = _make_llm_result(llm_output=None, generations=[])
|
||||
cb.on_llm_end(result)
|
||||
|
||||
assert cb.counts["total_tokens"] == 0
|
||||
|
||||
def test_empty_generations_list(self):
|
||||
"""generations present but empty list."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
|
||||
)
|
||||
|
||||
result = _make_llm_result(llm_output=None, generations=[[]])
|
||||
cb.on_llm_end(result)
|
||||
|
||||
assert cb.counts["total_tokens"] == 0
|
||||
|
||||
def test_usage_metadata_present_but_response_metadata_absent(self):
|
||||
"""Generation has usage_metadata but no response_metadata attr."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
|
||||
)
|
||||
|
||||
msg = MagicMock(spec=["usage_metadata"]) # no response_metadata
|
||||
msg.usage_metadata = {
|
||||
"input_tokens": 3,
|
||||
"output_tokens": 7,
|
||||
"total_tokens": 10,
|
||||
}
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
result = _make_llm_result(generations=[[gen]])
|
||||
|
||||
cb.on_llm_end(result)
|
||||
assert cb.counts["total_tokens"] == 10
|
||||
|
||||
def test_by_model_accumulation(self):
|
||||
"""Multiple calls accumulate per-model stats."""
|
||||
cb = _make_callback()
|
||||
|
||||
for _ in range(3):
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
|
||||
)
|
||||
result = _make_llm_result(
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
},
|
||||
)
|
||||
cb.on_llm_end(result)
|
||||
|
||||
assert cb.counts["by_model"]["gpt-4"]["calls"] == 3
|
||||
assert cb.counts["by_model"]["gpt-4"]["total_tokens"] == 45
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. Context overflow detection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestContextOverflowDetection:
|
||||
"""Verify context overflow detection in on_llm_end via Ollama metrics."""
|
||||
|
||||
def _trigger_overflow(
|
||||
self, context_limit, prompt_eval_count, original_prompt_estimate
|
||||
):
|
||||
"""Helper to set up and trigger context overflow path."""
|
||||
cb = _make_callback(research_context={"context_limit": context_limit})
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
|
||||
["x" * (original_prompt_estimate * 4)], # chars = tokens * 4
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.usage_metadata = None
|
||||
msg.response_metadata = {
|
||||
"prompt_eval_count": prompt_eval_count,
|
||||
"eval_count": 10,
|
||||
}
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
result = _make_llm_result(generations=[[gen]])
|
||||
cb.on_llm_end(result)
|
||||
return cb
|
||||
|
||||
def test_overflow_detected_at_95_percent(self):
|
||||
"""Context overflow flagged when prompt_eval_count >= 95% of limit."""
|
||||
cb = self._trigger_overflow(
|
||||
context_limit=1000,
|
||||
prompt_eval_count=960, # 96% > 95%
|
||||
original_prompt_estimate=1200,
|
||||
)
|
||||
assert cb.context_truncated is True
|
||||
assert cb.tokens_truncated == 1200 - 960
|
||||
|
||||
def test_no_overflow_below_threshold(self):
|
||||
"""No overflow when below 95% threshold."""
|
||||
cb = self._trigger_overflow(
|
||||
context_limit=1000,
|
||||
prompt_eval_count=940, # 94% < 95%
|
||||
original_prompt_estimate=1200,
|
||||
)
|
||||
assert cb.context_truncated is False
|
||||
|
||||
def test_exact_95_boundary(self):
|
||||
"""Exact 95% threshold should trigger overflow."""
|
||||
cb = self._trigger_overflow(
|
||||
context_limit=1000,
|
||||
prompt_eval_count=950, # exactly 95%
|
||||
original_prompt_estimate=1200,
|
||||
)
|
||||
assert cb.context_truncated is True
|
||||
|
||||
def test_no_overflow_when_context_limit_none(self):
|
||||
"""No overflow detection when context_limit is not set."""
|
||||
cb = _make_callback(research_context={})
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
|
||||
["hello"],
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.usage_metadata = None
|
||||
msg.response_metadata = {"prompt_eval_count": 9999, "eval_count": 10}
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
result = _make_llm_result(generations=[[gen]])
|
||||
cb.on_llm_end(result)
|
||||
|
||||
assert cb.context_truncated is False
|
||||
|
||||
def test_truncation_ratio_zero_prompt_estimate(self):
|
||||
"""When original_prompt_estimate <= prompt_eval_count, tokens_truncated is 0."""
|
||||
cb = self._trigger_overflow(
|
||||
context_limit=100,
|
||||
prompt_eval_count=96, # 96% > 95%
|
||||
original_prompt_estimate=90, # less than eval count
|
||||
)
|
||||
# context_truncated is set, but tokens_truncated stays 0
|
||||
# because original_prompt_estimate <= prompt_eval_count
|
||||
assert cb.context_truncated is True
|
||||
assert cb.tokens_truncated == 0
|
||||
|
||||
def test_prompt_eval_count_zero_skips_overflow(self):
|
||||
"""prompt_eval_count == 0 should not trigger overflow check."""
|
||||
cb = _make_callback(research_context={"context_limit": 1000})
|
||||
cb.on_llm_start(
|
||||
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
|
||||
["hello"],
|
||||
)
|
||||
|
||||
msg = MagicMock()
|
||||
msg.usage_metadata = None
|
||||
msg.response_metadata = {"prompt_eval_count": 0, "eval_count": 10}
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
result = _make_llm_result(generations=[[gen]])
|
||||
cb.on_llm_end(result)
|
||||
|
||||
assert cb.context_truncated is False
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. on_llm_error tracking
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestOnLLMError:
|
||||
"""Verify error tracking in on_llm_error."""
|
||||
|
||||
def test_error_sets_status_and_type(self):
|
||||
"""Error callback records status and error type."""
|
||||
cb = _make_callback()
|
||||
cb.start_time = time.time()
|
||||
|
||||
cb.on_llm_error(ValueError("boom"))
|
||||
|
||||
assert cb.success_status == "error"
|
||||
assert cb.error_type == "ValueError"
|
||||
assert cb.response_time_ms is not None
|
||||
|
||||
def test_error_saves_to_db_when_research_id_set(self):
|
||||
"""Error with research_id triggers _save_to_db with zero tokens."""
|
||||
cb = _make_callback(research_id="test-123")
|
||||
cb.start_time = time.time()
|
||||
|
||||
with patch.object(cb, "_save_to_db") as mock_save:
|
||||
cb.on_llm_error(RuntimeError("fail"))
|
||||
mock_save.assert_called_once_with(0, 0)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. _get_context_overflow_fields
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestContextOverflowFields:
|
||||
"""Verify _get_context_overflow_fields output."""
|
||||
|
||||
def test_fields_when_no_overflow(self):
|
||||
"""Fields should indicate no truncation."""
|
||||
cb = _make_callback()
|
||||
fields = cb._get_context_overflow_fields()
|
||||
|
||||
assert fields["context_truncated"] is False
|
||||
assert fields["tokens_truncated"] is None
|
||||
assert fields["truncation_ratio"] is None
|
||||
|
||||
def test_fields_when_overflow_detected(self):
|
||||
"""Fields should include truncation details."""
|
||||
cb = _make_callback()
|
||||
cb.context_limit = 1000
|
||||
cb.context_truncated = True
|
||||
cb.tokens_truncated = 200
|
||||
cb.truncation_ratio = 0.2
|
||||
cb.ollama_metrics = {"prompt_eval_count": 800, "eval_count": 50}
|
||||
|
||||
fields = cb._get_context_overflow_fields()
|
||||
|
||||
assert fields["context_truncated"] is True
|
||||
assert fields["tokens_truncated"] == 200
|
||||
assert fields["truncation_ratio"] == 0.2
|
||||
assert fields["ollama_prompt_eval_count"] == 800
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7. Prompt estimate calculation
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestPromptEstimate:
|
||||
"""Verify original_prompt_estimate is calculated from prompt chars."""
|
||||
|
||||
def test_estimate_from_multiple_prompts(self):
|
||||
"""Estimate is sum of chars // 4."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start(
|
||||
{}, ["aaaa", "bbbbbbbb"]
|
||||
) # 4 + 8 = 12 chars => 3 tokens
|
||||
assert cb.original_prompt_estimate == 3
|
||||
|
||||
def test_estimate_empty_prompts(self):
|
||||
"""Empty prompts list yields 0 estimate."""
|
||||
cb = _make_callback()
|
||||
cb.on_llm_start({}, [])
|
||||
assert cb.original_prompt_estimate == 0
|
||||
433
tests/search_engines/test_meta_search_logic.py
Normal file
433
tests/search_engines/test_meta_search_logic.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""Tests for MetaSearchEngine — engine filtering, API key validation, query analysis, content retrieval."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from local_deep_research.web_search_engines.engines.meta_search_engine import (
|
||||
MetaSearchEngine,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_settings_snapshot(engines, auto_settings=None):
|
||||
"""Build a settings_snapshot dict that MetaSearchEngine expects.
|
||||
|
||||
Args:
|
||||
engines: dict mapping engine_name -> {config keys like requires_api_key, api_key, ...}
|
||||
auto_settings: dict of per-engine use_in_auto_search overrides (default all True)
|
||||
"""
|
||||
snapshot = {}
|
||||
for name, config in engines.items():
|
||||
for key, value in config.items():
|
||||
snapshot[f"search.engine.web.{name}.{key}"] = value
|
||||
# Enable for auto search by default unless overridden
|
||||
if auto_settings and name in auto_settings:
|
||||
snapshot[f"search.engine.web.{name}.use_in_auto_search"] = (
|
||||
auto_settings[name]
|
||||
)
|
||||
else:
|
||||
snapshot[f"search.engine.web.{name}.use_in_auto_search"] = True
|
||||
return snapshot
|
||||
|
||||
|
||||
def _make_meta_engine(settings_snapshot, use_api_key_services=True, llm=None):
|
||||
"""Create MetaSearchEngine with mocked dependencies."""
|
||||
if llm is None:
|
||||
llm = MagicMock()
|
||||
with patch(
|
||||
"local_deep_research.web_search_engines.engines.meta_search_engine.WikipediaSearchEngine"
|
||||
):
|
||||
return MetaSearchEngine(
|
||||
llm=llm,
|
||||
max_results=5,
|
||||
use_api_key_services=use_api_key_services,
|
||||
settings_snapshot=settings_snapshot,
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. _get_available_engines — basic filtering
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetAvailableEngines:
|
||||
"""Verify engine filtering logic in _get_available_engines."""
|
||||
|
||||
def test_meta_and_auto_excluded(self):
|
||||
"""Engines named 'meta' and 'auto' are always excluded."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"searxng": {},
|
||||
"meta": {},
|
||||
"auto": {},
|
||||
}
|
||||
)
|
||||
engine = _make_meta_engine(snapshot)
|
||||
assert "meta" not in engine.available_engines
|
||||
assert "auto" not in engine.available_engines
|
||||
assert "searxng" in engine.available_engines
|
||||
|
||||
def test_engine_disabled_for_auto_search(self):
|
||||
"""Engines with use_in_auto_search=False are excluded."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{"searxng": {}, "brave": {}},
|
||||
auto_settings={"brave": False},
|
||||
)
|
||||
engine = _make_meta_engine(snapshot)
|
||||
assert "brave" not in engine.available_engines
|
||||
assert "searxng" in engine.available_engines
|
||||
|
||||
def test_no_engines_available_raises(self):
|
||||
"""RuntimeError when no engines pass the filter."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{"brave": {}},
|
||||
auto_settings={"brave": False},
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="No search engines enabled"):
|
||||
_make_meta_engine(snapshot)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. API key validation combinations
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAPIKeyValidation:
|
||||
"""Verify API key filtering logic."""
|
||||
|
||||
def test_engine_requires_key_and_key_present(self):
|
||||
"""Engine with required key and key present is included."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"brave": {"requires_api_key": True, "api_key": "sk-xxx"},
|
||||
}
|
||||
)
|
||||
engine = _make_meta_engine(snapshot)
|
||||
assert "brave" in engine.available_engines
|
||||
|
||||
def test_engine_requires_key_but_key_missing(self):
|
||||
"""Engine with required key but no key is excluded."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"brave": {"requires_api_key": True},
|
||||
"searxng": {},
|
||||
}
|
||||
)
|
||||
engine = _make_meta_engine(snapshot)
|
||||
assert "brave" not in engine.available_engines
|
||||
|
||||
def test_engine_requires_key_but_api_services_disabled(self):
|
||||
"""Engine requiring key excluded when use_api_key_services=False."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"brave": {"requires_api_key": True, "api_key": "sk-xxx"},
|
||||
"searxng": {},
|
||||
}
|
||||
)
|
||||
engine = _make_meta_engine(snapshot, use_api_key_services=False)
|
||||
assert "brave" not in engine.available_engines
|
||||
assert "searxng" in engine.available_engines
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. Local engine filtering
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestLocalEngineFiltering:
|
||||
"""Verify local.* engine prefix handling in _get_available_engines."""
|
||||
|
||||
def test_local_engine_uses_local_setting_path(self):
|
||||
"""local.* engines look up use_in_auto_search under search.engine.local.* path."""
|
||||
# _get_search_config parser splits on dots, so "local.myindex" can't survive
|
||||
# as a single engine name from settings keys. We mock _get_search_config
|
||||
# to test the _get_available_engines logic for local engine names directly.
|
||||
snapshot = {
|
||||
"search.engine.local.myindex.use_in_auto_search": True,
|
||||
"search.engine.web.searxng.use_in_auto_search": True,
|
||||
}
|
||||
|
||||
with patch(
|
||||
"local_deep_research.web_search_engines.engines.meta_search_engine.WikipediaSearchEngine"
|
||||
):
|
||||
with patch.object(
|
||||
MetaSearchEngine,
|
||||
"_get_search_config",
|
||||
return_value={"local.myindex": {}, "searxng": {}},
|
||||
):
|
||||
engine = MetaSearchEngine(
|
||||
llm=MagicMock(),
|
||||
max_results=5,
|
||||
settings_snapshot=snapshot,
|
||||
)
|
||||
|
||||
assert "local.myindex" in engine.available_engines
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. analyze_query — specialized domain detection
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAnalyzeQuery:
|
||||
"""Verify query analysis and engine selection logic."""
|
||||
|
||||
def test_scientific_paper_query(self):
|
||||
"""Query containing 'scientific paper' selects arxiv/pubmed if available."""
|
||||
snapshot = _make_settings_snapshot({"arxiv": {}, "searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
result = engine.analyze_query(
|
||||
"latest scientific paper on quantum computing"
|
||||
)
|
||||
assert result[0] == "arxiv"
|
||||
|
||||
def test_code_query_selects_github(self):
|
||||
"""Query about code selects github if available."""
|
||||
snapshot = _make_settings_snapshot({"github": {}, "searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
result = engine.analyze_query("python code for sorting algorithms")
|
||||
assert "github" in result
|
||||
|
||||
def test_arxiv_keyword_prioritizes_arxiv(self):
|
||||
"""Query containing 'arxiv' puts arxiv first."""
|
||||
snapshot = _make_settings_snapshot({"arxiv": {}, "searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
result = engine.analyze_query("find arxiv papers on transformers")
|
||||
assert result[0] == "arxiv"
|
||||
|
||||
def test_pubmed_keyword_prioritizes_pubmed(self):
|
||||
"""Query containing 'pubmed' puts pubmed first."""
|
||||
snapshot = _make_settings_snapshot({"pubmed": {}, "searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
result = engine.analyze_query("search pubmed for covid studies")
|
||||
assert result[0] == "pubmed"
|
||||
|
||||
def test_general_query_prefers_searxng(self):
|
||||
"""General queries prefer searxng when available."""
|
||||
snapshot = _make_settings_snapshot({"searxng": {}, "brave": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
result = engine.analyze_query("what is the weather today")
|
||||
assert result[0] == "searxng"
|
||||
|
||||
def test_no_llm_falls_back_to_reliability(self):
|
||||
"""Without LLM and without SearXNG, engines sorted by reliability."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"brave": {"reliability": 0.9},
|
||||
"wikipedia": {"reliability": 0.7},
|
||||
}
|
||||
)
|
||||
engine = _make_meta_engine(snapshot, llm=None)
|
||||
engine.llm = None
|
||||
result = engine.analyze_query("anything")
|
||||
# Should be sorted by reliability (descending)
|
||||
assert result[0] == "brave"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. analyze_query — exception handling
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestAnalyzeQueryExceptions:
|
||||
"""Verify exception handling in analyze_query."""
|
||||
|
||||
def test_llm_exception_falls_back_to_searxng(self):
|
||||
"""LLM error falls back to searxng-first ordering."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"searxng": {"strengths": "general", "description": "meta"},
|
||||
"brave": {"strengths": "web", "description": "search"},
|
||||
}
|
||||
)
|
||||
llm = MagicMock()
|
||||
llm.invoke.side_effect = RuntimeError("LLM down")
|
||||
engine = _make_meta_engine(snapshot, llm=llm)
|
||||
|
||||
result = engine.analyze_query("test query")
|
||||
assert result[0] == "searxng"
|
||||
|
||||
def test_llm_returns_empty_response(self):
|
||||
"""LLM returns empty string — falls back gracefully."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"searxng": {"strengths": "general", "description": "meta"},
|
||||
}
|
||||
)
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="")
|
||||
engine = _make_meta_engine(snapshot, llm=llm)
|
||||
|
||||
result = engine.analyze_query("test query")
|
||||
# searxng should still be included via the fallback
|
||||
assert "searxng" in result
|
||||
|
||||
def test_llm_returns_invalid_engine_names(self):
|
||||
"""LLM returns engine names that don't exist — falls back gracefully."""
|
||||
snapshot = _make_settings_snapshot(
|
||||
{
|
||||
"searxng": {"strengths": "general", "description": "meta"},
|
||||
}
|
||||
)
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="nonexistent1,nonexistent2")
|
||||
engine = _make_meta_engine(snapshot, llm=llm)
|
||||
|
||||
result = engine.analyze_query("test query")
|
||||
# searxng should be added as fallback since no valid engines were selected
|
||||
assert "searxng" in result
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. _get_full_content — failure scenarios
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetFullContent:
|
||||
"""Verify _get_full_content failure paths."""
|
||||
|
||||
def test_snippets_only_skips_content(self):
|
||||
"""When snippets_only=True, items returned as-is."""
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
snapshot["search.snippets_only"] = True
|
||||
engine = _make_meta_engine(snapshot)
|
||||
|
||||
items = [{"title": "Test", "snippet": "content"}]
|
||||
result = engine._get_full_content(items)
|
||||
assert result == items
|
||||
|
||||
def test_selected_engine_exception_returns_items(self):
|
||||
"""Exception in selected engine returns items unchanged."""
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
snapshot["search.snippets_only"] = False
|
||||
engine = _make_meta_engine(snapshot)
|
||||
|
||||
mock_selected = MagicMock()
|
||||
mock_selected._get_full_content.side_effect = RuntimeError("failed")
|
||||
engine._selected_engine = mock_selected
|
||||
engine._selected_engine_name = "searxng"
|
||||
|
||||
items = [{"title": "Test"}]
|
||||
result = engine._get_full_content(items)
|
||||
assert result == items
|
||||
|
||||
def test_no_selected_engine_returns_items(self):
|
||||
"""No selected engine returns items unchanged."""
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
snapshot["search.snippets_only"] = False
|
||||
engine = _make_meta_engine(snapshot)
|
||||
|
||||
# Ensure _selected_engine is not set
|
||||
if hasattr(engine, "_selected_engine"):
|
||||
delattr(engine, "_selected_engine")
|
||||
|
||||
items = [{"title": "Test"}]
|
||||
result = engine._get_full_content(items)
|
||||
assert result == items
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 7. _get_engine_instance — caching and failure
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetEngineInstance:
|
||||
"""Verify engine instance caching and creation failure."""
|
||||
|
||||
def test_cached_engine_returned(self):
|
||||
"""Cached engine instance is reused."""
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
|
||||
mock_cached = MagicMock()
|
||||
engine.engine_cache["test_engine"] = mock_cached
|
||||
|
||||
result = engine._get_engine_instance("test_engine")
|
||||
assert result is mock_cached
|
||||
|
||||
@patch(
|
||||
"local_deep_research.web_search_engines.engines.meta_search_engine.create_search_engine"
|
||||
)
|
||||
def test_creation_failure_returns_none(self, mock_create):
|
||||
"""Engine creation failure returns None."""
|
||||
mock_create.side_effect = RuntimeError("init failed")
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
|
||||
result = engine._get_engine_instance("bad_engine")
|
||||
assert result is None
|
||||
|
||||
@patch(
|
||||
"local_deep_research.web_search_engines.engines.meta_search_engine.create_search_engine"
|
||||
)
|
||||
def test_max_filtered_results_zero_passed(self, mock_create):
|
||||
"""max_filtered_results=0 (falsy) should NOT be passed when it's None check."""
|
||||
mock_create.return_value = MagicMock()
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
|
||||
with patch(
|
||||
"local_deep_research.web_search_engines.engines.meta_search_engine.WikipediaSearchEngine"
|
||||
):
|
||||
engine = MetaSearchEngine(
|
||||
llm=MagicMock(),
|
||||
max_results=5,
|
||||
max_filtered_results=0,
|
||||
settings_snapshot=snapshot,
|
||||
)
|
||||
|
||||
engine._get_engine_instance("new_engine")
|
||||
# max_filtered_results=0 is not None, so should be passed
|
||||
call_kwargs = mock_create.call_args
|
||||
# 0 is not None, so the "if self.max_filtered_results is not None" check passes
|
||||
# and max_filtered_results should be in the params
|
||||
assert "max_filtered_results" in (
|
||||
call_kwargs.kwargs if call_kwargs.kwargs else call_kwargs[1]
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 8. _get_previews — engine fallthrough
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestGetPreviews:
|
||||
"""Verify _get_previews tries engines in order and falls back."""
|
||||
|
||||
def test_first_engine_succeeds(self):
|
||||
"""First engine returning results is used."""
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
|
||||
mock_search = MagicMock()
|
||||
mock_search._get_previews.return_value = [{"title": "Result 1"}]
|
||||
engine._get_engine_instance = MagicMock(return_value=mock_search)
|
||||
|
||||
with patch.object(engine, "analyze_query", return_value=["searxng"]):
|
||||
with patch(
|
||||
"local_deep_research.web_search_engines.engines.meta_search_engine.SocketIOService"
|
||||
):
|
||||
result = engine._get_previews("test query")
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_all_engines_fail_uses_fallback(self):
|
||||
"""All engines failing falls back to Wikipedia."""
|
||||
snapshot = _make_settings_snapshot({"searxng": {}})
|
||||
engine = _make_meta_engine(snapshot)
|
||||
engine._get_engine_instance = MagicMock(return_value=None)
|
||||
|
||||
mock_fallback = MagicMock()
|
||||
mock_fallback._get_previews.return_value = [
|
||||
{"title": "Wikipedia result"}
|
||||
]
|
||||
engine.fallback_engine = mock_fallback
|
||||
|
||||
with patch.object(engine, "analyze_query", return_value=["searxng"]):
|
||||
result = engine._get_previews("test query")
|
||||
|
||||
assert result[0]["title"] == "Wikipedia result"
|
||||
335
tests/security/test_path_validator_security.py
Normal file
335
tests/security/test_path_validator_security.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for PathValidator — security-sensitive path validation edge cases."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from local_deep_research.security.path_validator import PathValidator
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 1. validate_safe_path
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestValidateSafePath:
|
||||
"""Verify validate_safe_path input validation and traversal blocking."""
|
||||
|
||||
def test_non_string_input_rejected(self):
|
||||
"""Non-string input raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid path input"):
|
||||
PathValidator.validate_safe_path(123, "/tmp")
|
||||
|
||||
def test_none_input_rejected(self):
|
||||
"""None input raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid path input"):
|
||||
PathValidator.validate_safe_path(None, "/tmp")
|
||||
|
||||
def test_empty_string_rejected(self):
|
||||
"""Empty string raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid path input"):
|
||||
PathValidator.validate_safe_path("", "/tmp")
|
||||
|
||||
def test_traversal_blocked(self, tmp_path):
|
||||
"""Path traversal attempt via .. is blocked."""
|
||||
with pytest.raises(ValueError):
|
||||
PathValidator.validate_safe_path("../../etc/passwd", str(tmp_path))
|
||||
|
||||
def test_valid_relative_path(self, tmp_path):
|
||||
"""Valid relative path returns resolved Path."""
|
||||
result = PathValidator.validate_safe_path(
|
||||
"subdir/file.txt", str(tmp_path)
|
||||
)
|
||||
assert result is not None
|
||||
assert str(tmp_path) in str(result)
|
||||
|
||||
def test_extension_check_rejects_wrong_extension(self, tmp_path):
|
||||
"""Wrong extension raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid file type"):
|
||||
PathValidator.validate_safe_path(
|
||||
"file.txt",
|
||||
str(tmp_path),
|
||||
required_extensions=(".json", ".yaml"),
|
||||
)
|
||||
|
||||
def test_extension_check_accepts_correct_extension(self, tmp_path):
|
||||
"""Correct extension passes validation."""
|
||||
result = PathValidator.validate_safe_path(
|
||||
"config.json",
|
||||
str(tmp_path),
|
||||
required_extensions=(".json", ".yaml"),
|
||||
)
|
||||
assert result.suffix == ".json"
|
||||
|
||||
def test_extension_case_sensitivity(self, tmp_path):
|
||||
"""Extension check is case-sensitive — .JSON != .json."""
|
||||
with pytest.raises(ValueError, match="Invalid file type"):
|
||||
PathValidator.validate_safe_path(
|
||||
"config.JSON",
|
||||
str(tmp_path),
|
||||
required_extensions=(".json",),
|
||||
)
|
||||
|
||||
def test_whitespace_stripped(self, tmp_path):
|
||||
"""Leading/trailing whitespace in input is stripped."""
|
||||
result = PathValidator.validate_safe_path(" file.txt ", str(tmp_path))
|
||||
assert result is not None
|
||||
assert "file.txt" in str(result)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 2. validate_local_filesystem_path — home dir expansion
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestHomeDirExpansion:
|
||||
"""Verify ~ expansion edge cases."""
|
||||
|
||||
def test_tilde_slash_expands_to_home(self):
|
||||
"""'~/' expands to home directory."""
|
||||
home = str(Path.home())
|
||||
result = PathValidator.validate_local_filesystem_path(
|
||||
"~/documents",
|
||||
restricted_dirs=[],
|
||||
)
|
||||
assert str(result).startswith(home)
|
||||
|
||||
def test_tilde_alone_expands_to_home(self):
|
||||
"""'~' alone expands to home directory."""
|
||||
home = Path.home().resolve()
|
||||
result = PathValidator.validate_local_filesystem_path(
|
||||
"~",
|
||||
restricted_dirs=[],
|
||||
)
|
||||
assert result == home
|
||||
|
||||
def test_tilde_with_trailing_slashes(self):
|
||||
"""'~/' with no relative part yields home directory."""
|
||||
home = Path.home().resolve()
|
||||
result = PathValidator.validate_local_filesystem_path(
|
||||
"~/",
|
||||
restricted_dirs=[],
|
||||
)
|
||||
assert result == home
|
||||
|
||||
def test_null_bytes_rejected(self):
|
||||
"""Null bytes in path raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Null bytes"):
|
||||
PathValidator.validate_local_filesystem_path("/tmp/file\x00.txt")
|
||||
|
||||
def test_control_characters_rejected(self):
|
||||
"""Control characters raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Control characters"):
|
||||
PathValidator.validate_local_filesystem_path("/tmp/file\x01name")
|
||||
|
||||
def test_path_traversal_blocked(self):
|
||||
""".. in path raises ValueError."""
|
||||
with pytest.raises(ValueError, match="traversal"):
|
||||
PathValidator.validate_local_filesystem_path("/tmp/../etc/passwd")
|
||||
|
||||
def test_restricted_dir_blocked(self):
|
||||
"""Access to /etc is blocked by default restrictions."""
|
||||
with pytest.raises(ValueError, match="system directories"):
|
||||
PathValidator.validate_local_filesystem_path("/etc/passwd")
|
||||
|
||||
def test_custom_restricted_dir(self, tmp_path):
|
||||
"""Custom restricted dirs are enforced."""
|
||||
restricted = tmp_path / "secret"
|
||||
restricted.mkdir()
|
||||
target = restricted / "file.txt"
|
||||
target.touch()
|
||||
|
||||
with pytest.raises(ValueError, match="system directories"):
|
||||
PathValidator.validate_local_filesystem_path(
|
||||
str(target),
|
||||
restricted_dirs=[restricted],
|
||||
)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3. sanitize_for_filesystem_ops — entirely uncovered
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestSanitizeForFilesystemOps:
|
||||
"""Verify sanitize_for_filesystem_ops validation."""
|
||||
|
||||
def test_relative_path_rejected(self):
|
||||
"""Relative path raises ValueError."""
|
||||
with pytest.raises(ValueError, match="must be absolute"):
|
||||
PathValidator.sanitize_for_filesystem_ops(Path("relative/path"))
|
||||
|
||||
def test_absolute_path_passes(self, tmp_path):
|
||||
"""Absolute path is returned as a Path."""
|
||||
result = PathValidator.sanitize_for_filesystem_ops(tmp_path)
|
||||
assert isinstance(result, Path)
|
||||
assert result.is_absolute()
|
||||
|
||||
def test_root_path(self):
|
||||
"""Root path '/' passes sanitization."""
|
||||
result = PathValidator.sanitize_for_filesystem_ops(Path("/"))
|
||||
assert result == Path("/")
|
||||
|
||||
def test_deep_path_preserved(self, tmp_path):
|
||||
"""Deep nested path structure is preserved."""
|
||||
deep = tmp_path / "a" / "b" / "c"
|
||||
result = PathValidator.sanitize_for_filesystem_ops(deep)
|
||||
assert str(result).endswith("a/b/c")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. validate_model_path
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestValidateModelPath:
|
||||
"""Verify validate_model_path checks."""
|
||||
|
||||
def test_nonexistent_model_file(self, tmp_path):
|
||||
"""Non-existent model file raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Model file not found"):
|
||||
PathValidator.validate_model_path(
|
||||
"model.bin", model_root=str(tmp_path)
|
||||
)
|
||||
|
||||
def test_directory_not_file_rejected(self, tmp_path):
|
||||
"""Directory path raises 'not a file' error."""
|
||||
subdir = tmp_path / "model_dir"
|
||||
subdir.mkdir()
|
||||
with pytest.raises(ValueError, match="not a file"):
|
||||
PathValidator.validate_model_path(
|
||||
"model_dir", model_root=str(tmp_path)
|
||||
)
|
||||
|
||||
def test_valid_model_file(self, tmp_path):
|
||||
"""Valid model file returns resolved path."""
|
||||
model_file = tmp_path / "model.gguf"
|
||||
model_file.touch()
|
||||
result = PathValidator.validate_model_path(
|
||||
"model.gguf", model_root=str(tmp_path)
|
||||
)
|
||||
assert result == model_file.resolve()
|
||||
|
||||
def test_traversal_in_model_path(self, tmp_path):
|
||||
"""Path traversal in model path is blocked."""
|
||||
with pytest.raises(ValueError):
|
||||
PathValidator.validate_model_path(
|
||||
"../../../etc/passwd", model_root=str(tmp_path)
|
||||
)
|
||||
|
||||
def test_model_root_created_if_missing(self, tmp_path):
|
||||
"""Model root directory is created if it doesn't exist."""
|
||||
new_root = tmp_path / "new_models"
|
||||
# Will fail because file doesn't exist, but root should be created
|
||||
try:
|
||||
PathValidator.validate_model_path(
|
||||
"test.bin", model_root=str(new_root)
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
assert new_root.exists()
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 5. validate_config_path — restricted prefixes and null byte stripping
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestValidateConfigPath:
|
||||
"""Verify validate_config_path security checks."""
|
||||
|
||||
def test_non_string_rejected(self):
|
||||
"""Non-string input raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid config path"):
|
||||
PathValidator.validate_config_path(123)
|
||||
|
||||
def test_empty_string_rejected(self):
|
||||
"""Empty string raises ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid config path"):
|
||||
PathValidator.validate_config_path("")
|
||||
|
||||
def test_null_bytes_stripped(self, tmp_path):
|
||||
"""Null bytes are stripped from config paths (not rejected)."""
|
||||
# Create a config file
|
||||
config_file = tmp_path / "config.json"
|
||||
config_file.write_text("{}")
|
||||
|
||||
# validate_config_path strips null bytes rather than rejecting
|
||||
# This means "config\x00.json" becomes "config.json" after stripping
|
||||
result = PathValidator.validate_config_path(
|
||||
"config\x00.json",
|
||||
config_root=str(tmp_path),
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
def test_traversal_rejected(self):
|
||||
""".. in config path raises ValueError."""
|
||||
with pytest.raises(ValueError, match="traversal"):
|
||||
PathValidator.validate_config_path(
|
||||
"../etc/passwd", config_root="/tmp"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"restricted", ["etc/passwd", "proc/self", "sys/class", "dev/null"]
|
||||
)
|
||||
def test_restricted_prefixes_blocked(self, restricted):
|
||||
"""System directory prefixes are blocked."""
|
||||
with pytest.raises(ValueError, match="restricted system directory"):
|
||||
PathValidator.validate_config_path(f"/{restricted}")
|
||||
|
||||
def test_restricted_prefix_exact_match(self):
|
||||
"""Exact restricted directory name (no trailing path) is blocked."""
|
||||
with pytest.raises(ValueError, match="restricted system directory"):
|
||||
PathValidator.validate_config_path("/etc")
|
||||
|
||||
def test_invalid_extension_rejected(self, tmp_path):
|
||||
"""Non-config extension is rejected."""
|
||||
txt_file = tmp_path / "file.txt"
|
||||
txt_file.touch()
|
||||
with pytest.raises(ValueError, match="Invalid"):
|
||||
PathValidator.validate_config_path(
|
||||
"file.txt", config_root=str(tmp_path)
|
||||
)
|
||||
|
||||
def test_valid_config_extensions(self, tmp_path):
|
||||
"""All valid config extensions are accepted."""
|
||||
for ext in (".json", ".yaml", ".yml", ".toml", ".ini", ".conf"):
|
||||
config_file = tmp_path / f"config{ext}"
|
||||
config_file.write_text("")
|
||||
result = PathValidator.validate_config_path(
|
||||
f"config{ext}",
|
||||
config_root=str(tmp_path),
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
def test_absolute_config_path_rejected_by_safe_join(self, tmp_path):
|
||||
"""Absolute config paths fail safe_join validation (safe_join rejects absolute second args)."""
|
||||
config_file = tmp_path / "app.json"
|
||||
config_file.write_text("{}")
|
||||
# safe_join("/", "/tmp/.../app.json") returns None because the path is absolute
|
||||
with pytest.raises(ValueError, match="Invalid absolute path"):
|
||||
PathValidator.validate_config_path(str(config_file))
|
||||
|
||||
def test_restricted_prefix_case_insensitive(self):
|
||||
"""Restricted prefix check is case-insensitive."""
|
||||
with pytest.raises(ValueError, match="restricted system directory"):
|
||||
PathValidator.validate_config_path("/ETC/passwd")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 6. validate_data_path
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestValidateDataPath:
|
||||
"""Verify validate_data_path basics."""
|
||||
|
||||
def test_valid_data_path(self, tmp_path):
|
||||
"""Valid relative path returns resolved path."""
|
||||
result = PathValidator.validate_data_path("data.db", str(tmp_path))
|
||||
assert "data.db" in str(result)
|
||||
|
||||
def test_traversal_in_data_path(self, tmp_path):
|
||||
"""Path traversal is blocked."""
|
||||
with pytest.raises(ValueError):
|
||||
PathValidator.validate_data_path("../../etc/passwd", str(tmp_path))
|
||||
Reference in New Issue
Block a user