Files
local-deep-research/tests/database/test_benchmark_models.py
LearningCircuit c587eb423c fix(tests): dispose SQLAlchemy engines in database model test fixtures (#3204)
Change engine fixtures from return to yield with engine.dispose()
teardown to prevent connection pool leaks during test runs.
2026-03-28 01:07:03 +01:00

456 lines
15 KiB
Python

"""Tests for benchmark-related database models."""
from datetime import datetime, timedelta, timezone
import pytest
from sqlalchemy import create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import sessionmaker
from local_deep_research.database.models import (
Base,
BenchmarkConfig,
BenchmarkProgress,
BenchmarkResult,
BenchmarkRun,
BenchmarkStatus,
DatasetType,
)
class TestBenchmarkModels:
"""Test suite for benchmark-related models."""
@pytest.fixture
def engine(self):
"""Create an in-memory SQLite database for testing."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
yield engine
engine.dispose()
@pytest.fixture
def session(self, engine):
"""Create a database session for testing."""
Session = sessionmaker(bind=engine)
session = Session()
yield session
session.close()
def test_benchmark_run_creation(self, session):
"""Test creating a BenchmarkRun."""
run = BenchmarkRun(
run_name="GPT-4 vs Llama Comparison",
config_hash="abc123def456",
query_hash_list=["hash1", "hash2", "hash3"],
search_config={"engine": "google", "max_results": 10},
evaluation_config={"model": "gpt-4", "temperature": 0.0},
datasets_config={"simpleqa": 100, "browsecomp": 50},
status=BenchmarkStatus.PENDING,
total_examples=150,
)
session.add(run)
session.commit()
# Verify the run
saved = session.query(BenchmarkRun).first()
assert saved is not None
assert saved.run_name == "GPT-4 vs Llama Comparison"
assert saved.config_hash == "abc123def456"
assert len(saved.query_hash_list) == 3
assert saved.total_examples == 150
assert saved.status == BenchmarkStatus.PENDING
def test_benchmark_status_progression(self, session):
"""Test status progression of a benchmark run."""
run = BenchmarkRun(
config_hash="test123",
query_hash_list=[],
search_config={},
evaluation_config={},
datasets_config={},
status=BenchmarkStatus.PENDING,
)
session.add(run)
session.commit()
# Progress through statuses
run.status = BenchmarkStatus.IN_PROGRESS
run.start_time = datetime.now(timezone.utc)
session.commit()
assert run.status == BenchmarkStatus.IN_PROGRESS
assert run.start_time is not None
# Complete the run
run.status = BenchmarkStatus.COMPLETED
run.end_time = datetime.now(timezone.utc)
run.overall_accuracy = 0.85
run.processing_rate = 5.2 # examples per minute
session.commit()
assert run.status == BenchmarkStatus.COMPLETED
assert run.end_time is not None
assert run.overall_accuracy == 0.85
def test_benchmark_result_creation(self, session):
"""Test creating BenchmarkResult records."""
# Create a benchmark run first
run = BenchmarkRun(
config_hash="test123",
query_hash_list=["q1", "q2"],
search_config={},
evaluation_config={},
datasets_config={},
status=BenchmarkStatus.IN_PROGRESS,
)
session.add(run)
session.commit()
# Create results
result1 = BenchmarkResult(
benchmark_run_id=run.id,
example_id="simpleqa_001",
query_hash="hash_q1",
dataset_type=DatasetType.SIMPLEQA,
research_id="research-uuid-123",
question="What is the capital of France?",
correct_answer="Paris",
response="The capital of France is Paris.",
extracted_answer="Paris",
confidence="high",
processing_time=2.5,
is_correct=True,
graded_confidence="high",
)
result2 = BenchmarkResult(
benchmark_run_id=run.id,
example_id="browsecomp_001",
query_hash="hash_q2",
dataset_type=DatasetType.BROWSECOMP,
research_id="research-uuid-124",
question="Compare Python and JavaScript",
correct_answer="Python is interpreted, JavaScript runs in browsers",
response="Python and JavaScript are both popular languages...",
extracted_answer="Different use cases",
confidence="medium",
processing_time=5.2,
is_correct=False,
research_error="Timeout during search",
)
session.add_all([result1, result2])
session.commit()
# Verify results
results = session.query(BenchmarkResult).all()
assert len(results) == 2
correct_result = (
session.query(BenchmarkResult).filter_by(is_correct=True).first()
)
assert correct_result.extracted_answer == "Paris"
assert correct_result.dataset_type == DatasetType.SIMPLEQA
failed_result = (
session.query(BenchmarkResult).filter_by(is_correct=False).first()
)
assert failed_result.research_error == "Timeout during search"
def test_benchmark_config_management(self, session):
"""Test BenchmarkConfig for saving and reusing configurations."""
config = BenchmarkConfig(
name="High Accuracy Config",
description="Configuration optimized for accuracy over speed",
config_hash="config_abc123",
search_config={
"engines": ["google", "bing", "semantic_scholar"],
"max_results": 20,
"timeout": 30,
},
evaluation_config={
"model": "gpt-4",
"temperature": 0.0,
"max_retries": 3,
},
datasets_config={
"simpleqa": 200,
"browsecomp": 100,
"sample_ratio": 0.1,
},
is_default=True,
is_public=True,
)
session.add(config)
session.commit()
# Test updating usage stats
config.usage_count += 1
config.last_used = datetime.now(timezone.utc)
config.best_accuracy = 0.92
config.avg_processing_rate = 4.5
session.commit()
# Verify config
saved = (
session.query(BenchmarkConfig).filter_by(is_default=True).first()
)
assert saved is not None
assert saved.name == "High Accuracy Config"
assert saved.usage_count == 1
assert saved.best_accuracy == 0.92
assert "semantic_scholar" in saved.search_config["engines"]
def test_benchmark_progress_tracking(self, session):
"""Test BenchmarkProgress for real-time tracking."""
# Create a benchmark run
run = BenchmarkRun(
config_hash="test123",
query_hash_list=["q1", "q2", "q3"],
search_config={},
evaluation_config={},
datasets_config={},
status=BenchmarkStatus.IN_PROGRESS,
total_examples=100,
)
session.add(run)
session.commit()
# Add progress updates
progress1 = BenchmarkProgress(
benchmark_run_id=run.id,
completed_examples=25,
total_examples=100,
overall_accuracy=0.88,
dataset_accuracies={"simpleqa": 0.90, "browsecomp": 0.85},
processing_rate=3.2,
estimated_completion=datetime.now(timezone.utc)
+ timedelta(minutes=20),
current_dataset=DatasetType.SIMPLEQA,
current_example_id="simpleqa_025",
memory_usage=512.5,
cpu_usage=45.2,
)
progress2 = BenchmarkProgress(
benchmark_run_id=run.id,
completed_examples=50,
total_examples=100,
overall_accuracy=0.86,
dataset_accuracies={"simpleqa": 0.89, "browsecomp": 0.83},
processing_rate=3.5,
estimated_completion=datetime.now(timezone.utc)
+ timedelta(minutes=15),
current_dataset=DatasetType.BROWSECOMP,
current_example_id="browsecomp_010",
)
session.add_all([progress1, progress2])
session.commit()
# Query progress updates
progress_updates = (
session.query(BenchmarkProgress)
.filter_by(benchmark_run_id=run.id)
.order_by(BenchmarkProgress.timestamp)
.all()
)
assert len(progress_updates) == 2
assert progress_updates[0].completed_examples == 25
assert progress_updates[1].completed_examples == 50
assert (
progress_updates[1].processing_rate
> progress_updates[0].processing_rate
)
def test_benchmark_relationships(self, session):
"""Test relationships between benchmark models."""
# Create a run
run = BenchmarkRun(
run_name="Test Run",
config_hash="test123",
query_hash_list=["q1"],
search_config={},
evaluation_config={},
datasets_config={},
status=BenchmarkStatus.COMPLETED,
)
session.add(run)
session.commit()
# Add multiple results
for i in range(5):
result = BenchmarkResult(
benchmark_run_id=run.id,
example_id=f"example_{i}",
query_hash=f"hash_{i}",
dataset_type=DatasetType.SIMPLEQA,
question=f"Question {i}",
correct_answer=f"Answer {i}",
is_correct=i % 2 == 0, # Alternate correct/incorrect
)
session.add(result)
session.commit()
# Test relationship queries
run_with_results = session.query(BenchmarkRun).first()
results = run_with_results.results.all()
assert len(results) == 5
# Count correct results
correct_count = run_with_results.results.filter_by(
is_correct=True
).count()
assert correct_count == 3
def test_unique_constraints(self, session):
"""Test unique constraints on benchmark models."""
run = BenchmarkRun(
config_hash="test123",
query_hash_list=["q1"],
search_config={},
evaluation_config={},
datasets_config={},
)
session.add(run)
session.commit()
# Add a result
result1 = BenchmarkResult(
benchmark_run_id=run.id,
example_id="test_001",
query_hash="unique_hash",
dataset_type=DatasetType.SIMPLEQA,
question="Test question",
correct_answer="Test answer",
)
session.add(result1)
session.commit()
# Try to add duplicate (same run_id and query_hash)
result2 = BenchmarkResult(
benchmark_run_id=run.id,
example_id="test_002",
query_hash="unique_hash", # Same hash
dataset_type=DatasetType.SIMPLEQA,
question="Different question",
correct_answer="Different answer",
)
session.add(result2)
with pytest.raises(IntegrityError):
session.commit()
def test_benchmark_error_handling(self, session):
"""Test error tracking in benchmark runs."""
run = BenchmarkRun(
config_hash="test123",
query_hash_list=[],
search_config={},
evaluation_config={},
datasets_config={},
status=BenchmarkStatus.FAILED,
error_message="Failed to connect to evaluation model API",
)
session.add(run)
session.commit()
# Add failed results
result = BenchmarkResult(
benchmark_run_id=run.id,
example_id="failed_001",
query_hash="hash_failed",
dataset_type=DatasetType.SIMPLEQA,
question="What is 2+2?",
correct_answer="4",
research_error="Search timeout after 30 seconds",
evaluation_error="Could not parse model response",
)
session.add(result)
session.commit()
# Verify error tracking
failed_run = (
session.query(BenchmarkRun)
.filter_by(status=BenchmarkStatus.FAILED)
.first()
)
assert (
failed_run.error_message
== "Failed to connect to evaluation model API"
)
failed_result = (
session.query(BenchmarkResult)
.filter(BenchmarkResult.research_error.isnot(None))
.first()
)
assert "timeout" in failed_result.research_error
assert failed_result.evaluation_error is not None
def test_benchmark_statistics(self, session):
"""Test calculating statistics from benchmark results."""
# Create a completed run
run = BenchmarkRun(
config_hash="test123",
query_hash_list=["q1", "q2", "q3", "q4", "q5"],
search_config={},
evaluation_config={},
datasets_config={},
status=BenchmarkStatus.COMPLETED,
total_examples=5,
completed_examples=5,
)
session.add(run)
session.commit()
# Add results with varying accuracy
accuracies = [True, True, False, True, False] # 3/5 = 60% accuracy
for i, is_correct in enumerate(accuracies):
result = BenchmarkResult(
benchmark_run_id=run.id,
example_id=f"test_{i}",
query_hash=f"hash_{i}",
dataset_type=DatasetType.SIMPLEQA
if i < 3
else DatasetType.BROWSECOMP,
question=f"Question {i}",
correct_answer=f"Answer {i}",
is_correct=is_correct,
processing_time=2.0 + i * 0.5,
)
session.add(result)
session.commit()
# Calculate statistics
results = (
session.query(BenchmarkResult)
.filter_by(benchmark_run_id=run.id)
.all()
)
correct_count = sum(1 for r in results if r.is_correct)
accuracy = correct_count / len(results)
assert accuracy == 0.6
# Calculate per-dataset accuracy
simpleqa_results = [
r for r in results if r.dataset_type == DatasetType.SIMPLEQA
]
simpleqa_accuracy = sum(
1 for r in simpleqa_results if r.is_correct
) / len(simpleqa_results)
assert simpleqa_accuracy == 2 / 3 # ~0.667
# Calculate average processing time
avg_time = sum(r.processing_time for r in results) / len(results)
assert avg_time == 3.0 # (2.0 + 2.5 + 3.0 + 3.5 + 4.0) / 5