test: add 134 high-value tests for token counter, repository, meta search, and path validator (#2498)

Target modules with complex untested logic rather than chasing coverage %.
Each test targets real bugs that could occur in production:

- token_counter: model detection fallback chain, context overflow at 95%
  boundary, token extraction from 4 different response formats, error tracking
- repository: knowledge truncation thresholds, LLM error type classification
  (5 error types), format_findings_to_text exception fallback, old_formatting path
- meta_search_engine: engine filtering (meta/auto exclusion, API key combos,
  auto_search toggle), query analysis with specialized domains, content
  retrieval failure paths, engine instance caching
- path_validator: traversal blocking, extension case sensitivity, null byte
  handling, sanitize_for_filesystem_ops (previously 0% covered),
  restricted prefix enforcement, home dir expansion edge cases
This commit is contained in:
LearningCircuit
2026-03-01 14:05:43 +01:00
committed by GitHub
parent 0c37179aa4
commit 5efdbdf361
4 changed files with 1768 additions and 0 deletions

View File

@@ -0,0 +1,477 @@
"""Tests for FindingsRepository — truncation, error detection, untested methods, formatting fallback."""
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.documents import Document
from local_deep_research.advanced_search_system.findings.repository import (
FindingsRepository,
format_links,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_repo():
"""Create a FindingsRepository with a mock LLM."""
model = MagicMock()
return FindingsRepository(model=model)
# ===========================================================================
# 1. format_links utility
# ===========================================================================
class TestFormatLinks:
"""Verify the format_links helper."""
def test_single_link(self):
"""Single link formatted with index."""
result = format_links(
[{"title": "Example", "url": "https://example.com"}]
)
assert "1. Example" in result
assert "https://example.com" in result
def test_multiple_links(self):
"""Multiple links numbered sequentially."""
links = [
{"title": "A", "url": "https://a.com"},
{"title": "B", "url": "https://b.com"},
]
result = format_links(links)
assert "1. A" in result
assert "2. B" in result
def test_empty_links(self):
"""Empty list yields empty string."""
assert format_links([]) == ""
# ===========================================================================
# 2. add_documents() — entirely untested method
# ===========================================================================
class TestAddDocuments:
"""Verify add_documents extends the document list."""
def test_add_documents_extends_list(self):
"""Documents are appended to the internal list."""
repo = _make_repo()
docs = [
Document(page_content="Doc 1"),
Document(page_content="Doc 2"),
]
repo.add_documents(docs)
assert len(repo.documents) == 2
assert repo.documents[0].page_content == "Doc 1"
def test_add_documents_accumulates(self):
"""Multiple calls accumulate documents."""
repo = _make_repo()
repo.add_documents([Document(page_content="A")])
repo.add_documents(
[Document(page_content="B"), Document(page_content="C")]
)
assert len(repo.documents) == 3
def test_add_empty_documents(self):
"""Adding empty list is harmless."""
repo = _make_repo()
repo.add_documents([])
assert len(repo.documents) == 0
# ===========================================================================
# 3. set_questions_by_iteration() — entirely untested method
# ===========================================================================
class TestSetQuestionsByIteration:
"""Verify set_questions_by_iteration stores a copy."""
def test_stores_questions(self):
"""Questions dict is stored by iteration number."""
repo = _make_repo()
questions = {1: ["Q1", "Q2"], 2: ["Q3"]}
repo.set_questions_by_iteration(questions)
assert repo.questions_by_iteration == {1: ["Q1", "Q2"], 2: ["Q3"]}
def test_stores_copy_not_reference(self):
"""Modification of original dict does not affect stored copy."""
repo = _make_repo()
questions = {1: ["Q1"]}
repo.set_questions_by_iteration(questions)
questions[2] = ["Q2"]
assert 2 not in repo.questions_by_iteration
def test_overwrites_previous(self):
"""Calling again overwrites previous questions."""
repo = _make_repo()
repo.set_questions_by_iteration({1: ["old"]})
repo.set_questions_by_iteration({1: ["new"]})
assert repo.questions_by_iteration[1] == ["new"]
# ===========================================================================
# 4. format_findings_to_text — exception handler
# ===========================================================================
class TestFormatFindingsToText:
"""Verify format_findings_to_text and its exception fallback."""
@patch(
"local_deep_research.advanced_search_system.findings.repository.format_findings"
)
def test_successful_formatting(self, mock_format):
"""Successful formatting returns the formatted report."""
mock_format.return_value = "Formatted report"
repo = _make_repo()
repo.questions_by_iteration = {1: ["Q1"]}
result = repo.format_findings_to_text(
[{"phase": "test", "content": "data"}],
"synthesized content",
)
assert result == "Formatted report"
mock_format.assert_called_once()
@patch(
"local_deep_research.advanced_search_system.findings.repository.format_findings"
)
def test_exception_returns_fallback_message(self, mock_format):
"""When format_findings raises, fallback error message is returned."""
mock_format.side_effect = RuntimeError("formatting broke")
repo = _make_repo()
result = repo.format_findings_to_text([], "raw synthesis")
assert "Error during final formatting" in result
assert "raw synthesis" in result
# ===========================================================================
# 5. Knowledge truncation in synthesize_findings
# ===========================================================================
class TestKnowledgeTruncation:
"""Verify truncation logic when knowledge exceeds limits."""
def test_no_truncation_under_limit(self):
"""Content under 24000 chars is not truncated."""
repo = _make_repo()
content = "x" * 20000
repo.model.invoke.return_value = MagicMock(content="synthesized")
result = repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=[{"content": content}],
)
# Should succeed without truncation marker
assert "content truncated" not in result
def test_truncation_over_24000_chars(self):
"""Content over 24000 chars triggers truncation (also needs estimated_tokens > 12000)."""
repo = _make_repo()
# Need > 48000 chars so estimated_tokens (len/4) > max_safe_tokens (12000)
# AND len > 24000 for the actual truncation
content = "A" * 25000 + "B" * 25000 # 50000 chars total
# Capture the prompt passed to model.invoke
captured_prompts = []
def capture_invoke(prompt):
captured_prompts.append(prompt)
return MagicMock(content="result")
repo.model.invoke.side_effect = capture_invoke
repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=[{"content": content}],
)
# The prompt should contain the truncation marker
assert len(captured_prompts) == 1
assert "content truncated due to length" in captured_prompts[0]
def test_exactly_24000_chars_no_truncation(self):
"""At 24000 chars estimated_tokens=6000 < max_safe_tokens=12000, so no truncation."""
repo = _make_repo()
content = "x" * 24000
captured_prompts = []
def capture_invoke(prompt):
captured_prompts.append(prompt)
return MagicMock(content="result")
repo.model.invoke.side_effect = capture_invoke
repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=[{"content": content}],
)
assert len(captured_prompts) == 1
assert "content truncated" not in captured_prompts[0]
def test_truncation_preserves_start_and_end(self):
"""Truncated content keeps first 12000 and last 12000 chars."""
repo = _make_repo()
# Need > 48000 chars for estimated_tokens > max_safe_tokens AND > 24000 for truncation
content = "START" + "x" * 50000 + "END__" # well over both thresholds
captured_prompts = []
def capture_invoke(prompt):
captured_prompts.append(prompt)
return MagicMock(content="result")
repo.model.invoke.side_effect = capture_invoke
repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=[{"content": content}],
)
prompt = captured_prompts[0]
assert "START" in prompt # beginning preserved
assert "END__" in prompt # ending preserved
# ===========================================================================
# 6. LLM exception type detection (lines 430-466)
# ===========================================================================
class TestLLMExceptionTypeDetection:
"""Verify error classification in synthesize_findings exception handler."""
@pytest.mark.parametrize(
"error_msg, expected_fragment",
[
("Request timed out after 30s", "LLM timeout"),
("too many tokens in request", "token limit"),
("context length exceeded", "token limit"),
("rate limit exceeded", "rate limit"),
("rate_limit_error", "rate limit"),
("connection refused", "connection issues"),
("network error occurred", "connection issues"),
("invalid api key provided", "authentication"),
("authentication failed", "authentication"),
],
)
def test_error_type_detection(self, error_msg, expected_fragment):
"""Each error message pattern maps to the correct user-facing message."""
repo = _make_repo()
repo.model.invoke.side_effect = Exception(error_msg)
result = repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=[{"content": "data"}],
)
assert expected_fragment in result
def test_unknown_error_includes_details(self):
"""Unknown error type includes the actual error message."""
repo = _make_repo()
repo.model.invoke.side_effect = Exception(
"something completely unexpected"
)
result = repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=[{"content": "data"}],
)
assert "something completely unexpected" in result
assert "LLM error:" in result
# ===========================================================================
# 7. old_formatting=True path
# ===========================================================================
class TestOldFormattingPath:
"""Verify the old_formatting=True branch in synthesize_findings."""
@patch(
"local_deep_research.advanced_search_system.findings.repository.format_findings"
)
def test_old_formatting_converts_strings(self, mock_format):
"""String findings are converted to dicts with phase labels."""
mock_format.return_value = "old-formatted"
repo = _make_repo()
repo.questions_by_iteration = {1: ["Q1"]}
result = repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=["finding one", "finding two"],
old_formatting=True,
)
assert result == "old-formatted"
call_args = mock_format.call_args
findings_list = (
call_args.kwargs.get("findings_list")
or call_args[1].get("findings_list")
or call_args[0][0]
)
assert findings_list[0]["phase"] == "Finding 1"
assert findings_list[1]["phase"] == "Finding 2"
@patch(
"local_deep_research.advanced_search_system.findings.repository.format_findings"
)
def test_old_formatting_preserves_dicts(self, mock_format):
"""Dict findings are passed through as-is."""
mock_format.return_value = "old-formatted"
repo = _make_repo()
repo.synthesize_findings(
query="test",
sub_queries=[],
findings=[{"phase": "Analysis", "content": "data"}],
old_formatting=True,
)
call_args = mock_format.call_args
findings_list = call_args.kwargs.get("findings_list") or call_args[0][0]
assert findings_list[0]["phase"] == "Analysis"
# ===========================================================================
# 8. accumulated_knowledge handling
# ===========================================================================
class TestAccumulatedKnowledge:
"""Verify accumulated_knowledge parameter behavior."""
def test_none_accumulated_joins_findings(self):
"""When accumulated_knowledge is None, findings are joined."""
repo = _make_repo()
captured_prompts = []
def capture_invoke(prompt):
captured_prompts.append(prompt)
return MagicMock(content="result")
repo.model.invoke.side_effect = capture_invoke
repo.synthesize_findings(
query="test",
sub_queries=["sq1"],
findings=[{"content": "part1"}, {"content": "part2"}],
accumulated_knowledge=None,
)
# Both parts should appear in the prompt
prompt = captured_prompts[0]
assert "part1" in prompt
assert "part2" in prompt
def test_response_with_content_attr(self):
"""Response object with .content attribute is handled."""
repo = _make_repo()
repo.model.invoke.return_value = MagicMock(content="synthesized answer")
result = repo.synthesize_findings(
query="test",
sub_queries=[],
findings=[{"content": "data"}],
)
assert result == "synthesized answer"
def test_response_without_content_attr(self):
"""String response (no .content) is converted via str()."""
repo = _make_repo()
class PlainResponse:
def __str__(self):
return "string response"
repo.model.invoke.return_value = PlainResponse()
result = repo.synthesize_findings(
query="test",
sub_queries=[],
findings=[{"content": "data"}],
)
assert "string response" in result
# ===========================================================================
# 9. add_finding edge cases
# ===========================================================================
class TestAddFinding:
"""Verify add_finding behavior with different input types."""
def test_string_finding_converted_to_dict(self):
"""String finding is wrapped in a standard dict."""
repo = _make_repo()
repo.add_finding("query1", "some text")
findings = repo.get_findings("query1")
assert len(findings) == 1
assert findings[0]["content"] == "some text"
assert findings[0]["phase"] == "Synthesis"
def test_dict_finding_stored_directly(self):
"""Dict finding is stored as-is."""
repo = _make_repo()
finding = {"phase": "Analysis", "content": "data", "extra": "field"}
repo.add_finding("query1", finding)
assert repo.get_findings("query1")[0]["extra"] == "field"
def test_final_synthesis_creates_synthesis_key(self):
"""Finding with 'Final synthesis' phase creates a _synthesis key."""
repo = _make_repo()
repo.add_finding(
"query1",
{"phase": "Final synthesis", "content": "final answer"},
)
synthesis = repo.get_findings("query1_synthesis")
assert len(synthesis) == 1
assert synthesis[0]["content"] == "final answer"
def test_clear_findings(self):
"""clear_findings removes all findings for a query."""
repo = _make_repo()
repo.add_finding("q", "data")
repo.clear_findings("q")
assert repo.get_findings("q") == []
def test_clear_nonexistent_query_is_noop(self):
"""Clearing a query that doesn't exist does nothing."""
repo = _make_repo()
repo.clear_findings("nonexistent") # should not raise

View File

@@ -0,0 +1,523 @@
"""Tests for token_counter.py — model detection, context overflow, token extraction, and rate limiting metrics."""
import time
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.outputs import LLMResult
from local_deep_research.metrics.token_counter import TokenCountingCallback
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_callback(**kw):
"""Create a TokenCountingCallback with sensible defaults."""
return TokenCountingCallback(
research_id=kw.get("research_id"),
research_context=kw.get("research_context", {}),
)
def _make_llm_result(llm_output=None, generations=None):
"""Build a minimal LLMResult for on_llm_end tests."""
result = MagicMock(spec=LLMResult)
result.llm_output = llm_output
if generations is not None:
result.generations = generations
else:
result.generations = []
return result
# ===========================================================================
# 1. Model detection fallback chain
# ===========================================================================
class TestModelDetectionFallback:
"""Verify the on_llm_start model-name resolution chain."""
def test_preset_model_takes_priority(self):
"""Preset model should override anything in serialized/kwargs."""
cb = _make_callback()
cb.preset_model = "my-preset-model"
cb.preset_provider = "my-provider"
serialized = {"kwargs": {"model": "should-be-ignored"}}
cb.on_llm_start(
serialized, ["hello"], invocation_params={"model": "also-ignored"}
)
assert cb.current_model == "my-preset-model"
assert cb.current_provider == "my-provider"
def test_model_from_invocation_params(self):
"""Model extracted from invocation_params when preset is absent."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOpenAI"},
["hello"],
invocation_params={"model": "gpt-4"},
)
assert cb.current_model == "gpt-4"
def test_model_from_kwargs_directly(self):
"""Model extracted from kwargs when invocation_params lacks it."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOpenAI"},
["hello"],
model="gpt-3.5-turbo",
)
assert cb.current_model == "gpt-3.5-turbo"
def test_model_from_serialized_kwargs(self):
"""Model extracted from serialized['kwargs'] as next fallback."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOpenAI", "kwargs": {"model": "from-serialized"}},
["hello"],
)
assert cb.current_model == "from-serialized"
def test_model_from_serialized_name(self):
"""serialized['name'] used when kwargs has no model."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOpenAI", "name": "my-name", "kwargs": {}},
["hello"],
)
assert cb.current_model == "my-name"
def test_ollama_specific_extraction(self):
"""ChatOllama type triggers Ollama-specific model extraction."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
["hello"],
)
assert cb.current_model == "llama3"
def test_ollama_fallback_to_ollama_string(self):
"""ChatOllama without model in kwargs falls back to 'ollama'."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {}},
["hello"],
)
assert cb.current_model == "ollama"
def test_final_fallback_to_type(self):
"""When no model name found, _type string is used."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "SomeCustomLLM", "kwargs": {}},
["hello"],
)
assert cb.current_model == "SomeCustomLLM"
def test_final_fallback_to_unknown(self):
"""When nothing at all, model is 'unknown'."""
cb = _make_callback()
cb.on_llm_start({}, ["hello"])
assert cb.current_model == "unknown"
# ===========================================================================
# 2. Provider detection
# ===========================================================================
class TestProviderDetection:
"""Verify provider extraction from serialized type strings."""
@pytest.mark.parametrize(
"type_str, expected",
[
("ChatOllama", "ollama"),
("ChatOpenAI", "openai"),
("ChatAnthropic", "anthropic"),
],
)
def test_known_providers(self, type_str, expected):
"""Known provider types are mapped correctly."""
cb = _make_callback()
cb.on_llm_start(
{"_type": type_str, "kwargs": {"model": "test"}},
["hello"],
)
assert cb.current_provider == expected
def test_unknown_provider_from_kwargs(self):
"""Unknown type falls back to provider kwarg."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "CustomLLM", "kwargs": {"model": "test"}},
["hello"],
provider="custom-prov",
)
assert cb.current_provider == "custom-prov"
def test_unknown_provider_no_kwarg(self):
"""No _type and no provider kwarg yields 'unknown'."""
cb = _make_callback()
cb.on_llm_start({"kwargs": {"model": "test"}}, ["hello"])
assert cb.current_provider == "unknown"
# ===========================================================================
# 3. Token extraction from different response formats
# ===========================================================================
class TestTokenExtraction:
"""Verify on_llm_end token extraction from various LLMResult shapes."""
def test_tokens_from_llm_output_token_usage(self):
"""Standard token_usage dict in llm_output."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
)
result = _make_llm_result(
llm_output={
"token_usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
}
},
)
cb.on_llm_end(result)
assert cb.counts["total_prompt_tokens"] == 10
assert cb.counts["total_completion_tokens"] == 20
assert cb.counts["total_tokens"] == 30
def test_tokens_from_llm_output_usage_key(self):
"""Alternative 'usage' key in llm_output."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
)
result = _make_llm_result(
llm_output={
"usage": {
"prompt_tokens": 5,
"completion_tokens": 15,
"total_tokens": 20,
}
},
)
cb.on_llm_end(result)
assert cb.counts["total_tokens"] == 20
def test_tokens_from_usage_metadata_in_generations(self):
"""Ollama-style usage_metadata in generation messages."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
)
msg = MagicMock()
msg.usage_metadata = {
"input_tokens": 8,
"output_tokens": 12,
"total_tokens": 20,
}
msg.response_metadata = {}
gen = MagicMock()
gen.message = msg
result = _make_llm_result(generations=[[gen]])
cb.on_llm_end(result)
assert cb.counts["total_tokens"] == 20
def test_tokens_from_response_metadata_ollama(self):
"""Ollama response_metadata with prompt_eval_count / eval_count."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
)
msg = MagicMock()
msg.usage_metadata = None
msg.response_metadata = {
"prompt_eval_count": 100,
"eval_count": 50,
"total_duration": 1000,
}
gen = MagicMock()
gen.message = msg
result = _make_llm_result(generations=[[gen]])
cb.on_llm_end(result)
assert cb.counts["total_prompt_tokens"] == 100
assert cb.counts["total_completion_tokens"] == 50
assert cb.counts["total_tokens"] == 150
def test_missing_usage_entirely(self):
"""No usage info at all — counts stay zero."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
)
result = _make_llm_result(llm_output=None, generations=[])
cb.on_llm_end(result)
assert cb.counts["total_tokens"] == 0
def test_empty_generations_list(self):
"""generations present but empty list."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
)
result = _make_llm_result(llm_output=None, generations=[[]])
cb.on_llm_end(result)
assert cb.counts["total_tokens"] == 0
def test_usage_metadata_present_but_response_metadata_absent(self):
"""Generation has usage_metadata but no response_metadata attr."""
cb = _make_callback()
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}}, ["hi"]
)
msg = MagicMock(spec=["usage_metadata"]) # no response_metadata
msg.usage_metadata = {
"input_tokens": 3,
"output_tokens": 7,
"total_tokens": 10,
}
gen = MagicMock()
gen.message = msg
result = _make_llm_result(generations=[[gen]])
cb.on_llm_end(result)
assert cb.counts["total_tokens"] == 10
def test_by_model_accumulation(self):
"""Multiple calls accumulate per-model stats."""
cb = _make_callback()
for _ in range(3):
cb.on_llm_start(
{"_type": "ChatOpenAI", "kwargs": {"model": "gpt-4"}}, ["hi"]
)
result = _make_llm_result(
llm_output={
"token_usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
}
},
)
cb.on_llm_end(result)
assert cb.counts["by_model"]["gpt-4"]["calls"] == 3
assert cb.counts["by_model"]["gpt-4"]["total_tokens"] == 45
# ===========================================================================
# 4. Context overflow detection
# ===========================================================================
class TestContextOverflowDetection:
"""Verify context overflow detection in on_llm_end via Ollama metrics."""
def _trigger_overflow(
self, context_limit, prompt_eval_count, original_prompt_estimate
):
"""Helper to set up and trigger context overflow path."""
cb = _make_callback(research_context={"context_limit": context_limit})
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
["x" * (original_prompt_estimate * 4)], # chars = tokens * 4
)
msg = MagicMock()
msg.usage_metadata = None
msg.response_metadata = {
"prompt_eval_count": prompt_eval_count,
"eval_count": 10,
}
gen = MagicMock()
gen.message = msg
result = _make_llm_result(generations=[[gen]])
cb.on_llm_end(result)
return cb
def test_overflow_detected_at_95_percent(self):
"""Context overflow flagged when prompt_eval_count >= 95% of limit."""
cb = self._trigger_overflow(
context_limit=1000,
prompt_eval_count=960, # 96% > 95%
original_prompt_estimate=1200,
)
assert cb.context_truncated is True
assert cb.tokens_truncated == 1200 - 960
def test_no_overflow_below_threshold(self):
"""No overflow when below 95% threshold."""
cb = self._trigger_overflow(
context_limit=1000,
prompt_eval_count=940, # 94% < 95%
original_prompt_estimate=1200,
)
assert cb.context_truncated is False
def test_exact_95_boundary(self):
"""Exact 95% threshold should trigger overflow."""
cb = self._trigger_overflow(
context_limit=1000,
prompt_eval_count=950, # exactly 95%
original_prompt_estimate=1200,
)
assert cb.context_truncated is True
def test_no_overflow_when_context_limit_none(self):
"""No overflow detection when context_limit is not set."""
cb = _make_callback(research_context={})
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
["hello"],
)
msg = MagicMock()
msg.usage_metadata = None
msg.response_metadata = {"prompt_eval_count": 9999, "eval_count": 10}
gen = MagicMock()
gen.message = msg
result = _make_llm_result(generations=[[gen]])
cb.on_llm_end(result)
assert cb.context_truncated is False
def test_truncation_ratio_zero_prompt_estimate(self):
"""When original_prompt_estimate <= prompt_eval_count, tokens_truncated is 0."""
cb = self._trigger_overflow(
context_limit=100,
prompt_eval_count=96, # 96% > 95%
original_prompt_estimate=90, # less than eval count
)
# context_truncated is set, but tokens_truncated stays 0
# because original_prompt_estimate <= prompt_eval_count
assert cb.context_truncated is True
assert cb.tokens_truncated == 0
def test_prompt_eval_count_zero_skips_overflow(self):
"""prompt_eval_count == 0 should not trigger overflow check."""
cb = _make_callback(research_context={"context_limit": 1000})
cb.on_llm_start(
{"_type": "ChatOllama", "kwargs": {"model": "llama3"}},
["hello"],
)
msg = MagicMock()
msg.usage_metadata = None
msg.response_metadata = {"prompt_eval_count": 0, "eval_count": 10}
gen = MagicMock()
gen.message = msg
result = _make_llm_result(generations=[[gen]])
cb.on_llm_end(result)
assert cb.context_truncated is False
# ===========================================================================
# 5. on_llm_error tracking
# ===========================================================================
class TestOnLLMError:
"""Verify error tracking in on_llm_error."""
def test_error_sets_status_and_type(self):
"""Error callback records status and error type."""
cb = _make_callback()
cb.start_time = time.time()
cb.on_llm_error(ValueError("boom"))
assert cb.success_status == "error"
assert cb.error_type == "ValueError"
assert cb.response_time_ms is not None
def test_error_saves_to_db_when_research_id_set(self):
"""Error with research_id triggers _save_to_db with zero tokens."""
cb = _make_callback(research_id="test-123")
cb.start_time = time.time()
with patch.object(cb, "_save_to_db") as mock_save:
cb.on_llm_error(RuntimeError("fail"))
mock_save.assert_called_once_with(0, 0)
# ===========================================================================
# 6. _get_context_overflow_fields
# ===========================================================================
class TestContextOverflowFields:
"""Verify _get_context_overflow_fields output."""
def test_fields_when_no_overflow(self):
"""Fields should indicate no truncation."""
cb = _make_callback()
fields = cb._get_context_overflow_fields()
assert fields["context_truncated"] is False
assert fields["tokens_truncated"] is None
assert fields["truncation_ratio"] is None
def test_fields_when_overflow_detected(self):
"""Fields should include truncation details."""
cb = _make_callback()
cb.context_limit = 1000
cb.context_truncated = True
cb.tokens_truncated = 200
cb.truncation_ratio = 0.2
cb.ollama_metrics = {"prompt_eval_count": 800, "eval_count": 50}
fields = cb._get_context_overflow_fields()
assert fields["context_truncated"] is True
assert fields["tokens_truncated"] == 200
assert fields["truncation_ratio"] == 0.2
assert fields["ollama_prompt_eval_count"] == 800
# ===========================================================================
# 7. Prompt estimate calculation
# ===========================================================================
class TestPromptEstimate:
"""Verify original_prompt_estimate is calculated from prompt chars."""
def test_estimate_from_multiple_prompts(self):
"""Estimate is sum of chars // 4."""
cb = _make_callback()
cb.on_llm_start(
{}, ["aaaa", "bbbbbbbb"]
) # 4 + 8 = 12 chars => 3 tokens
assert cb.original_prompt_estimate == 3
def test_estimate_empty_prompts(self):
"""Empty prompts list yields 0 estimate."""
cb = _make_callback()
cb.on_llm_start({}, [])
assert cb.original_prompt_estimate == 0

View File

@@ -0,0 +1,433 @@
"""Tests for MetaSearchEngine — engine filtering, API key validation, query analysis, content retrieval."""
from unittest.mock import MagicMock, patch
import pytest
from local_deep_research.web_search_engines.engines.meta_search_engine import (
MetaSearchEngine,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_settings_snapshot(engines, auto_settings=None):
"""Build a settings_snapshot dict that MetaSearchEngine expects.
Args:
engines: dict mapping engine_name -> {config keys like requires_api_key, api_key, ...}
auto_settings: dict of per-engine use_in_auto_search overrides (default all True)
"""
snapshot = {}
for name, config in engines.items():
for key, value in config.items():
snapshot[f"search.engine.web.{name}.{key}"] = value
# Enable for auto search by default unless overridden
if auto_settings and name in auto_settings:
snapshot[f"search.engine.web.{name}.use_in_auto_search"] = (
auto_settings[name]
)
else:
snapshot[f"search.engine.web.{name}.use_in_auto_search"] = True
return snapshot
def _make_meta_engine(settings_snapshot, use_api_key_services=True, llm=None):
"""Create MetaSearchEngine with mocked dependencies."""
if llm is None:
llm = MagicMock()
with patch(
"local_deep_research.web_search_engines.engines.meta_search_engine.WikipediaSearchEngine"
):
return MetaSearchEngine(
llm=llm,
max_results=5,
use_api_key_services=use_api_key_services,
settings_snapshot=settings_snapshot,
)
# ===========================================================================
# 1. _get_available_engines — basic filtering
# ===========================================================================
class TestGetAvailableEngines:
"""Verify engine filtering logic in _get_available_engines."""
def test_meta_and_auto_excluded(self):
"""Engines named 'meta' and 'auto' are always excluded."""
snapshot = _make_settings_snapshot(
{
"searxng": {},
"meta": {},
"auto": {},
}
)
engine = _make_meta_engine(snapshot)
assert "meta" not in engine.available_engines
assert "auto" not in engine.available_engines
assert "searxng" in engine.available_engines
def test_engine_disabled_for_auto_search(self):
"""Engines with use_in_auto_search=False are excluded."""
snapshot = _make_settings_snapshot(
{"searxng": {}, "brave": {}},
auto_settings={"brave": False},
)
engine = _make_meta_engine(snapshot)
assert "brave" not in engine.available_engines
assert "searxng" in engine.available_engines
def test_no_engines_available_raises(self):
"""RuntimeError when no engines pass the filter."""
snapshot = _make_settings_snapshot(
{"brave": {}},
auto_settings={"brave": False},
)
with pytest.raises(RuntimeError, match="No search engines enabled"):
_make_meta_engine(snapshot)
# ===========================================================================
# 2. API key validation combinations
# ===========================================================================
class TestAPIKeyValidation:
"""Verify API key filtering logic."""
def test_engine_requires_key_and_key_present(self):
"""Engine with required key and key present is included."""
snapshot = _make_settings_snapshot(
{
"brave": {"requires_api_key": True, "api_key": "sk-xxx"},
}
)
engine = _make_meta_engine(snapshot)
assert "brave" in engine.available_engines
def test_engine_requires_key_but_key_missing(self):
"""Engine with required key but no key is excluded."""
snapshot = _make_settings_snapshot(
{
"brave": {"requires_api_key": True},
"searxng": {},
}
)
engine = _make_meta_engine(snapshot)
assert "brave" not in engine.available_engines
def test_engine_requires_key_but_api_services_disabled(self):
"""Engine requiring key excluded when use_api_key_services=False."""
snapshot = _make_settings_snapshot(
{
"brave": {"requires_api_key": True, "api_key": "sk-xxx"},
"searxng": {},
}
)
engine = _make_meta_engine(snapshot, use_api_key_services=False)
assert "brave" not in engine.available_engines
assert "searxng" in engine.available_engines
# ===========================================================================
# 3. Local engine filtering
# ===========================================================================
class TestLocalEngineFiltering:
"""Verify local.* engine prefix handling in _get_available_engines."""
def test_local_engine_uses_local_setting_path(self):
"""local.* engines look up use_in_auto_search under search.engine.local.* path."""
# _get_search_config parser splits on dots, so "local.myindex" can't survive
# as a single engine name from settings keys. We mock _get_search_config
# to test the _get_available_engines logic for local engine names directly.
snapshot = {
"search.engine.local.myindex.use_in_auto_search": True,
"search.engine.web.searxng.use_in_auto_search": True,
}
with patch(
"local_deep_research.web_search_engines.engines.meta_search_engine.WikipediaSearchEngine"
):
with patch.object(
MetaSearchEngine,
"_get_search_config",
return_value={"local.myindex": {}, "searxng": {}},
):
engine = MetaSearchEngine(
llm=MagicMock(),
max_results=5,
settings_snapshot=snapshot,
)
assert "local.myindex" in engine.available_engines
# ===========================================================================
# 4. analyze_query — specialized domain detection
# ===========================================================================
class TestAnalyzeQuery:
"""Verify query analysis and engine selection logic."""
def test_scientific_paper_query(self):
"""Query containing 'scientific paper' selects arxiv/pubmed if available."""
snapshot = _make_settings_snapshot({"arxiv": {}, "searxng": {}})
engine = _make_meta_engine(snapshot)
result = engine.analyze_query(
"latest scientific paper on quantum computing"
)
assert result[0] == "arxiv"
def test_code_query_selects_github(self):
"""Query about code selects github if available."""
snapshot = _make_settings_snapshot({"github": {}, "searxng": {}})
engine = _make_meta_engine(snapshot)
result = engine.analyze_query("python code for sorting algorithms")
assert "github" in result
def test_arxiv_keyword_prioritizes_arxiv(self):
"""Query containing 'arxiv' puts arxiv first."""
snapshot = _make_settings_snapshot({"arxiv": {}, "searxng": {}})
engine = _make_meta_engine(snapshot)
result = engine.analyze_query("find arxiv papers on transformers")
assert result[0] == "arxiv"
def test_pubmed_keyword_prioritizes_pubmed(self):
"""Query containing 'pubmed' puts pubmed first."""
snapshot = _make_settings_snapshot({"pubmed": {}, "searxng": {}})
engine = _make_meta_engine(snapshot)
result = engine.analyze_query("search pubmed for covid studies")
assert result[0] == "pubmed"
def test_general_query_prefers_searxng(self):
"""General queries prefer searxng when available."""
snapshot = _make_settings_snapshot({"searxng": {}, "brave": {}})
engine = _make_meta_engine(snapshot)
result = engine.analyze_query("what is the weather today")
assert result[0] == "searxng"
def test_no_llm_falls_back_to_reliability(self):
"""Without LLM and without SearXNG, engines sorted by reliability."""
snapshot = _make_settings_snapshot(
{
"brave": {"reliability": 0.9},
"wikipedia": {"reliability": 0.7},
}
)
engine = _make_meta_engine(snapshot, llm=None)
engine.llm = None
result = engine.analyze_query("anything")
# Should be sorted by reliability (descending)
assert result[0] == "brave"
# ===========================================================================
# 5. analyze_query — exception handling
# ===========================================================================
class TestAnalyzeQueryExceptions:
"""Verify exception handling in analyze_query."""
def test_llm_exception_falls_back_to_searxng(self):
"""LLM error falls back to searxng-first ordering."""
snapshot = _make_settings_snapshot(
{
"searxng": {"strengths": "general", "description": "meta"},
"brave": {"strengths": "web", "description": "search"},
}
)
llm = MagicMock()
llm.invoke.side_effect = RuntimeError("LLM down")
engine = _make_meta_engine(snapshot, llm=llm)
result = engine.analyze_query("test query")
assert result[0] == "searxng"
def test_llm_returns_empty_response(self):
"""LLM returns empty string — falls back gracefully."""
snapshot = _make_settings_snapshot(
{
"searxng": {"strengths": "general", "description": "meta"},
}
)
llm = MagicMock()
llm.invoke.return_value = MagicMock(content="")
engine = _make_meta_engine(snapshot, llm=llm)
result = engine.analyze_query("test query")
# searxng should still be included via the fallback
assert "searxng" in result
def test_llm_returns_invalid_engine_names(self):
"""LLM returns engine names that don't exist — falls back gracefully."""
snapshot = _make_settings_snapshot(
{
"searxng": {"strengths": "general", "description": "meta"},
}
)
llm = MagicMock()
llm.invoke.return_value = MagicMock(content="nonexistent1,nonexistent2")
engine = _make_meta_engine(snapshot, llm=llm)
result = engine.analyze_query("test query")
# searxng should be added as fallback since no valid engines were selected
assert "searxng" in result
# ===========================================================================
# 6. _get_full_content — failure scenarios
# ===========================================================================
class TestGetFullContent:
"""Verify _get_full_content failure paths."""
def test_snippets_only_skips_content(self):
"""When snippets_only=True, items returned as-is."""
snapshot = _make_settings_snapshot({"searxng": {}})
snapshot["search.snippets_only"] = True
engine = _make_meta_engine(snapshot)
items = [{"title": "Test", "snippet": "content"}]
result = engine._get_full_content(items)
assert result == items
def test_selected_engine_exception_returns_items(self):
"""Exception in selected engine returns items unchanged."""
snapshot = _make_settings_snapshot({"searxng": {}})
snapshot["search.snippets_only"] = False
engine = _make_meta_engine(snapshot)
mock_selected = MagicMock()
mock_selected._get_full_content.side_effect = RuntimeError("failed")
engine._selected_engine = mock_selected
engine._selected_engine_name = "searxng"
items = [{"title": "Test"}]
result = engine._get_full_content(items)
assert result == items
def test_no_selected_engine_returns_items(self):
"""No selected engine returns items unchanged."""
snapshot = _make_settings_snapshot({"searxng": {}})
snapshot["search.snippets_only"] = False
engine = _make_meta_engine(snapshot)
# Ensure _selected_engine is not set
if hasattr(engine, "_selected_engine"):
delattr(engine, "_selected_engine")
items = [{"title": "Test"}]
result = engine._get_full_content(items)
assert result == items
# ===========================================================================
# 7. _get_engine_instance — caching and failure
# ===========================================================================
class TestGetEngineInstance:
"""Verify engine instance caching and creation failure."""
def test_cached_engine_returned(self):
"""Cached engine instance is reused."""
snapshot = _make_settings_snapshot({"searxng": {}})
engine = _make_meta_engine(snapshot)
mock_cached = MagicMock()
engine.engine_cache["test_engine"] = mock_cached
result = engine._get_engine_instance("test_engine")
assert result is mock_cached
@patch(
"local_deep_research.web_search_engines.engines.meta_search_engine.create_search_engine"
)
def test_creation_failure_returns_none(self, mock_create):
"""Engine creation failure returns None."""
mock_create.side_effect = RuntimeError("init failed")
snapshot = _make_settings_snapshot({"searxng": {}})
engine = _make_meta_engine(snapshot)
result = engine._get_engine_instance("bad_engine")
assert result is None
@patch(
"local_deep_research.web_search_engines.engines.meta_search_engine.create_search_engine"
)
def test_max_filtered_results_zero_passed(self, mock_create):
"""max_filtered_results=0 (falsy) should NOT be passed when it's None check."""
mock_create.return_value = MagicMock()
snapshot = _make_settings_snapshot({"searxng": {}})
with patch(
"local_deep_research.web_search_engines.engines.meta_search_engine.WikipediaSearchEngine"
):
engine = MetaSearchEngine(
llm=MagicMock(),
max_results=5,
max_filtered_results=0,
settings_snapshot=snapshot,
)
engine._get_engine_instance("new_engine")
# max_filtered_results=0 is not None, so should be passed
call_kwargs = mock_create.call_args
# 0 is not None, so the "if self.max_filtered_results is not None" check passes
# and max_filtered_results should be in the params
assert "max_filtered_results" in (
call_kwargs.kwargs if call_kwargs.kwargs else call_kwargs[1]
)
# ===========================================================================
# 8. _get_previews — engine fallthrough
# ===========================================================================
class TestGetPreviews:
"""Verify _get_previews tries engines in order and falls back."""
def test_first_engine_succeeds(self):
"""First engine returning results is used."""
snapshot = _make_settings_snapshot({"searxng": {}})
engine = _make_meta_engine(snapshot)
mock_search = MagicMock()
mock_search._get_previews.return_value = [{"title": "Result 1"}]
engine._get_engine_instance = MagicMock(return_value=mock_search)
with patch.object(engine, "analyze_query", return_value=["searxng"]):
with patch(
"local_deep_research.web_search_engines.engines.meta_search_engine.SocketIOService"
):
result = engine._get_previews("test query")
assert len(result) == 1
def test_all_engines_fail_uses_fallback(self):
"""All engines failing falls back to Wikipedia."""
snapshot = _make_settings_snapshot({"searxng": {}})
engine = _make_meta_engine(snapshot)
engine._get_engine_instance = MagicMock(return_value=None)
mock_fallback = MagicMock()
mock_fallback._get_previews.return_value = [
{"title": "Wikipedia result"}
]
engine.fallback_engine = mock_fallback
with patch.object(engine, "analyze_query", return_value=["searxng"]):
result = engine._get_previews("test query")
assert result[0]["title"] == "Wikipedia result"

View File

@@ -0,0 +1,335 @@
"""Tests for PathValidator — security-sensitive path validation edge cases."""
from pathlib import Path
import pytest
from local_deep_research.security.path_validator import PathValidator
# ===========================================================================
# 1. validate_safe_path
# ===========================================================================
class TestValidateSafePath:
"""Verify validate_safe_path input validation and traversal blocking."""
def test_non_string_input_rejected(self):
"""Non-string input raises ValueError."""
with pytest.raises(ValueError, match="Invalid path input"):
PathValidator.validate_safe_path(123, "/tmp")
def test_none_input_rejected(self):
"""None input raises ValueError."""
with pytest.raises(ValueError, match="Invalid path input"):
PathValidator.validate_safe_path(None, "/tmp")
def test_empty_string_rejected(self):
"""Empty string raises ValueError."""
with pytest.raises(ValueError, match="Invalid path input"):
PathValidator.validate_safe_path("", "/tmp")
def test_traversal_blocked(self, tmp_path):
"""Path traversal attempt via .. is blocked."""
with pytest.raises(ValueError):
PathValidator.validate_safe_path("../../etc/passwd", str(tmp_path))
def test_valid_relative_path(self, tmp_path):
"""Valid relative path returns resolved Path."""
result = PathValidator.validate_safe_path(
"subdir/file.txt", str(tmp_path)
)
assert result is not None
assert str(tmp_path) in str(result)
def test_extension_check_rejects_wrong_extension(self, tmp_path):
"""Wrong extension raises ValueError."""
with pytest.raises(ValueError, match="Invalid file type"):
PathValidator.validate_safe_path(
"file.txt",
str(tmp_path),
required_extensions=(".json", ".yaml"),
)
def test_extension_check_accepts_correct_extension(self, tmp_path):
"""Correct extension passes validation."""
result = PathValidator.validate_safe_path(
"config.json",
str(tmp_path),
required_extensions=(".json", ".yaml"),
)
assert result.suffix == ".json"
def test_extension_case_sensitivity(self, tmp_path):
"""Extension check is case-sensitive — .JSON != .json."""
with pytest.raises(ValueError, match="Invalid file type"):
PathValidator.validate_safe_path(
"config.JSON",
str(tmp_path),
required_extensions=(".json",),
)
def test_whitespace_stripped(self, tmp_path):
"""Leading/trailing whitespace in input is stripped."""
result = PathValidator.validate_safe_path(" file.txt ", str(tmp_path))
assert result is not None
assert "file.txt" in str(result)
# ===========================================================================
# 2. validate_local_filesystem_path — home dir expansion
# ===========================================================================
class TestHomeDirExpansion:
"""Verify ~ expansion edge cases."""
def test_tilde_slash_expands_to_home(self):
"""'~/' expands to home directory."""
home = str(Path.home())
result = PathValidator.validate_local_filesystem_path(
"~/documents",
restricted_dirs=[],
)
assert str(result).startswith(home)
def test_tilde_alone_expands_to_home(self):
"""'~' alone expands to home directory."""
home = Path.home().resolve()
result = PathValidator.validate_local_filesystem_path(
"~",
restricted_dirs=[],
)
assert result == home
def test_tilde_with_trailing_slashes(self):
"""'~/' with no relative part yields home directory."""
home = Path.home().resolve()
result = PathValidator.validate_local_filesystem_path(
"~/",
restricted_dirs=[],
)
assert result == home
def test_null_bytes_rejected(self):
"""Null bytes in path raise ValueError."""
with pytest.raises(ValueError, match="Null bytes"):
PathValidator.validate_local_filesystem_path("/tmp/file\x00.txt")
def test_control_characters_rejected(self):
"""Control characters raise ValueError."""
with pytest.raises(ValueError, match="Control characters"):
PathValidator.validate_local_filesystem_path("/tmp/file\x01name")
def test_path_traversal_blocked(self):
""".. in path raises ValueError."""
with pytest.raises(ValueError, match="traversal"):
PathValidator.validate_local_filesystem_path("/tmp/../etc/passwd")
def test_restricted_dir_blocked(self):
"""Access to /etc is blocked by default restrictions."""
with pytest.raises(ValueError, match="system directories"):
PathValidator.validate_local_filesystem_path("/etc/passwd")
def test_custom_restricted_dir(self, tmp_path):
"""Custom restricted dirs are enforced."""
restricted = tmp_path / "secret"
restricted.mkdir()
target = restricted / "file.txt"
target.touch()
with pytest.raises(ValueError, match="system directories"):
PathValidator.validate_local_filesystem_path(
str(target),
restricted_dirs=[restricted],
)
# ===========================================================================
# 3. sanitize_for_filesystem_ops — entirely uncovered
# ===========================================================================
class TestSanitizeForFilesystemOps:
"""Verify sanitize_for_filesystem_ops validation."""
def test_relative_path_rejected(self):
"""Relative path raises ValueError."""
with pytest.raises(ValueError, match="must be absolute"):
PathValidator.sanitize_for_filesystem_ops(Path("relative/path"))
def test_absolute_path_passes(self, tmp_path):
"""Absolute path is returned as a Path."""
result = PathValidator.sanitize_for_filesystem_ops(tmp_path)
assert isinstance(result, Path)
assert result.is_absolute()
def test_root_path(self):
"""Root path '/' passes sanitization."""
result = PathValidator.sanitize_for_filesystem_ops(Path("/"))
assert result == Path("/")
def test_deep_path_preserved(self, tmp_path):
"""Deep nested path structure is preserved."""
deep = tmp_path / "a" / "b" / "c"
result = PathValidator.sanitize_for_filesystem_ops(deep)
assert str(result).endswith("a/b/c")
# ===========================================================================
# 4. validate_model_path
# ===========================================================================
class TestValidateModelPath:
"""Verify validate_model_path checks."""
def test_nonexistent_model_file(self, tmp_path):
"""Non-existent model file raises ValueError."""
with pytest.raises(ValueError, match="Model file not found"):
PathValidator.validate_model_path(
"model.bin", model_root=str(tmp_path)
)
def test_directory_not_file_rejected(self, tmp_path):
"""Directory path raises 'not a file' error."""
subdir = tmp_path / "model_dir"
subdir.mkdir()
with pytest.raises(ValueError, match="not a file"):
PathValidator.validate_model_path(
"model_dir", model_root=str(tmp_path)
)
def test_valid_model_file(self, tmp_path):
"""Valid model file returns resolved path."""
model_file = tmp_path / "model.gguf"
model_file.touch()
result = PathValidator.validate_model_path(
"model.gguf", model_root=str(tmp_path)
)
assert result == model_file.resolve()
def test_traversal_in_model_path(self, tmp_path):
"""Path traversal in model path is blocked."""
with pytest.raises(ValueError):
PathValidator.validate_model_path(
"../../../etc/passwd", model_root=str(tmp_path)
)
def test_model_root_created_if_missing(self, tmp_path):
"""Model root directory is created if it doesn't exist."""
new_root = tmp_path / "new_models"
# Will fail because file doesn't exist, but root should be created
try:
PathValidator.validate_model_path(
"test.bin", model_root=str(new_root)
)
except ValueError:
pass
assert new_root.exists()
# ===========================================================================
# 5. validate_config_path — restricted prefixes and null byte stripping
# ===========================================================================
class TestValidateConfigPath:
"""Verify validate_config_path security checks."""
def test_non_string_rejected(self):
"""Non-string input raises ValueError."""
with pytest.raises(ValueError, match="Invalid config path"):
PathValidator.validate_config_path(123)
def test_empty_string_rejected(self):
"""Empty string raises ValueError."""
with pytest.raises(ValueError, match="Invalid config path"):
PathValidator.validate_config_path("")
def test_null_bytes_stripped(self, tmp_path):
"""Null bytes are stripped from config paths (not rejected)."""
# Create a config file
config_file = tmp_path / "config.json"
config_file.write_text("{}")
# validate_config_path strips null bytes rather than rejecting
# This means "config\x00.json" becomes "config.json" after stripping
result = PathValidator.validate_config_path(
"config\x00.json",
config_root=str(tmp_path),
)
assert result is not None
def test_traversal_rejected(self):
""".. in config path raises ValueError."""
with pytest.raises(ValueError, match="traversal"):
PathValidator.validate_config_path(
"../etc/passwd", config_root="/tmp"
)
@pytest.mark.parametrize(
"restricted", ["etc/passwd", "proc/self", "sys/class", "dev/null"]
)
def test_restricted_prefixes_blocked(self, restricted):
"""System directory prefixes are blocked."""
with pytest.raises(ValueError, match="restricted system directory"):
PathValidator.validate_config_path(f"/{restricted}")
def test_restricted_prefix_exact_match(self):
"""Exact restricted directory name (no trailing path) is blocked."""
with pytest.raises(ValueError, match="restricted system directory"):
PathValidator.validate_config_path("/etc")
def test_invalid_extension_rejected(self, tmp_path):
"""Non-config extension is rejected."""
txt_file = tmp_path / "file.txt"
txt_file.touch()
with pytest.raises(ValueError, match="Invalid"):
PathValidator.validate_config_path(
"file.txt", config_root=str(tmp_path)
)
def test_valid_config_extensions(self, tmp_path):
"""All valid config extensions are accepted."""
for ext in (".json", ".yaml", ".yml", ".toml", ".ini", ".conf"):
config_file = tmp_path / f"config{ext}"
config_file.write_text("")
result = PathValidator.validate_config_path(
f"config{ext}",
config_root=str(tmp_path),
)
assert result is not None
def test_absolute_config_path_rejected_by_safe_join(self, tmp_path):
"""Absolute config paths fail safe_join validation (safe_join rejects absolute second args)."""
config_file = tmp_path / "app.json"
config_file.write_text("{}")
# safe_join("/", "/tmp/.../app.json") returns None because the path is absolute
with pytest.raises(ValueError, match="Invalid absolute path"):
PathValidator.validate_config_path(str(config_file))
def test_restricted_prefix_case_insensitive(self):
"""Restricted prefix check is case-insensitive."""
with pytest.raises(ValueError, match="restricted system directory"):
PathValidator.validate_config_path("/ETC/passwd")
# ===========================================================================
# 6. validate_data_path
# ===========================================================================
class TestValidateDataPath:
"""Verify validate_data_path basics."""
def test_valid_data_path(self, tmp_path):
"""Valid relative path returns resolved path."""
result = PathValidator.validate_data_path("data.db", str(tmp_path))
assert "data.db" in str(result)
def test_traversal_in_data_path(self, tmp_path):
"""Path traversal is blocked."""
with pytest.raises(ValueError):
PathValidator.validate_data_path("../../etc/passwd", str(tmp_path))