Skip to content

Instantly share code, notes, and snippets.

@vergenzt
Last active November 13, 2024 19:46
Show Gist options
  • Select an option

  • Save vergenzt/a5e9ee7d103f6454a2743eea0b586502 to your computer and use it in GitHub Desktop.

Select an option

Save vergenzt/a5e9ee7d103f6454a2743eea0b586502 to your computer and use it in GitHub Desktop.
Python script to write alembic migration head(s) to head_revision.txt file
#!/usr/bin/env python3
import ast
import os
import sys
from collections import defaultdict
from graphlib import TopologicalSorter
from pathlib import Path
from subprocess import check_call, check_output
from typing import Any, Iterable, Iterator, Tuple, Union
def get_lit_assignments(migration_file: Path) -> Iterator[Tuple[str, Any]]:
mod: ast.Module = ast.parse(migration_file.read_text(), migration_file)
for stmt in mod.body:
match stmt:
case ast.Assign([ast.Name(var)], val_def):
try:
# F821's = false positive https://stackoverflow.com/a/67548086/718180
val = ast.literal_eval(val_def) # noqa: F821
yield var, val # noqa: F821
except ValueError:
pass
case _:
pass
def get_edges(migration_file: Path) -> Tuple[str, Iterable[str]]:
"""
Get revision ID & down_revision ID(s) for given migration file
Use `ast` module & plaintext reads for speed
"""
lits = dict(get_lit_assignments(migration_file))
try:
revision: str = lits["revision"]
down_revision: Union[str, Iterable[str]] = lits["down_revision"]
except KeyError as err:
raise ValueError(
f"Migration {migration_file} did not have const {err.args[0]!r} declaration"
)
return revision, (down_revision,) if isinstance(down_revision, str) else down_revision or ()
def main():
os.chdir(check_output(["git", "rev-parse", "--show-toplevel"]).strip())
MIGRATIONS_PATH = Path(
check_output(["git", "ls-files", "**/alembic.ini"], text=True).strip()
).parent
MIGRATIONS_HEAD_PATH = MIGRATIONS_PATH / "head_revision.txt"
VERSIONS_PATH = MIGRATIONS_PATH / "versions"
# only check migrations which are in git index
version_files = map(Path, check_output(["git", "ls-files", VERSIONS_PATH], text=True).splitlines())
graph = defaultdict[str, set[str]](set)
for revision, prev_revisions in map(get_edges, version_files):
for prev_revision in prev_revisions:
# prev_revision -> next_revision(s)
graph[prev_revision].add(revision)
sorter = TopologicalSorter(graph)
sorter.prepare()
heads = sorter.get_ready()
if not heads:
raise ValueError("No migration heads detected")
if len(heads) > 1:
raise ValueError(f"Multiple migration heads detected: {heads}")
print(f"Writing head revisions to {MIGRATIONS_HEAD_PATH}:", file=sys.stderr)
print("\n".join(heads))
MIGRATIONS_HEAD_PATH.write_text("\n".join(heads) + "\n")
# add to index (in case there are changes)
check_call(["git", "add", MIGRATIONS_HEAD_PATH])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment