mirror of
https://github.com/LearningCircuit/local-deep-research.git
synced 2026-06-16 03:51:07 +03:00
Tighten check-datetime-timezone so the UtcDateTime rule applies to both models and migrations. Supersedes the inverted approach in #3515, which tried to accept sa.DateTime(timezone=True) inside migrations. - Rewrite the AST walker: handle sa.Column / bare Column, positional type arg at any index, bare Column(UtcDateTime) without parens (the hook's own example), and ast.IfExp with both branches inspected independently so a violation in either arm is still flagged. - Anchor the path filter on src/local_deep_research/ to stop false-positives on tests/database/models/ and partial-name matches like database/models_backup/. - Update .pre-commit-config.yaml name/description and the stale CI_CD_INFRASTRUCTURE.md hook table entry. - Add tests/hooks/test_check_datetime_timezone.py with 20 cases: violations (models / migrations / conditional types / batch runs / bare names), allows (UtcDateTime with import, combo import order, empty / syntax-error files), and path-filter boundaries.
208 lines
7.1 KiB
Python
Executable File
208 lines
7.1 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Pre-commit hook: ensure DateTime columns in models and migrations use UtcDateTime.
|
|
|
|
Scans files under ``src/local_deep_research/database/models/`` and
|
|
``src/local_deep_research/database/migrations/versions/`` for
|
|
``Column(...)`` / ``sa.Column(...)`` calls whose type argument is a bare
|
|
``DateTime`` (with or without ``timezone=True``). Flags them and hints at
|
|
the ``UtcDateTime`` replacement from ``sqlalchemy_utc``.
|
|
|
|
Limitations (accepted gaps, not caught by this hook):
|
|
- Raw SQL inside ``op.execute("... DATETIME ...")`` — the hook cannot
|
|
parse SQL strings.
|
|
- Type-alias indirection: ``dt = sa.DateTime(); sa.Column("x", dt)``.
|
|
- Fully-qualified imports without the ``sa`` alias
|
|
(e.g. ``import sqlalchemy; sqlalchemy.Column(...)``).
|
|
- ``sa.TIMESTAMP`` columns.
|
|
- Walrus expressions: ``Column((dt := DateTime()))`` wraps the call in
|
|
``ast.NamedExpr``, which the helper does not traverse.
|
|
- Import-order variations beyond the two hardcoded substring forms.
|
|
"""
|
|
|
|
import ast
|
|
import re
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
|
|
|
|
def _callable_name(func_node):
|
|
"""Return the callable's short name regardless of ``X`` or ``sa.X`` form."""
|
|
if isinstance(func_node, ast.Name):
|
|
return func_node.id
|
|
if isinstance(func_node, ast.Attribute):
|
|
return func_node.attr
|
|
return None
|
|
|
|
|
|
def _resolve_type_arg(arg):
|
|
"""Return list of ('call', Call) or ('name', str) entries for all
|
|
type-like nodes in arg's subtree. Returns [] when arg is not a type
|
|
reference.
|
|
|
|
For ast.IfExp, BOTH branches are included — returning only the
|
|
first-resolved branch would silently pass a violation that lives
|
|
in the other branch.
|
|
"""
|
|
if isinstance(arg, ast.Call):
|
|
return [("call", arg)]
|
|
if isinstance(arg, ast.Name) and arg.id in {"UtcDateTime", "DateTime"}:
|
|
return [("name", arg.id)]
|
|
if isinstance(arg, ast.IfExp):
|
|
return _resolve_type_arg(arg.body) + _resolve_type_arg(arg.orelse)
|
|
return []
|
|
|
|
|
|
def check_datetime_columns(file_path: Path) -> List[Tuple[int, str, str]]:
|
|
"""Check a Python file for DateTime columns that should use UtcDateTime.
|
|
|
|
Returns a list of (line_number, line_content, error_message) tuples for violations.
|
|
"""
|
|
violations = []
|
|
|
|
try:
|
|
with open(file_path, "r") as f:
|
|
content = f.read()
|
|
lines = content.split("\n")
|
|
except Exception as e:
|
|
print(f"Error reading {file_path}: {e}", file=sys.stderr)
|
|
return violations
|
|
|
|
has_utc_datetime_import = (
|
|
"from sqlalchemy_utc import UtcDateTime" in content
|
|
or "from sqlalchemy_utc import utcnow, UtcDateTime" in content
|
|
)
|
|
|
|
try:
|
|
tree = ast.parse(content)
|
|
except SyntaxError:
|
|
return violations
|
|
|
|
fix_hint = (
|
|
"Use UtcDateTime() instead of DateTime() — "
|
|
"import: from sqlalchemy_utc import UtcDateTime"
|
|
)
|
|
|
|
for node in ast.walk(tree):
|
|
if not isinstance(node, ast.Call):
|
|
continue
|
|
if _callable_name(node.func) != "Column":
|
|
continue
|
|
|
|
type_entries = []
|
|
for arg in node.args:
|
|
type_entries = _resolve_type_arg(arg)
|
|
if type_entries:
|
|
break
|
|
|
|
for kind, payload in type_entries:
|
|
if kind == "call":
|
|
inner_name = _callable_name(payload.func)
|
|
if inner_name == "DateTime":
|
|
line_num = node.lineno
|
|
if 0 <= line_num - 1 < len(lines):
|
|
violations.append(
|
|
(line_num, lines[line_num - 1].strip(), fix_hint)
|
|
)
|
|
elif inner_name == "UtcDateTime":
|
|
if not has_utc_datetime_import:
|
|
line_num = node.lineno
|
|
if 0 <= line_num - 1 < len(lines):
|
|
violations.append(
|
|
(
|
|
line_num,
|
|
lines[line_num - 1].strip(),
|
|
"Missing import: from sqlalchemy_utc import UtcDateTime",
|
|
)
|
|
)
|
|
elif kind == "name":
|
|
if payload == "DateTime":
|
|
line_num = node.lineno
|
|
if 0 <= line_num - 1 < len(lines):
|
|
violations.append(
|
|
(line_num, lines[line_num - 1].strip(), fix_hint)
|
|
)
|
|
elif payload == "UtcDateTime" and not has_utc_datetime_import:
|
|
line_num = node.lineno
|
|
if 0 <= line_num - 1 < len(lines):
|
|
violations.append(
|
|
(
|
|
line_num,
|
|
lines[line_num - 1].strip(),
|
|
"Missing import: from sqlalchemy_utc import UtcDateTime",
|
|
)
|
|
)
|
|
|
|
for i, line in enumerate(lines, 1):
|
|
if "func.now()" in line and "Column" in line:
|
|
violations.append(
|
|
(
|
|
i,
|
|
line.strip(),
|
|
"Use utcnow() instead of func.now() for timezone-aware defaults",
|
|
)
|
|
)
|
|
if re.search(
|
|
r"default\s*=\s*(lambda:\s*)?datetime\.(utcnow|now)", line
|
|
):
|
|
violations.append(
|
|
(
|
|
i,
|
|
line.strip(),
|
|
"Use utcnow() from sqlalchemy_utc instead of datetime functions for defaults",
|
|
)
|
|
)
|
|
|
|
return violations
|
|
|
|
|
|
def main():
|
|
"""Main entry point for the pre-commit hook."""
|
|
files_to_check = sys.argv[1:]
|
|
|
|
if not files_to_check:
|
|
print("No files to check")
|
|
return 0
|
|
|
|
all_violations = []
|
|
|
|
for file_path_str in files_to_check:
|
|
file_path = Path(file_path_str)
|
|
|
|
path_str = str(file_path)
|
|
in_scope = file_path.suffix == ".py" and (
|
|
"src/local_deep_research/database/models/" in path_str
|
|
or "src/local_deep_research/database/migrations/versions/"
|
|
in path_str
|
|
)
|
|
if not in_scope:
|
|
continue
|
|
|
|
violations = check_datetime_columns(file_path)
|
|
if violations:
|
|
all_violations.append((file_path, violations))
|
|
|
|
if all_violations:
|
|
print("\nDateTime column issues found:\n")
|
|
for file_path, violations in all_violations:
|
|
print(f" {file_path}:")
|
|
for line_num, line_content, error_msg in violations:
|
|
print(f" Line {line_num}: {error_msg}")
|
|
print(f" > {line_content}")
|
|
print(
|
|
"\n Fix: use UtcDateTime from sqlalchemy_utc for all datetime columns"
|
|
)
|
|
print(
|
|
" (applies to both database/models/ and database/migrations/versions/)"
|
|
)
|
|
print(" Example: ")
|
|
print(" from sqlalchemy_utc import UtcDateTime, utcnow")
|
|
print(" Column(UtcDateTime, default=utcnow(), ...)\n")
|
|
return 1
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|