mirror of
https://github.com/LearningCircuit/local-deep-research.git
synced 2026-06-16 03:51:07 +03:00
Change engine fixtures from return to yield with engine.dispose() teardown to prevent connection pool leaks during test runs.
456 lines
15 KiB
Python
456 lines
15 KiB
Python
"""Tests for benchmark-related database models."""
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pytest
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from local_deep_research.database.models import (
|
|
Base,
|
|
BenchmarkConfig,
|
|
BenchmarkProgress,
|
|
BenchmarkResult,
|
|
BenchmarkRun,
|
|
BenchmarkStatus,
|
|
DatasetType,
|
|
)
|
|
|
|
|
|
class TestBenchmarkModels:
|
|
"""Test suite for benchmark-related models."""
|
|
|
|
@pytest.fixture
|
|
def engine(self):
|
|
"""Create an in-memory SQLite database for testing."""
|
|
engine = create_engine("sqlite:///:memory:")
|
|
Base.metadata.create_all(engine)
|
|
yield engine
|
|
engine.dispose()
|
|
|
|
@pytest.fixture
|
|
def session(self, engine):
|
|
"""Create a database session for testing."""
|
|
Session = sessionmaker(bind=engine)
|
|
session = Session()
|
|
yield session
|
|
session.close()
|
|
|
|
def test_benchmark_run_creation(self, session):
|
|
"""Test creating a BenchmarkRun."""
|
|
run = BenchmarkRun(
|
|
run_name="GPT-4 vs Llama Comparison",
|
|
config_hash="abc123def456",
|
|
query_hash_list=["hash1", "hash2", "hash3"],
|
|
search_config={"engine": "google", "max_results": 10},
|
|
evaluation_config={"model": "gpt-4", "temperature": 0.0},
|
|
datasets_config={"simpleqa": 100, "browsecomp": 50},
|
|
status=BenchmarkStatus.PENDING,
|
|
total_examples=150,
|
|
)
|
|
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Verify the run
|
|
saved = session.query(BenchmarkRun).first()
|
|
assert saved is not None
|
|
assert saved.run_name == "GPT-4 vs Llama Comparison"
|
|
assert saved.config_hash == "abc123def456"
|
|
assert len(saved.query_hash_list) == 3
|
|
assert saved.total_examples == 150
|
|
assert saved.status == BenchmarkStatus.PENDING
|
|
|
|
def test_benchmark_status_progression(self, session):
|
|
"""Test status progression of a benchmark run."""
|
|
run = BenchmarkRun(
|
|
config_hash="test123",
|
|
query_hash_list=[],
|
|
search_config={},
|
|
evaluation_config={},
|
|
datasets_config={},
|
|
status=BenchmarkStatus.PENDING,
|
|
)
|
|
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Progress through statuses
|
|
run.status = BenchmarkStatus.IN_PROGRESS
|
|
run.start_time = datetime.now(timezone.utc)
|
|
session.commit()
|
|
|
|
assert run.status == BenchmarkStatus.IN_PROGRESS
|
|
assert run.start_time is not None
|
|
|
|
# Complete the run
|
|
run.status = BenchmarkStatus.COMPLETED
|
|
run.end_time = datetime.now(timezone.utc)
|
|
run.overall_accuracy = 0.85
|
|
run.processing_rate = 5.2 # examples per minute
|
|
session.commit()
|
|
|
|
assert run.status == BenchmarkStatus.COMPLETED
|
|
assert run.end_time is not None
|
|
assert run.overall_accuracy == 0.85
|
|
|
|
def test_benchmark_result_creation(self, session):
|
|
"""Test creating BenchmarkResult records."""
|
|
# Create a benchmark run first
|
|
run = BenchmarkRun(
|
|
config_hash="test123",
|
|
query_hash_list=["q1", "q2"],
|
|
search_config={},
|
|
evaluation_config={},
|
|
datasets_config={},
|
|
status=BenchmarkStatus.IN_PROGRESS,
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Create results
|
|
result1 = BenchmarkResult(
|
|
benchmark_run_id=run.id,
|
|
example_id="simpleqa_001",
|
|
query_hash="hash_q1",
|
|
dataset_type=DatasetType.SIMPLEQA,
|
|
research_id="research-uuid-123",
|
|
question="What is the capital of France?",
|
|
correct_answer="Paris",
|
|
response="The capital of France is Paris.",
|
|
extracted_answer="Paris",
|
|
confidence="high",
|
|
processing_time=2.5,
|
|
is_correct=True,
|
|
graded_confidence="high",
|
|
)
|
|
|
|
result2 = BenchmarkResult(
|
|
benchmark_run_id=run.id,
|
|
example_id="browsecomp_001",
|
|
query_hash="hash_q2",
|
|
dataset_type=DatasetType.BROWSECOMP,
|
|
research_id="research-uuid-124",
|
|
question="Compare Python and JavaScript",
|
|
correct_answer="Python is interpreted, JavaScript runs in browsers",
|
|
response="Python and JavaScript are both popular languages...",
|
|
extracted_answer="Different use cases",
|
|
confidence="medium",
|
|
processing_time=5.2,
|
|
is_correct=False,
|
|
research_error="Timeout during search",
|
|
)
|
|
|
|
session.add_all([result1, result2])
|
|
session.commit()
|
|
|
|
# Verify results
|
|
results = session.query(BenchmarkResult).all()
|
|
assert len(results) == 2
|
|
|
|
correct_result = (
|
|
session.query(BenchmarkResult).filter_by(is_correct=True).first()
|
|
)
|
|
assert correct_result.extracted_answer == "Paris"
|
|
assert correct_result.dataset_type == DatasetType.SIMPLEQA
|
|
|
|
failed_result = (
|
|
session.query(BenchmarkResult).filter_by(is_correct=False).first()
|
|
)
|
|
assert failed_result.research_error == "Timeout during search"
|
|
|
|
def test_benchmark_config_management(self, session):
|
|
"""Test BenchmarkConfig for saving and reusing configurations."""
|
|
config = BenchmarkConfig(
|
|
name="High Accuracy Config",
|
|
description="Configuration optimized for accuracy over speed",
|
|
config_hash="config_abc123",
|
|
search_config={
|
|
"engines": ["google", "bing", "semantic_scholar"],
|
|
"max_results": 20,
|
|
"timeout": 30,
|
|
},
|
|
evaluation_config={
|
|
"model": "gpt-4",
|
|
"temperature": 0.0,
|
|
"max_retries": 3,
|
|
},
|
|
datasets_config={
|
|
"simpleqa": 200,
|
|
"browsecomp": 100,
|
|
"sample_ratio": 0.1,
|
|
},
|
|
is_default=True,
|
|
is_public=True,
|
|
)
|
|
|
|
session.add(config)
|
|
session.commit()
|
|
|
|
# Test updating usage stats
|
|
config.usage_count += 1
|
|
config.last_used = datetime.now(timezone.utc)
|
|
config.best_accuracy = 0.92
|
|
config.avg_processing_rate = 4.5
|
|
session.commit()
|
|
|
|
# Verify config
|
|
saved = (
|
|
session.query(BenchmarkConfig).filter_by(is_default=True).first()
|
|
)
|
|
assert saved is not None
|
|
assert saved.name == "High Accuracy Config"
|
|
assert saved.usage_count == 1
|
|
assert saved.best_accuracy == 0.92
|
|
assert "semantic_scholar" in saved.search_config["engines"]
|
|
|
|
def test_benchmark_progress_tracking(self, session):
|
|
"""Test BenchmarkProgress for real-time tracking."""
|
|
# Create a benchmark run
|
|
run = BenchmarkRun(
|
|
config_hash="test123",
|
|
query_hash_list=["q1", "q2", "q3"],
|
|
search_config={},
|
|
evaluation_config={},
|
|
datasets_config={},
|
|
status=BenchmarkStatus.IN_PROGRESS,
|
|
total_examples=100,
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Add progress updates
|
|
progress1 = BenchmarkProgress(
|
|
benchmark_run_id=run.id,
|
|
completed_examples=25,
|
|
total_examples=100,
|
|
overall_accuracy=0.88,
|
|
dataset_accuracies={"simpleqa": 0.90, "browsecomp": 0.85},
|
|
processing_rate=3.2,
|
|
estimated_completion=datetime.now(timezone.utc)
|
|
+ timedelta(minutes=20),
|
|
current_dataset=DatasetType.SIMPLEQA,
|
|
current_example_id="simpleqa_025",
|
|
memory_usage=512.5,
|
|
cpu_usage=45.2,
|
|
)
|
|
|
|
progress2 = BenchmarkProgress(
|
|
benchmark_run_id=run.id,
|
|
completed_examples=50,
|
|
total_examples=100,
|
|
overall_accuracy=0.86,
|
|
dataset_accuracies={"simpleqa": 0.89, "browsecomp": 0.83},
|
|
processing_rate=3.5,
|
|
estimated_completion=datetime.now(timezone.utc)
|
|
+ timedelta(minutes=15),
|
|
current_dataset=DatasetType.BROWSECOMP,
|
|
current_example_id="browsecomp_010",
|
|
)
|
|
|
|
session.add_all([progress1, progress2])
|
|
session.commit()
|
|
|
|
# Query progress updates
|
|
progress_updates = (
|
|
session.query(BenchmarkProgress)
|
|
.filter_by(benchmark_run_id=run.id)
|
|
.order_by(BenchmarkProgress.timestamp)
|
|
.all()
|
|
)
|
|
|
|
assert len(progress_updates) == 2
|
|
assert progress_updates[0].completed_examples == 25
|
|
assert progress_updates[1].completed_examples == 50
|
|
assert (
|
|
progress_updates[1].processing_rate
|
|
> progress_updates[0].processing_rate
|
|
)
|
|
|
|
def test_benchmark_relationships(self, session):
|
|
"""Test relationships between benchmark models."""
|
|
# Create a run
|
|
run = BenchmarkRun(
|
|
run_name="Test Run",
|
|
config_hash="test123",
|
|
query_hash_list=["q1"],
|
|
search_config={},
|
|
evaluation_config={},
|
|
datasets_config={},
|
|
status=BenchmarkStatus.COMPLETED,
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Add multiple results
|
|
for i in range(5):
|
|
result = BenchmarkResult(
|
|
benchmark_run_id=run.id,
|
|
example_id=f"example_{i}",
|
|
query_hash=f"hash_{i}",
|
|
dataset_type=DatasetType.SIMPLEQA,
|
|
question=f"Question {i}",
|
|
correct_answer=f"Answer {i}",
|
|
is_correct=i % 2 == 0, # Alternate correct/incorrect
|
|
)
|
|
session.add(result)
|
|
|
|
session.commit()
|
|
|
|
# Test relationship queries
|
|
run_with_results = session.query(BenchmarkRun).first()
|
|
results = run_with_results.results.all()
|
|
assert len(results) == 5
|
|
|
|
# Count correct results
|
|
correct_count = run_with_results.results.filter_by(
|
|
is_correct=True
|
|
).count()
|
|
assert correct_count == 3
|
|
|
|
def test_unique_constraints(self, session):
|
|
"""Test unique constraints on benchmark models."""
|
|
run = BenchmarkRun(
|
|
config_hash="test123",
|
|
query_hash_list=["q1"],
|
|
search_config={},
|
|
evaluation_config={},
|
|
datasets_config={},
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Add a result
|
|
result1 = BenchmarkResult(
|
|
benchmark_run_id=run.id,
|
|
example_id="test_001",
|
|
query_hash="unique_hash",
|
|
dataset_type=DatasetType.SIMPLEQA,
|
|
question="Test question",
|
|
correct_answer="Test answer",
|
|
)
|
|
session.add(result1)
|
|
session.commit()
|
|
|
|
# Try to add duplicate (same run_id and query_hash)
|
|
result2 = BenchmarkResult(
|
|
benchmark_run_id=run.id,
|
|
example_id="test_002",
|
|
query_hash="unique_hash", # Same hash
|
|
dataset_type=DatasetType.SIMPLEQA,
|
|
question="Different question",
|
|
correct_answer="Different answer",
|
|
)
|
|
session.add(result2)
|
|
|
|
with pytest.raises(IntegrityError):
|
|
session.commit()
|
|
|
|
def test_benchmark_error_handling(self, session):
|
|
"""Test error tracking in benchmark runs."""
|
|
run = BenchmarkRun(
|
|
config_hash="test123",
|
|
query_hash_list=[],
|
|
search_config={},
|
|
evaluation_config={},
|
|
datasets_config={},
|
|
status=BenchmarkStatus.FAILED,
|
|
error_message="Failed to connect to evaluation model API",
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Add failed results
|
|
result = BenchmarkResult(
|
|
benchmark_run_id=run.id,
|
|
example_id="failed_001",
|
|
query_hash="hash_failed",
|
|
dataset_type=DatasetType.SIMPLEQA,
|
|
question="What is 2+2?",
|
|
correct_answer="4",
|
|
research_error="Search timeout after 30 seconds",
|
|
evaluation_error="Could not parse model response",
|
|
)
|
|
session.add(result)
|
|
session.commit()
|
|
|
|
# Verify error tracking
|
|
failed_run = (
|
|
session.query(BenchmarkRun)
|
|
.filter_by(status=BenchmarkStatus.FAILED)
|
|
.first()
|
|
)
|
|
assert (
|
|
failed_run.error_message
|
|
== "Failed to connect to evaluation model API"
|
|
)
|
|
|
|
failed_result = (
|
|
session.query(BenchmarkResult)
|
|
.filter(BenchmarkResult.research_error.isnot(None))
|
|
.first()
|
|
)
|
|
assert "timeout" in failed_result.research_error
|
|
assert failed_result.evaluation_error is not None
|
|
|
|
def test_benchmark_statistics(self, session):
|
|
"""Test calculating statistics from benchmark results."""
|
|
# Create a completed run
|
|
run = BenchmarkRun(
|
|
config_hash="test123",
|
|
query_hash_list=["q1", "q2", "q3", "q4", "q5"],
|
|
search_config={},
|
|
evaluation_config={},
|
|
datasets_config={},
|
|
status=BenchmarkStatus.COMPLETED,
|
|
total_examples=5,
|
|
completed_examples=5,
|
|
)
|
|
session.add(run)
|
|
session.commit()
|
|
|
|
# Add results with varying accuracy
|
|
accuracies = [True, True, False, True, False] # 3/5 = 60% accuracy
|
|
|
|
for i, is_correct in enumerate(accuracies):
|
|
result = BenchmarkResult(
|
|
benchmark_run_id=run.id,
|
|
example_id=f"test_{i}",
|
|
query_hash=f"hash_{i}",
|
|
dataset_type=DatasetType.SIMPLEQA
|
|
if i < 3
|
|
else DatasetType.BROWSECOMP,
|
|
question=f"Question {i}",
|
|
correct_answer=f"Answer {i}",
|
|
is_correct=is_correct,
|
|
processing_time=2.0 + i * 0.5,
|
|
)
|
|
session.add(result)
|
|
|
|
session.commit()
|
|
|
|
# Calculate statistics
|
|
results = (
|
|
session.query(BenchmarkResult)
|
|
.filter_by(benchmark_run_id=run.id)
|
|
.all()
|
|
)
|
|
|
|
correct_count = sum(1 for r in results if r.is_correct)
|
|
accuracy = correct_count / len(results)
|
|
assert accuracy == 0.6
|
|
|
|
# Calculate per-dataset accuracy
|
|
simpleqa_results = [
|
|
r for r in results if r.dataset_type == DatasetType.SIMPLEQA
|
|
]
|
|
simpleqa_accuracy = sum(
|
|
1 for r in simpleqa_results if r.is_correct
|
|
) / len(simpleqa_results)
|
|
assert simpleqa_accuracy == 2 / 3 # ~0.667
|
|
|
|
# Calculate average processing time
|
|
avg_time = sum(r.processing_time for r in results) / len(results)
|
|
assert avg_time == 3.0 # (2.0 + 2.5 + 3.0 + 3.5 + 4.0) / 5
|