Skip to content

Instantly share code, notes, and snippets.

@stephanGarland
Created December 16, 2025 18:21
Show Gist options
  • Select an option

  • Save stephanGarland/7d1c5c12b62148c0cdeae840694c8618 to your computer and use it in GitHub Desktop.

Select an option

Save stephanGarland/7d1c5c12b62148c0cdeae840694c8618 to your computer and use it in GitHub Desktop.
Estimate row / index / table size for InnoDB tables
import argparse
import doctest
import re
import signal
import sys
from enum import Enum
from typing import NoReturn
# lambda x: x <= k[0][0] ? k[0][1] : x <= k[1][0] ?...
_SIZE_MAP: dict[str, list[tuple[float | int, int]]] = {
"BIGINT": [(float("inf"), 8)],
"BLOB": [(float("inf"), 2)],
"DATE": [(float("inf"), 3)],
"DATETIME": [(0, 5), (2, 6), (4, 7), (6, 8)],
"DOUBLE": [(float("inf"), 8)],
"ENUM": [
(255, 1),
(65535, 2),
],
"FLOAT": [
(24, 4),
(53, 8),
],
"INT": [(float("inf"), 4)],
"INTEGER": [(float("inf"), 4)],
"LONGBLOB": [(float("inf"), 4)],
"LONGTEXT": [(float("inf"), 4)],
"JSON": [(float("inf"), 4)], # TODO: handle overhead
"MEDIUMBLOB": [(float("inf"), 3)],
"MEDIUMINT": [(float("inf"), 3)],
"MEDIUMTEXT": [(float("inf"), 3)],
"REAL": [(float("inf"), 8)],
"SET": [
(1, 1),
(9, 2),
(17, 3),
(25, 4),
(64, 8),
],
"SMALLINT": [(float("inf"), 2)],
"TEXT": [(float("inf"), 2)],
"TIME": [
(0, 3),
(2, 4),
(4, 5),
(6, 6),
],
"TIMESTAMP": [
(0, 4),
(2, 5),
(4, 6),
(6, 7),
],
"TINYBLOB": [(float("inf"), 1)],
"TINYINT": [(float("inf"), 1)],
"TINYTEXT": [(float("inf"), 1)],
"VARBINARY": [
(255, 1),
(65535, 2),
],
"VARCHAR": [
(255, 1),
(65535, 2),
],
"YEAR": [(float("inf"), 1)],
}
_CREATE_REGEX: str = r"""
(?:
`(?P<col_name>\w+)`\s+
(?P<col_type>[A-Za-z0-9]+)
(?:\(\s*(?P<col_width>\d+)
(?:\s*,\s*
(?P<col_scale>\d+)
)?\s*\)
)?
(?:\s+GENERATED\s+.*
(?P<col_virtual>VIRTUAL|STORED)
)?
(?:\s+COLLATE\s+
(?P<col_collation>[A-Za-z0-9_]+)
)?
(?:\s+
(?P<col_null>NOT\s+NULL|NULL)
)?
(?:\s+AUTO_INCREMENT)?
(?:\s+DEFAULT\s+
(?P<col_default>(?:'[^']*'|[^,)]*)
))?
\s*(?=,|\))
)
|
(?:
(?P<idx_kind>PRIMARY|UNIQUE|KEY)
(?:\s+KEY)?
(?:\s+`(?P<idx_name>[^`]+)`)?
\s*\(
(?P<idx_cols>[^\)]+)
\)
)
"""
_DEC_REGEX = re.compile(r"\((\d+)(?:\s*,\s*(\d+))?\)")
_WIDTH_REGEX = re.compile(r"\((\d+)\)")
_CHARSET_BPC: dict[str, int] = {
"ascii": 1,
"big5": 2,
"binary": 1,
"euckr": 2,
"gb2312": 2,
"gbk": 2,
"latin1": 1,
"sjis": 2,
"ucs2": 2,
"ujis": 3,
"utf16": 4,
"utf32": 4,
"utf8": 3,
"utf8mb4": 4,
}
_DEC_DEFAULT_PRECISION: tuple[int, int] = (10, 0)
_DEC_LEFTOVER: tuple[int, ...] = (0, 1, 1, 2, 2, 3, 3, 4, 4)
_DEC_MAX_DIGITS_PER_CHUNK: int = 9
_DEC_MULTIPLICAND: int = 4
_SIGNAL_BASE_ERROR: int = 128
_INNODB_PAGE_DEFAULT_SIZE: int = 16384
_FAST_DDL_SIZE_LIMIT: int = _INNODB_PAGE_DEFAULT_SIZE // 2
class MsgLevel(Enum):
# uses logging library levels
DEBUG = 10
INFO = 20
WARNING = 30
ERROR = 40
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Estimates maximum base row size for an InnoDB table",
epilog="Expects a text file or piped input containing a CREATE TABLE statement",
)
parser.add_argument(
"-f",
"--file",
help=r"File containing the CREATE TABLE <table>\G output",
)
parser.add_argument(
"--no-timeout", action="store_true", help="Disable the timeout for input"
)
parser.add_argument(
"--stats", action="store_true", help="Show extended statistics for the table"
)
return parser.parse_args()
def get_stdin_with_timeout(timeout: int = 5) -> str | None:
signal.signal(signal.SIGALRM, _alarm_handler)
signal.alarm(timeout)
try:
return sys.stdin.read()
except TimeoutError:
return None
finally:
signal.alarm(0)
def _alarm_handler(signum, frame) -> NoReturn:
raise TimeoutError
def _exit_handler(
msg: str | None = None,
msg_level: MsgLevel = MsgLevel.INFO,
exit_code: signal.Signals | int = 0,
) -> NoReturn:
"""Provides proper signaling of exit status"""
if isinstance(exit_code, signal.Signals):
exit_code = exit_code + _SIGNAL_BASE_ERROR
msg = msg or "Bye!"
print(f"\n{msg_level.name}: {msg}")
sys.exit(exit_code)
def _charset_bpc(collation: str | None) -> int:
"""Returns the maximum bytes per character for a given collation"""
if not collation:
return 1
s = collation.strip().lower()
charset = s.split("_", 1)[0] if s else ""
if (ret := _CHARSET_BPC.get(charset)) is not None:
return ret
_exit_handler(
msg=f"unknown charset '{charset}'", msg_level=MsgLevel.ERROR, exit_code=1
)
def _decimal_ps(col_type: str) -> tuple[int, int]:
"""Finds the precision for DECIMAL / NUMERIC"""
m = _DEC_REGEX.search(col_type)
if not m:
return _DEC_DEFAULT_PRECISION
p = int(m.group(1))
s = int(m.group(2) or 0)
if s > p:
s = p
return p, s
def _decimal_bytes(p: int, s: int) -> int:
"""Calculates the storage for DECIMAL / NUMERIC.
Requires 4 bytes (_DEC_MULTIPLICAND) per 9 digits (_DEC_MAX_DIGITS_PER_CHUNK)
with some fraction of 4 bytes for the remaining number of digits.
Ref: https://dev.mysql.com/doc/refman/5.7/en/precision-math-decimal-characteristics.html
"""
int_d = p - s
gi, ri = divmod(int_d, _DEC_MAX_DIGITS_PER_CHUNK)
gf, rf = divmod(s, _DEC_MAX_DIGITS_PER_CHUNK)
return _DEC_MULTIPLICAND * (gi + gf) + _DEC_LEFTOVER[ri] + _DEC_LEFTOVER[rf]
def _declared_width(col_type: str | None) -> int:
"""Returns the column width, if applicable"""
if not col_type:
return 0
m = _WIDTH_REGEX.search(col_type)
return int(m.group(1)) if m else 0
def col_size(
match: dict[str, str | None],
default_collation: str | None = None,
) -> int:
"""Calculates the expected size of a column, based on its type"""
col_type = str(match.get("col_type") or "").strip().upper()
base = col_type.split()[0]
if base in {"DECIMAL", "NUMERIC"}:
p = int(match.get("col_width") or _DEC_DEFAULT_PRECISION[0])
s = int(match.get("col_scale") or _DEC_DEFAULT_PRECISION[1])
if s > p:
s = p
return _decimal_bytes(p, s)
if base in {"TIME", "TIMESTAMP", "DATETIME"}:
fsp = int(match.get("col_width") or 0)
levels = _SIZE_MAP[base]
for max_fsp, bytes_ in levels:
if fsp <= max_fsp:
return bytes_
return levels[-1][1]
if base.startswith("VAR"):
collation = match.get("col_collation") or default_collation or ""
bpc = _charset_bpc(str(collation))
w_chars = int(match.get("col_width") or 0)
w_bytes = w_chars * (1 if base == "VARBINARY" else bpc)
prefix = 1 if 0 < w_bytes <= 255 else 2 if w_bytes > 255 else 0
return w_bytes + prefix
if base == "CHAR":
collation = match.get("col_collation") or default_collation or ""
bpc = _charset_bpc(str(collation))
w_chars = int(match.get("col_width") or 0)
return w_chars * bpc
if base == "BINARY":
return int(match.get("col_width") or 0)
levels = _SIZE_MAP.get(base, [])
if levels:
w = int(match.get("col_width") or 0)
for max_w, bytes_ in levels:
if w <= max_w:
return bytes_
return levels[-1][1]
return 0
def _get_matches(
create_stmt: str, regex_flags: re.RegexFlag = re.I | re.M | re.X
) -> list[dict[str, str | None]]:
return [
m.groupdict(default="") # type: ignore[arg-type]
for m in re.finditer(_CREATE_REGEX, create_stmt, regex_flags)
if m.group("col_name")
]
def _get_index_defs(
create_stmt: str, regex_flags: re.RegexFlag = re.I | re.M | re.X
) -> list[dict[str, str | None]]:
return [
m.groupdict(default="") # type: ignore[arg-type]
for m in re.finditer(_CREATE_REGEX, create_stmt, regex_flags)
if m.group("idx_kind")
]
def _parse_index_cols(index_cols: str) -> list[str]:
return re.findall(r"`([^`]+)`", index_cols)
_test_create_table: str = """
CREATE TABLE `t_example` (
`col1` DATETIME NOT NULL,
`col2` DATETIME NOT NULL,
`col3` BIGINT(20) NOT NULL AUTO_INCREMENT,
`col4` VARCHAR(36) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col5` VARCHAR(16) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col6` VARCHAR(64) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col7` VARCHAR(64) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col8` VARCHAR(3) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col9` BIGINT(20) NOT NULL DEFAULT '0',
`col10` BIGINT(20) NOT NULL DEFAULT '0',
`col11` BIGINT(20) NOT NULL DEFAULT '0',
`col12` DATETIME DEFAULT NULL,
`col13` DATETIME DEFAULT NULL,
`col14` VARCHAR(32) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col15` VARCHAR(36) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col16` TINYINT(1) DEFAULT NULL,
`col17` TINYINT(1) DEFAULT '0',
`col18` TINYINT(1) DEFAULT NULL,
`col19` VARCHAR(32) COLLATE utf8mb4_bin DEFAULT NULL,
`col20` VARCHAR(36) COLLATE utf8mb4_bin DEFAULT NULL,
`col21` VARCHAR(5) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col22` VARCHAR(36) COLLATE utf8mb4_bin NOT NULL DEFAULT '',
`col23` JSON DEFAULT NULL,
`col24` JSON DEFAULT NULL,
`col25` VARCHAR(32) COLLATE utf8mb4_bin DEFAULT '',
`col26` VARCHAR(32) COLLATE utf8mb4_bin DEFAULT '',
`col27` VARCHAR(255) COLLATE utf8mb4_bin DEFAULT NULL,
`col28` VARCHAR(4) COLLATE utf8mb4_bin DEFAULT NULL,
`col29` BIGINT(20) DEFAULT NULL,
`col30` TINYINT(1) DEFAULT NULL,
`col31` VARCHAR(64) COLLATE utf8mb4_bin DEFAULT NULL,
`col32` VARCHAR(16) COLLATE utf8mb4_bin DEFAULT '',
`col33` VARCHAR(32) COLLATE utf8mb4_bin DEFAULT NULL,
`col34` VARCHAR(32) COLLATE utf8mb4_bin DEFAULT NULL,
`col35` VARCHAR(32) COLLATE utf8mb4_bin DEFAULT NULL,
`col36` BIGINT(20) DEFAULT NULL,
`col37` BIGINT(20) DEFAULT NULL,
`col38` TINYINT(1) DEFAULT NULL,
PRIMARY KEY (`col3`),
UNIQUE KEY `uk_col4` (`col4`),
UNIQUE KEY `uk_col6` (`col6`),
KEY `idx_col5` (`col5`),
KEY `idx_col29_a` (`col29`),
KEY `idx_col29_b` (`col29`),
KEY `idx_col14` (`col14`),
KEY `idx_col20` (`col20`),
KEY `idx_col38` (`col38`),
KEY `idx_col5_col7_col1` (`col5`,`col7`,`col1`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_bin
"""
def get_total_size(create_stmt: str) -> int:
"""Calculates the maximum possible size for each row of the matches from a table.
>>> get_total_size(_test_create_table)
3549
"""
return sum(
col_size(match=match)
for match in _get_matches(create_stmt)
if match.get("col_virtual") != "VIRTUAL"
)
def get_index_row_sizes(create_stmt: str) -> tuple[int, int, int]:
"""
Return per-row index sizes in bytes as (pk, secondary_indexes, total)
- PK index entry size = sum of PK column sizes
- Each secondary index entry = sum(index columns + all PK columns)
"""
col_defs = _get_matches(create_stmt)
col_sizes: dict[str, int] = {str(d.get("col_name")): col_size(d) for d in col_defs}
idx_defs = _get_index_defs(create_stmt)
pk_cols: list[str] = []
secondary_indexes: list[list[str]] = []
for idx in idx_defs:
kind = (idx.get("idx_kind") or "").upper()
cols = _parse_index_cols(idx.get("idx_cols") or "")
if not cols:
continue
if kind == "PRIMARY":
pk_cols = cols
else:
secondary_indexes.append(cols)
pk_row_size = sum(col_sizes.get(c, 0) for c in pk_cols)
secondary_row_size = 0
for idx_cols in secondary_indexes:
secondary_row_size += sum(col_sizes.get(c, 0) for c in idx_cols)
total_index_row_size = pk_row_size + secondary_row_size
return pk_row_size, secondary_row_size, total_index_row_size
def _human_size(num_bytes: int) -> str:
units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]
value = float(num_bytes)
for unit in units:
if value < 1024.0 or unit == units[-1]:
return f"{value:.2f} {unit}"
value /= 1024.0
return f"{value:.2f} {units[-1]}"
def format_mysql_table(headers: list[str], rows: list[list[str]]) -> str:
"""
Render a MySQL-style ASCII table, but with
a left-justified header and right-justified cells
"""
cols = len(headers)
col_widths: list[int] = [len(h) for h in headers]
for row in rows:
for i in range(cols):
col_widths[i] = max(col_widths[i], len(row[i]))
def border() -> str:
parts = ["+" + "+".join("-" * (w + 2) for w in col_widths) + "+"]
return parts[0]
def header_row() -> str:
cells = [" " + headers[i].ljust(col_widths[i]) + " " for i in range(cols)]
return "|" + "|".join(cells) + "|"
def data_row(row: list[str]) -> str:
cells = [" " + row[i].rjust(col_widths[i]) + " " for i in range(cols)]
return "|" + "|".join(cells) + "|"
lines: list[str] = [border(), header_row(), border()]
for r in rows:
lines.append(data_row(r))
lines.append(border())
return "\n".join(lines)
def print_stats(create_stmt: str, row_size: int) -> None:
pk_row_size, sec_row_size, total_index_row_size = get_index_row_sizes(create_stmt)
per_row_headers = ["Metric", "Bytes"]
per_row_rows = [
["Row data size per row", str(row_size)],
["Primary key size per row", str(pk_row_size)],
["Secondary indexes size per row (all)", str(sec_row_size)],
["All indexes size per row", str(total_index_row_size)],
]
print(f"{format_mysql_table(per_row_headers, per_row_rows)}\n")
proj_headers = [
"Rows",
"Data Size",
"Index Size",
"Total Size",
]
proj_rows: list[list[str]] = []
for num_rows in [1_000_000, 5_000_000, 25_000_000, 100_000_000, 1_000_000_000]:
data_size = row_size * num_rows
index_size = total_index_row_size * num_rows
total_size = data_size + index_size
proj_rows.append(
[
f"{num_rows:,}",
_human_size(data_size),
_human_size(index_size),
_human_size(total_size),
]
)
print(format_mysql_table(proj_headers, proj_rows))
return None
if __name__ == "__main__":
doctest.testmod()
args = get_args()
if args.file:
try:
with open(args.file, "r") as f:
create_stmt = f.read()
except OSError as e:
_exit_handler(
msg=f"failed to open file: {e}", msg_level=MsgLevel.ERROR, exit_code=1
)
else:
print("Reading from /dev/stdin - expects EOF to finish")
timeout: int = 5
if args.no_timeout:
timeout = 31_536_000 # 1 year
if (create_stmt := get_stdin_with_timeout(timeout=timeout)) is None: # type: ignore[assignment]
_exit_handler(
msg="timed out while waiting for input",
msg_level=MsgLevel.ERROR,
exit_code=signal.Signals.SIGALRM,
)
elif not create_stmt:
_exit_handler(
msg="CREATE TABLE statement was empty",
msg_level=MsgLevel.ERROR,
exit_code=1,
)
row_size = get_total_size(create_stmt)
if args.stats:
print_stats(create_stmt=create_stmt, row_size=row_size)
_exit_handler(exit_code=0)
if row_size >= _INNODB_PAGE_DEFAULT_SIZE // 2:
_exit_handler(
msg=(
f"maximum row size is {row_size} - "
f"exceeds FastDDL limit of {_FAST_DDL_SIZE_LIMIT}"
),
msg_level=MsgLevel.WARNING,
exit_code=1,
)
else:
_exit_handler(msg=f"maximum row size is {row_size}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment