chore(hooks): require UtcDateTime in migrations too (#3523)

Tighten check-datetime-timezone so the UtcDateTime rule applies to
both models and migrations. Supersedes the inverted approach in #3515,
which tried to accept sa.DateTime(timezone=True) inside migrations.

- Rewrite the AST walker: handle sa.Column / bare Column, positional
  type arg at any index, bare Column(UtcDateTime) without parens (the
  hook's own example), and ast.IfExp with both branches inspected
  independently so a violation in either arm is still flagged.
- Anchor the path filter on src/local_deep_research/ to stop
  false-positives on tests/database/models/ and partial-name matches
  like database/models_backup/.
- Update .pre-commit-config.yaml name/description and the stale
  CI_CD_INFRASTRUCTURE.md hook table entry.
- Add tests/hooks/test_check_datetime_timezone.py with 20 cases:
  violations (models / migrations / conditional types / batch runs /
  bare names), allows (UtcDateTime with import, combo import order,
  empty / syntax-error files), and path-filter boundaries.
This commit is contained in:
LearningCircuit
2026-04-18 21:47:17 +02:00
committed by GitHub
parent 285eb07fb7
commit bab0f61b66
4 changed files with 417 additions and 45 deletions

View File

@@ -1,5 +1,23 @@
#!/usr/bin/env python3
"""Pre-commit hook to ensure all datetime columns use UtcDateTime for SQLite compatibility."""
"""Pre-commit hook: ensure DateTime columns in models and migrations use UtcDateTime.
Scans files under ``src/local_deep_research/database/models/`` and
``src/local_deep_research/database/migrations/versions/`` for
``Column(...)`` / ``sa.Column(...)`` calls whose type argument is a bare
``DateTime`` (with or without ``timezone=True``). Flags them and hints at
the ``UtcDateTime`` replacement from ``sqlalchemy_utc``.
Limitations (accepted gaps, not caught by this hook):
- Raw SQL inside ``op.execute("... DATETIME ...")`` — the hook cannot
parse SQL strings.
- Type-alias indirection: ``dt = sa.DateTime(); sa.Column("x", dt)``.
- Fully-qualified imports without the ``sa`` alias
(e.g. ``import sqlalchemy; sqlalchemy.Column(...)``).
- ``sa.TIMESTAMP`` columns.
- Walrus expressions: ``Column((dt := DateTime()))`` wraps the call in
``ast.NamedExpr``, which the helper does not traverse.
- Import-order variations beyond the two hardcoded substring forms.
"""
import ast
import re
@@ -8,6 +26,33 @@ from pathlib import Path
from typing import List, Tuple
def _callable_name(func_node):
"""Return the callable's short name regardless of ``X`` or ``sa.X`` form."""
if isinstance(func_node, ast.Name):
return func_node.id
if isinstance(func_node, ast.Attribute):
return func_node.attr
return None
def _resolve_type_arg(arg):
"""Return list of ('call', Call) or ('name', str) entries for all
type-like nodes in arg's subtree. Returns [] when arg is not a type
reference.
For ast.IfExp, BOTH branches are included — returning only the
first-resolved branch would silently pass a violation that lives
in the other branch.
"""
if isinstance(arg, ast.Call):
return [("call", arg)]
if isinstance(arg, ast.Name) and arg.id in {"UtcDateTime", "DateTime"}:
return [("name", arg.id)]
if isinstance(arg, ast.IfExp):
return _resolve_type_arg(arg.body) + _resolve_type_arg(arg.orelse)
return []
def check_datetime_columns(file_path: Path) -> List[Tuple[int, str, str]]:
"""Check a Python file for DateTime columns that should use UtcDateTime.
@@ -23,57 +68,71 @@ def check_datetime_columns(file_path: Path) -> List[Tuple[int, str, str]]:
print(f"Error reading {file_path}: {e}", file=sys.stderr)
return violations
# Check if file imports UtcDateTime (if it uses any DateTime columns)
has_utc_datetime_import = (
"from sqlalchemy_utc import UtcDateTime" in content
or "from sqlalchemy_utc import utcnow, UtcDateTime" in content
)
# Parse the AST to find Column definitions with DateTime
try:
tree = ast.parse(content)
except SyntaxError:
# Not valid Python, skip
return violations
fix_hint = (
"Use UtcDateTime() instead of DateTime() — "
"import: from sqlalchemy_utc import UtcDateTime"
)
for node in ast.walk(tree):
if isinstance(node, ast.Call):
# Check if this is a Column call
if isinstance(node.func, ast.Name) and node.func.id == "Column":
# Check if first argument is DateTime
if node.args and isinstance(node.args[0], ast.Call):
datetime_call = node.args[0]
if (
isinstance(datetime_call.func, ast.Name)
and datetime_call.func.id == "DateTime"
):
# This should be UtcDateTime instead
if not isinstance(node, ast.Call):
continue
if _callable_name(node.func) != "Column":
continue
type_entries = []
for arg in node.args:
type_entries = _resolve_type_arg(arg)
if type_entries:
break
for kind, payload in type_entries:
if kind == "call":
inner_name = _callable_name(payload.func)
if inner_name == "DateTime":
line_num = node.lineno
if 0 <= line_num - 1 < len(lines):
violations.append(
(line_num, lines[line_num - 1].strip(), fix_hint)
)
elif inner_name == "UtcDateTime":
if not has_utc_datetime_import:
line_num = node.lineno
if 0 <= line_num - 1 < len(lines):
violations.append(
(
line_num,
lines[line_num - 1].strip(),
"Use UtcDateTime instead of DateTime for SQLite compatibility",
"Missing import: from sqlalchemy_utc import UtcDateTime",
)
)
elif (
isinstance(datetime_call.func, ast.Name)
and datetime_call.func.id == "UtcDateTime"
):
# This is correct, but check if import exists
if not has_utc_datetime_import:
line_num = node.lineno
if 0 <= line_num - 1 < len(lines):
violations.append(
(
line_num,
lines[line_num - 1].strip(),
"Missing import: from sqlalchemy_utc import UtcDateTime",
)
)
elif kind == "name":
if payload == "DateTime":
line_num = node.lineno
if 0 <= line_num - 1 < len(lines):
violations.append(
(line_num, lines[line_num - 1].strip(), fix_hint)
)
elif payload == "UtcDateTime" and not has_utc_datetime_import:
line_num = node.lineno
if 0 <= line_num - 1 < len(lines):
violations.append(
(
line_num,
lines[line_num - 1].strip(),
"Missing import: from sqlalchemy_utc import UtcDateTime",
)
)
# Also check for func.now() usage which should be utcnow()
for i, line in enumerate(lines, 1):
if "func.now()" in line and "Column" in line:
violations.append(
@@ -83,7 +142,6 @@ def check_datetime_columns(file_path: Path) -> List[Tuple[int, str, str]]:
"Use utcnow() instead of func.now() for timezone-aware defaults",
)
)
# Check for datetime.utcnow or datetime.now(UTC) in defaults
if re.search(
r"default\s*=\s*(lambda:\s*)?datetime\.(utcnow|now)", line
):
@@ -111,23 +169,31 @@ def main():
for file_path_str in files_to_check:
file_path = Path(file_path_str)
# Only check Python files in database/models directories
if file_path.suffix == ".py" and (
"database/models" in str(file_path) or "models" in file_path.parts
):
violations = check_datetime_columns(file_path)
if violations:
all_violations.append((file_path, violations))
path_str = str(file_path)
in_scope = file_path.suffix == ".py" and (
"src/local_deep_research/database/models/" in path_str
or "src/local_deep_research/database/migrations/versions/"
in path_str
)
if not in_scope:
continue
violations = check_datetime_columns(file_path)
if violations:
all_violations.append((file_path, violations))
if all_violations:
print("\nDateTime column issues found:\n")
print("\nDateTime column issues found:\n")
for file_path, violations in all_violations:
print(f" {file_path}:")
for line_num, line_content, error_msg in violations:
print(f" Line {line_num}: {error_msg}")
print(f" > {line_content}")
print(
"\n Fix: Use UtcDateTime from sqlalchemy_utc for all datetime columns"
"\n Fix: use UtcDateTime from sqlalchemy_utc for all datetime columns"
)
print(
" (applies to both database/models/ and database/migrations/versions/)"
)
print(" Example: ")
print(" from sqlalchemy_utc import UtcDateTime, utcnow")