Last active
November 13, 2024 19:46
-
-
Save vergenzt/a5e9ee7d103f6454a2743eea0b586502 to your computer and use it in GitHub Desktop.
Python script to write alembic migration head(s) to head_revision.txt file
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import ast | |
| import os | |
| import sys | |
| from collections import defaultdict | |
| from graphlib import TopologicalSorter | |
| from pathlib import Path | |
| from subprocess import check_call, check_output | |
| from typing import Any, Iterable, Iterator, Tuple, Union | |
| def get_lit_assignments(migration_file: Path) -> Iterator[Tuple[str, Any]]: | |
| mod: ast.Module = ast.parse(migration_file.read_text(), migration_file) | |
| for stmt in mod.body: | |
| match stmt: | |
| case ast.Assign([ast.Name(var)], val_def): | |
| try: | |
| # F821's = false positive https://stackoverflow.com/a/67548086/718180 | |
| val = ast.literal_eval(val_def) # noqa: F821 | |
| yield var, val # noqa: F821 | |
| except ValueError: | |
| pass | |
| case _: | |
| pass | |
| def get_edges(migration_file: Path) -> Tuple[str, Iterable[str]]: | |
| """ | |
| Get revision ID & down_revision ID(s) for given migration file | |
| Use `ast` module & plaintext reads for speed | |
| """ | |
| lits = dict(get_lit_assignments(migration_file)) | |
| try: | |
| revision: str = lits["revision"] | |
| down_revision: Union[str, Iterable[str]] = lits["down_revision"] | |
| except KeyError as err: | |
| raise ValueError( | |
| f"Migration {migration_file} did not have const {err.args[0]!r} declaration" | |
| ) | |
| return revision, (down_revision,) if isinstance(down_revision, str) else down_revision or () | |
| def main(): | |
| os.chdir(check_output(["git", "rev-parse", "--show-toplevel"]).strip()) | |
| MIGRATIONS_PATH = Path( | |
| check_output(["git", "ls-files", "**/alembic.ini"], text=True).strip() | |
| ).parent | |
| MIGRATIONS_HEAD_PATH = MIGRATIONS_PATH / "head_revision.txt" | |
| VERSIONS_PATH = MIGRATIONS_PATH / "versions" | |
| # only check migrations which are in git index | |
| version_files = map(Path, check_output(["git", "ls-files", VERSIONS_PATH], text=True).splitlines()) | |
| graph = defaultdict[str, set[str]](set) | |
| for revision, prev_revisions in map(get_edges, version_files): | |
| for prev_revision in prev_revisions: | |
| # prev_revision -> next_revision(s) | |
| graph[prev_revision].add(revision) | |
| sorter = TopologicalSorter(graph) | |
| sorter.prepare() | |
| heads = sorter.get_ready() | |
| if not heads: | |
| raise ValueError("No migration heads detected") | |
| if len(heads) > 1: | |
| raise ValueError(f"Multiple migration heads detected: {heads}") | |
| print(f"Writing head revisions to {MIGRATIONS_HEAD_PATH}:", file=sys.stderr) | |
| print("\n".join(heads)) | |
| MIGRATIONS_HEAD_PATH.write_text("\n".join(heads) + "\n") | |
| # add to index (in case there are changes) | |
| check_call(["git", "add", MIGRATIONS_HEAD_PATH]) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment