Skip to content

Instantly share code, notes, and snippets.

@z0z0r4
Created February 15, 2026 21:58
Show Gist options
  • Select an option

  • Save z0z0r4/f45d95520091d5e8a82877a92e8b5083 to your computer and use it in GitHub Desktop.

Select an option

Save z0z0r4/f45d95520091d5e8a82877a92e8b5083 to your computer and use it in GitHub Desktop.
Timsort
import random
import sys
from sort import TimSort
class CountedItem:
"""包装类,用于统计比较次数"""
comparisons = 0
def __init__(self, val):
self.val = val
@classmethod
def reset(cls):
cls.comparisons = 0
def __lt__(self, other):
CountedItem.comparisons += 1
return self.val < other.val
def __le__(self, other):
CountedItem.comparisons += 1
return self.val <= other.val
def __gt__(self, other):
CountedItem.comparisons += 1
return self.val > other.val
def __ge__(self, other):
CountedItem.comparisons += 1
return self.val >= other.val
def __eq__(self, other):
CountedItem.comparisons += 1
return self.val == other.val
def get_test_data(name, n):
if name == "*sort": # Random
return [CountedItem(random.random()) for _ in range(n)]
elif name == "\\sort": # Descending
return [CountedItem(i) for i in range(n, 0, -1)]
elif name == "/sort": # Ascending
return [CountedItem(i) for i in range(n)]
elif name == "3sort": # Ascending + 3 exchanges
arr = [CountedItem(i) for i in range(n)]
for _ in range(3):
i, j = random.sample(range(n), 2)
arr[i], arr[j] = arr[j], arr[i]
return arr
elif name == "+sort": # Ascending + 10 random at end
return [CountedItem(i) for i in range(n-10)] + [CountedItem(random.random() * n) for _ in range(10)]
elif name == "%sort": # Ascending + 1% random replacements
arr = [CountedItem(i) for i in range(n)]
for _ in range(max(1, n // 100)):
arr[random.randint(0, n-1)] = CountedItem(random.random() * n)
return arr
elif name == "~sort": # Many duplicates (4 unique values)
return [CountedItem(random.randint(0, 3)) for _ in range(n)]
elif name == "=sort": # All equal
v = random.random()
return [CountedItem(v) for _ in range(n)]
elif name == "!sort": # Quicksort killer / Mixed
# A common messy pattern
arr = [CountedItem(i) for i in range(n)]
random.shuffle(arr)
return arr
return []
def quicksort(arr, low, high):
if low < high:
pivot_idx = partition(arr, low, high)
quicksort(arr, low, pivot_idx)
quicksort(arr, pivot_idx + 1, high)
def partition(arr, low, high):
pivot = arr[(low + high) // 2]
i = low - 1
j = high + 1
while True:
i += 1
while arr[i] < pivot: i += 1
j -= 1
while arr[j] > pivot: j -= 1
if i >= j: return j
arr[i], arr[j] = arr[j], arr[i]
def mergesort(arr):
if len(arr) <= 1: return arr
mid = len(arr) // 2
left = mergesort(arr[:mid])
right = mergesort(arr[mid:])
return merge(left, right)
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
def run_benchmark(n_list):
test_names = ["*sort", "\\sort", "/sort", "+sort", "~sort", "=sort", "%sort", "3sort", "!sort"]
# 增加递归深度限制以适配原始快排/归并
sys.setrecursionlimit(5000)
for n in n_list:
# 存储当前 n 下所有算法的结果,用于计算百分比
results = {name: {} for name in test_names}
# 预先跑一遍 Mine 和 Built-in 拿数据
for algo_label, algo_func in [("Mine", lambda d: TimSort(d).sort()), ("Built", lambda d: d.sort())]:
for name in test_names:
data = get_test_data(name, n)
CountedItem.reset()
algo_func(data)
results[name][algo_label] = CountedItem.comparisons
# 打印 Markdown 表格
print(f"\n### N = {n}")
header = f"| {'Algorithm':<15} | " + " | ".join(f"{name:^10}" for name in test_names) + " |"
separator = f"| :--- | " + " | ".join(":---:" for _ in test_names) + " |"
print(header)
print(separator)
# 1. 打印 Mine 及其相对于 Built 的百分比
mine_row = f"| **TimSort (Mine)** | "
for name in test_names:
m, b = results[name]["Mine"], results[name]["Built"]
perc = (m / b * 100) if b != 0 else 0
mine_row += f"{m} ({perc:.1f}%) | "
print(mine_row)
# 2. 打印 Built-in
built_row = f"| Built-in Sort | "
for name in test_names:
built_row += f"{results[name]['Built']} | "
print(built_row)
# 3. 打印其他算法 (Quick/Merge)
for other_name, other_func in [("QuickSort", lambda d: quicksort(d, 0, len(d)-1)),
("MergeSort", lambda d: mergesort(d))]:
row = f"| {other_name:<15} | "
for name in test_names:
data = get_test_data(name, n)
CountedItem.reset()
try:
other_func(data)
row += f"{CountedItem.comparisons} | "
except RecursionError:
row += f"Recursion | "
print(row)
if __name__ == "__main__":
run_benchmark([512, 1024, 2048, 4096])
import unittest
import random
from sort import TimSort
class TestTimSort(unittest.TestCase):
def test_empty_list(self):
arr = []
TimSort(arr).sort()
self.assertEqual(arr, [])
def test_single_element(self):
arr = [1]
TimSort(arr).sort()
self.assertEqual(arr, [1])
def test_already_sorted(self):
arr = [1, 2, 3, 4, 5]
TimSort(arr).sort()
self.assertEqual(arr, [1, 2, 3, 4, 5])
def test_reverse_sorted(self):
arr = [5, 4, 3, 2, 1]
TimSort(arr).sort()
self.assertEqual(arr, [1, 2, 3, 4, 5])
def test_duplicates(self):
arr = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 5]
expected = sorted(arr)
TimSort(arr).sort()
self.assertEqual(arr, expected)
def test_random_list(self):
arr = [random.randint(0, 1000) for _ in range(100)]
expected = sorted(arr)
TimSort(arr).sort()
self.assertEqual(arr, expected)
def test_large_random_list(self):
# Exercising minrun and merge logic
arr = [random.randint(0, 10000) for _ in range(500)]
expected = sorted(arr)
TimSort(arr).sort()
self.assertEqual(arr, expected)
def test_binary_search(self):
# Testing binary_search_right and binary_search_left
arr = [1, 3, 5, 7, 9]
# binary_search_right: 返回第一个 > key 的位置
self.assertEqual(TimSort.binary_search_right(arr, 0, 4, 0), 0)
self.assertEqual(TimSort.binary_search_right(arr, 0, 4, 1), 1)
self.assertEqual(TimSort.binary_search_right(arr, 0, 4, 4), 2)
self.assertEqual(TimSort.binary_search_right(arr, 0, 4, 10), 5)
# binary_search_left: 返回第一个 >= key 的位置
self.assertEqual(TimSort.binary_search_left(arr, 0, 4, 0), 0)
self.assertEqual(TimSort.binary_search_left(arr, 0, 4, 1), 0)
self.assertEqual(TimSort.binary_search_left(arr, 0, 4, 3), 1)
self.assertEqual(TimSort.binary_search_left(arr, 0, 4, 5), 2)
def test_stability(self):
# 稳定性测试:相同值的元素在排序后应保持原始相对顺序
class IndexedValue:
def __init__(self, val, original_index):
self.val = val
self.original_index = original_index
# 模拟排序时的比较行为(只比较 val)
def __lt__(self, other): return self.val < other.val
def __le__(self, other): return self.val <= other.val
def __gt__(self, other): return self.val > other.val
def __ge__(self, other): return self.val >= other.val
def __eq__(self, other): return self.val == other.val
def __repr__(self):
return f"({self.val}, index:{self.original_index})"
# 构造包含重复 val 的对象列表
data = [
IndexedValue(5, 0),
IndexedValue(2, 1),
IndexedValue(5, 2),
IndexedValue(1, 3),
IndexedValue(2, 4),
IndexedValue(5, 5)
]
TimSort(data).sort()
# 验证排序是否正确且稳定
for i in range(len(data) - 1):
# 1. 验证值是递增的
self.assertLessEqual(data[i].val, data[i+1].val)
# 2. 如果值相等,验证原始索引也是递增的(说明相对位置没变)
if data[i].val == data[i+1].val:
self.assertLess(data[i].original_index, data[i+1].original_index,
f"Stability failed at values {data[i]} and {data[i+1]}")
if __name__ == "__main__":
unittest.main()
from typing import List
class TimSort:
items: list
stack: List[tuple]
minrun: int
min_gallop = 7
def __init__(self, items: list):
self.items = items
self.stack = []
self.minrun = self.calc_minrun()
def calc_minrun(self):
n = len(self.items)
r = 0
while n >= 32:
r = r or (n % 2)
n //= 2
return n + r
@staticmethod
def binary_search_right(items: list, low: int, high: int, key: int):
while low <= high:
mid = (low + high) // 2
if items[mid] > key:
high = mid - 1
else:
low = mid + 1
return low
@staticmethod
def binary_search_left(items: list, low: int, high: int, key: int):
while low <= high:
mid = (low + high) // 2
if items[mid] >= key:
high = mid - 1
else:
low = mid + 1
return low
@staticmethod
def binary_insert_sort(items: list, low: int, high: int):
for i in range(low + 1, high + 1):
key = items[i]
pos = TimSort.binary_search_right(items, low, i - 1, key)
items[pos : i + 1] = [key] + items[pos:i]
def gallop_right(self, items: list, low: int, high: int, key: int, hint: int = -1) -> int:
"""
在 items[low:high+1] 中找到第一个大于 key 的位置。
hint 是搜索的起始位置,从该处向两侧倍增。
"""
if low > high:
return low
if hint < low or hint > high:
hint = low
if items[hint] <= key:
# 向右倍增: 已知 items[hint + last_ofs] <= key
last_ofs = 0
ofs = 1
max_ofs = high - hint + 1
while ofs < max_ofs and items[hint + ofs] <= key:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs > max_ofs:
ofs = max_ofs
# items[hint+last_ofs] <= key < items[hint+ofs]
# ++last_ofs 跳过已知元素
l_bound = hint + last_ofs + 1
r_bound = hint + ofs
else:
# 向左倍增: 已知 items[hint - last_ofs] > key
last_ofs = 0
ofs = 1
max_ofs = hint - low + 1
while ofs < max_ofs and items[hint - ofs] > key:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs > max_ofs:
ofs = max_ofs
# items[hint-ofs] <= key < items[hint-last_ofs]
# 跳过已知元素 (hint-last_ofs)
l_bound = hint - ofs
r_bound = hint - last_ofs - 1
return self.binary_search_right(items, max(low, l_bound), min(high, r_bound), key)
def gallop_left(self, items: list, low: int, high: int, key: int, hint: int = -1) -> int:
"""
在 items[low:high+1] 中找到第一个大于等于 key 的位置。
hint 是搜索的起始位置,从该处向两侧倍增。
"""
if low > high:
return low
if hint < low or hint > high:
hint = high
if items[hint] >= key:
# 向左倍增: 已知 items[hint - last_ofs] >= key
last_ofs = 0
ofs = 1
max_ofs = hint - low + 1
while ofs < max_ofs and items[hint - ofs] >= key:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs > max_ofs:
ofs = max_ofs
# items[hint-ofs] < key <= items[hint-last_ofs]
# 跳过已知元素 (hint-last_ofs)
l_bound = hint - ofs
r_bound = hint - last_ofs - 1
else:
# 向右倍增: 已知 items[hint + last_ofs] < key
last_ofs = 0
ofs = 1
max_ofs = high - hint + 1
while ofs < max_ofs and items[hint + ofs] < key:
last_ofs = ofs
ofs = (ofs << 1) + 1
if ofs > max_ofs:
ofs = max_ofs
# items[hint+last_ofs] < key <= items[hint+ofs]
# ++last_ofs 跳过已知元素
l_bound = hint + last_ofs + 1
r_bound = hint + ofs
return self.binary_search_left(items, max(low, l_bound), min(high, r_bound), key)
def merge_low(self, base_a: int, len_a: int, base_b: int, len_b: int):
tmp = self.items[base_a : base_a + len_a]
cursor_a = 0
cursor_b = base_b
dest = base_a
high_b = base_b + len_b - 1
min_gallop = self.min_gallop
while cursor_a < len_a and cursor_b <= high_b:
acount = 0
bcount = 0
while cursor_a < len_a and cursor_b <= high_b:
if self.items[cursor_b] < tmp[cursor_a]:
self.items[dest] = self.items[cursor_b]
cursor_b += 1
dest += 1
bcount += 1
acount = 0
if bcount >= min_gallop:
break
else:
self.items[dest] = tmp[cursor_a]
cursor_a += 1
dest += 1
acount += 1
bcount = 0
if acount >= min_gallop:
break
if cursor_a >= len_a or cursor_b > high_b:
break
min_gallop += 1
while True:
min_gallop -= min_gallop > 1 # 每次循环奖励
self.min_gallop = min_gallop
# Gallop A (在 A 中找 B[0])
k = self.gallop_right(tmp, cursor_a, len_a - 1, self.items[cursor_b], hint=cursor_a)
acount = k - cursor_a
if acount > 0:
self.items[dest : dest + acount] = tmp[cursor_a : k]
dest += acount
cursor_a = k
if cursor_a == len_a: break
self.items[dest] = self.items[cursor_b]
dest += 1
cursor_b += 1
if cursor_b > high_b: break
# Gallop B (在 B 中找 A[0])
k = self.gallop_left(self.items, cursor_b, high_b, tmp[cursor_a], hint=cursor_b)
bcount = k - cursor_b
if bcount > 0:
self.items[dest : dest + bcount] = self.items[cursor_b : k]
dest += bcount
cursor_b = k
if cursor_b > high_b: break
self.items[dest] = tmp[cursor_a]
dest += 1
cursor_a += 1
if cursor_a == len_a: break
if acount < 7 and bcount < 7:
break
if min_gallop < 0: min_gallop = 0
# 退出 Gallop 模式的惩罚
min_gallop += 1
self.min_gallop = max(1, min_gallop)
if cursor_a < len_a:
self.items[dest : dest + (len_a - cursor_a)] = tmp[cursor_a:]
def merge_high(self, base_a: int, len_a: int, base_b: int, len_b: int):
"""
从后向前合并。B 被复制到 tmp。
"""
tmp = self.items[base_b : base_b + len_b]
cursor_a = base_a + len_a - 1
cursor_b = len_b - 1
dest = base_b + len_b - 1
min_gallop = self.min_gallop
while cursor_a >= base_a and cursor_b >= 0:
acount = 0
bcount = 0
while cursor_a >= base_a and cursor_b >= 0:
if self.items[cursor_a] > tmp[cursor_b]:
self.items[dest] = self.items[cursor_a]
dest -= 1
cursor_a -= 1
acount += 1
bcount = 0
if acount >= min_gallop:
break
else:
self.items[dest] = tmp[cursor_b]
dest -= 1
cursor_b -= 1
bcount += 1
acount = 0
if bcount >= min_gallop:
break
if cursor_a < base_a or cursor_b < 0:
break
min_gallop += 1
while True:
min_gallop -= min_gallop > 1
self.min_gallop = min_gallop
k = self.gallop_right(self.items, base_a, cursor_a, tmp[cursor_b], hint=cursor_a)
acount = cursor_a - k + 1
if acount > 0:
self.items[dest - acount + 1 : dest + 1] = self.items[k : cursor_a + 1]
dest -= acount
cursor_a -= acount
if cursor_a < base_a:
break
self.items[dest] = tmp[cursor_b]
dest -= 1
cursor_b -= 1
if cursor_b < 0:
break
k = self.gallop_left(tmp, 0, cursor_b, self.items[cursor_a], hint=cursor_b)
bcount = cursor_b - k + 1
if bcount > 0:
self.items[dest - bcount + 1 : dest + 1] = tmp[k : cursor_b + 1]
dest -= bcount
cursor_b -= bcount
if cursor_b < 0:
break
self.items[dest] = self.items[cursor_a]
dest -= 1
cursor_a -= 1
if cursor_a < base_a:
break
if acount < 7 and bcount < 7:
break
# 退出 Gallop 模式的惩罚
min_gallop += 1
self.min_gallop = max(1, min_gallop)
if cursor_b >= 0:
self.items[dest - cursor_b : dest + 1] = tmp[0 : cursor_b + 1]
def merge_at(self, run_a: tuple[int, int], run_b: tuple[int, int]):
base_a, len_a = run_a
base_b, len_b = run_b
# 查找 B[0] 在 A 中的插入点,跳过 A 中已经到位的元素
# 使用 gallop_right 从 A 的开头搜索,复杂度 O(log K)
k = self.gallop_right(
self.items, base_a, base_a + len_a - 1, self.items[base_b], hint=base_a
)
skipped_a = k - base_a
new_base_a = base_a + skipped_a
new_len_a = len_a - skipped_a
if new_len_a == 0:
return
# 查找 A[-1] 在 B 中的插入点,跳过 B 中已经到位的元素
# 使用 gallop_left 从 B 的末尾搜索,复杂度 O(log K)
m = self.gallop_left(
self.items,
base_b,
base_b + len_b - 1,
self.items[new_base_a + new_len_a - 1],
hint=base_b + len_b - 1,
)
new_len_b = m - base_b
if new_len_b == 0:
return
# 根据剩余待合并部分的长度选择 merge_low 或 merge_high
if new_len_a <= new_len_b:
self.merge_low(new_base_a, new_len_a, base_b, new_len_b)
else:
self.merge_high(new_base_a, new_len_a, base_b, new_len_b)
def merge_collapse(self):
"""
X > Y + Z
Y > Z
"""
while len(self.stack) > 1:
if len(self.stack) >= 3:
start_x, run_len_x = self.stack[len(self.stack) - 3]
start_y, run_len_y = self.stack[len(self.stack) - 2]
start_z, run_len_z = self.stack[len(self.stack) - 1]
if not (run_len_x > run_len_y + run_len_z and run_len_y > run_len_z):
if run_len_x > run_len_z:
# merge Y and Z
self.merge_at(
self.stack[len(self.stack) - 2],
self.stack[len(self.stack) - 1],
)
self.stack.pop(len(self.stack) - 1) # pop Z
self.stack[len(self.stack) - 1] = (
start_y,
run_len_y + run_len_z,
) # update Y -> Y+Z
# final X, Y+Z
else:
# merge X and Y
self.merge_at(
self.stack[len(self.stack) - 3],
self.stack[len(self.stack) - 2],
)
self.stack.pop(len(self.stack) - 2) # pop Y
self.stack[len(self.stack) - 2] = (
start_x,
run_len_x + run_len_y,
) # update X -> X+Y
# final X+Y, Z
continue
elif len(self.stack) == 2:
start_x, run_len_x = self.stack[len(self.stack) - 2]
start_y, run_len_y = self.stack[len(self.stack) - 1]
if run_len_x <= run_len_y:
# merge X and Y
self.merge_at(
self.stack[len(self.stack) - 2], self.stack[len(self.stack) - 1]
)
self.stack.pop(len(self.stack) - 1) # pop Y
self.stack[len(self.stack) - 1] = (
start_x,
run_len_x + run_len_y,
) # update X -> X+Y
# final X+Y
continue
break # 只有一个 run 或者栈顶满足 X > Y + Z and Y > Z 的情况
def sort(self):
i = 0
n = len(self.items)
items = self.items
while i < n:
start = i
if i < n - 1:
if items[start] <= items[start + 1]: # 升序 run
while i < n - 1 and items[i] <= items[i + 1]:
i += 1
else:
while i < n - 1 and items[i] > items[i + 1]: # 降序 run
i += 1
items[start : i + 1] = items[start : i + 1][::-1]
i += 1
run_len = i - start
if run_len < self.minrun and i < n:
dist = min(self.minrun, n - start)
self.binary_insert_sort(items, start, start + dist - 1)
i = start + dist
run_len = dist
self.stack.append((start, run_len))
self.merge_collapse()
if len(self.stack) > 1:
# force merge remaining runs
while len(self.stack) > 1:
self.merge_at(
self.stack[len(self.stack) - 2], self.stack[len(self.stack) - 1]
)
start_x, run_len_x = self.stack[len(self.stack) - 2]
start_y, run_len_y = self.stack[len(self.stack) - 1]
self.stack.pop(len(self.stack) - 1) # pop Y
self.stack[len(self.stack) - 1] = (
start_x,
run_len_x + run_len_y,
) # update X -> X+Y
if __name__ == "__main__":
import random
s = [random.randint(0, 100) for _ in range(100)]
TimSort(s).sort()
print(s)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment