Skip to content

Instantly share code, notes, and snippets.

@MaskRay
Last active December 15, 2025 02:03
Show Gist options
  • Select an option

  • Save MaskRay/0d2655c81dd45d8fa0f4da2763f4c818 to your computer and use it in GitHub Desktop.

Select an option

Save MaskRay/0d2655c81dd45d8fa0f4da2763f4c818 to your computer and use it in GitHub Desktop.
Weak AVL tree
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <numeric>
#include <random>
#include <vector>
using namespace std;
struct Node {
Node *ch[2]{};
uintptr_t par_and_flg{};
int i{}, sum{}, size{};
Node *parent() const { return reinterpret_cast<Node*>(par_and_flg & ~3UL); }
void set_parent(Node *p) { par_and_flg = (par_and_flg & 3) | reinterpret_cast<uintptr_t>(p); }
uintptr_t flags() const { return par_and_flg & 3; }
bool rd2(int d) const { return par_and_flg & (1 << d); }
void flip(int d) { par_and_flg ^= (1 << d); }
void clr_flags() { par_and_flg &= ~3UL; }
void mconcat() {
sum = i;
size = 1;
if (ch[0]) sum += ch[0]->sum, size += ch[0]->size;
if (ch[1]) sum += ch[1]->sum, size += ch[1]->size;
}
bool operator<(const Node &o) const { return i < o.i; }
};
struct WAVL {
Node *root{};
~WAVL() {
auto destroy = [](auto &self, Node *n) -> void {
if (!n) return;
self(self, n->ch[0]);
self(self, n->ch[1]);
delete n;
};
destroy(destroy, root);
}
Node *rotate(Node *x, int d) {
auto pivot = x->ch[d];
if ((x->ch[d] = pivot->ch[d^1])) x->ch[d]->set_parent(x);
pivot->set_parent(x->parent());
if (!x->parent()) root = pivot;
else x->parent()->ch[x != x->parent()->ch[0]] = pivot;
pivot->ch[d^1] = x;
x->set_parent(pivot);
x->mconcat();
return pivot;
}
void insert(Node *x) {
Node *p = nullptr;
int d = 0;
for (auto tmp = root; tmp; ) {
p = tmp;
d = *p < *x;
tmp = tmp->ch[d];
}
x->par_and_flg = reinterpret_cast<uintptr_t>(p);
x->ch[0] = x->ch[1] = nullptr;
if (!p) return root = x, x->mconcat();
p->ch[d] = x;
auto *x2 = x;
if (p->rd2(d)) {
p->flip(d);
} else {
assert(p->rd2(d^1) == 0);
p->flip(d^1);
int d1 = d;
for (x = p, p = x->parent(); p; x = p, p = x->parent()) {
d = (p->ch[1] == x);
if (p->rd2(d)) {
p->flip(d);
break;
}
p->flip(d^1);
if (!p->rd2(d ^ 1)) {
if ((d^1) == d1) {
assert(!x->rd2(d1) && (x->ch[d1] == x2 || x->ch[d1]->flags() == 1 || x->ch[d1]->flags() == 2));
x->flip(d);
auto y = rotate(x, d^1); // y is previous x
if (y->rd2(d))
x->flip(d^1);
else if (y->rd2(d^1))
p->flip(d);
x = y;
}
x = rotate(p, d);
x->clr_flags();
break;
}
d1 = d;
}
}
for (; x2; x2 = x2->parent()) x2->mconcat();
}
void remove(Node *x) {
auto old = x;
auto p = x->parent();
auto right = x->ch[1];
Node *child;
if (!x->ch[0]) x = child = right;
else if (!right) x = child = x->ch[0];
else {
if (!(child = right->ch[0])) {
child = right->ch[1];
p = x = right;
} else {
do x = child; while ((child = x->ch[0]));
child = x->ch[1];
p = x->parent();
p->ch[0] = child;
old->ch[1]->set_parent(x);
x->ch[1] = old->ch[1];
}
old->ch[0]->set_parent(x);
x->ch[0] = old->ch[0];
x->par_and_flg = old->par_and_flg;
}
if (!old->parent()) root = x;
else old->parent()->ch[old != old->parent()->ch[0]] = x;
if (child) child->set_parent(p);
Node *x2 = p;
if (p) {
x = child;
if (p->ch[0] == x && p->ch[1] == x) {
p->clr_flags();
x = p;
p = x->parent();
}
while (p) {
int d2 = (p->ch[1] == x);
if (!p->rd2(d2)) {
p->flip(d2);
break;
}
if (p->rd2(d2 ^ 1)) {
p->flip(d2 ^ 1);
x = p;
p = x->parent();
continue;
}
auto sib = p->ch[d2^1];
if (sib->flags() == 3) {
sib->clr_flags();
x = p;
p = x->parent();
continue;
}
sib->flip(d2^1);
if (sib->rd2(d2))
p->flip(d2);
else if (!sib->rd2(d2^1)) {
p->flip(d2);
x = rotate(sib, d2);
if (x->rd2(d2^1)) sib->flip(d2);
if (x->rd2(d2)) p->flip(d2^1);
x->par_and_flg |= 3;
}
rotate(p, d2^1);
break;
}
}
for (; x2; x2 = x2->parent()) x2->mconcat();
}
Node *find(int key) const {
auto tmp = root;
while (tmp) {
if (key < tmp->i) tmp = tmp->ch[0];
else if (key > tmp->i) tmp = tmp->ch[1];
else return tmp;
}
return nullptr;
}
Node *min() const {
Node *p = nullptr;
for (auto n = root; n; n = n->ch[0]) p = n;
return p;
}
int rank(int key) const {
int r = 0;
for (auto n = root; n; ) {
if (key <= n->i) n = n->ch[0];
else {
r += 1 + (n->ch[0] ? n->ch[0]->size : 0);
n = n->ch[1];
}
}
return r;
}
int select(int k) const {
auto x = root;
while (x) {
int lsz = x->ch[0] ? x->ch[0]->size : 0;
if (k < lsz) x = x->ch[0];
else if (k == lsz) return x->i;
else k -= lsz + 1, x = x->ch[1];
}
return -1;
}
int prev(int key) const {
int res = -1;
for (auto x = root; x; )
if (key <= x->i) x = x->ch[0];
else { res = x->i; x = x->ch[1]; }
return res;
}
int next(int key) const {
int res = -1;
for (auto x = root; x; )
if (key >= x->i) x = x->ch[1];
else { res = x->i; x = x->ch[0]; }
return res;
}
static Node *next(Node *x) {
if (x->ch[1]) {
x = x->ch[1];
while (x->ch[0]) x = x->ch[0];
} else {
while (x->parent() && x == x->parent()->ch[1]) x = x->parent();
x = x->parent();
}
return x;
}
};
void print_tree(Node *n, int d = 0) {
if (!n) { printf("%*snil\n", 2*d, ""); return; }
print_tree(n->ch[0], d + 1);
printf("%*s%d (%d,%d)\n", 2*d, "", n->i, n->rd2(0) ? 2 : 1, n->rd2(1) ? 2 : 1);
print_tree(n->ch[1], d + 1);
}
int compute_rank(Node *n, bool debug = false) {
if (!n) return -1;
int lr = compute_rank(n->ch[0], debug), rr = compute_rank(n->ch[1], debug);
if (lr < -1 || rr < -1) return -2;
int rank_l = lr + (n->rd2(0) ? 2 : 1);
int rank_r = rr + (n->rd2(1) ? 2 : 1);
if (rank_l != rank_r) {
if (debug) printf("node %d: rank mismatch left=%d right=%d\n", n->i, rank_l, rank_r);
return -2;
}
if (!n->ch[0] && !n->ch[1] && n->flags() != 0) {
if (debug) printf("node %d: leaf must be 1,1 but flags=%lu\n", n->i, n->flags());
return -2;
}
int expected_sum = n->i + (n->ch[0] ? n->ch[0]->sum : 0) + (n->ch[1] ? n->ch[1]->sum : 0);
if (n->sum != expected_sum) {
if (debug) printf("node %d: sum mismatch got=%d expected=%d\n", n->i, n->sum, expected_sum);
return -2;
}
int expected_size = 1 + (n->ch[0] ? n->ch[0]->size : 0) + (n->ch[1] ? n->ch[1]->size : 0);
if (n->size != expected_size) {
if (debug) printf("node %d: size mismatch got=%d expected=%d\n", n->i, n->size, expected_size);
return -2;
}
return rank_l;
}
bool verify_tree(const WAVL &tree, bool verbose = false) {
int rank = compute_rank(tree.root);
if (rank < -1) {
printf("INVALID TREE\n");
compute_rank(tree.root, true);
return false;
}
if (verbose) printf("Tree verified, rank = %d\n", rank);
return true;
}
int main() {
srand(42);
WAVL tree;
int i = 0;
std::vector<int> a(20);
std::iota(a.begin(), a.end(), 1);
std::shuffle(a.begin(), a.end(), std::default_random_engine(42));
for (int val : a) {
auto n = new Node;
n->i = val;
tree.insert(n);
if (i++ < 6) {
printf("-- %d After insertion of %d\n", i, val);
print_tree(tree.root);
}
}
printf("\nSum\tof values = %d\n", tree.root->sum);
verify_tree(tree, true);
for (int val : {5, 10, 15}) {
if (auto found = tree.find(val)) {
tree.remove(found);
delete found;
}
}
printf("After removing 5, 10, 15:\n");
printf("\nSum\tof values = %d\n", tree.root->sum);
verify_tree(tree, true);
std::vector<Node*> ref;
for (auto n = tree.min(); n; n = WAVL::next(n)) ref.push_back(n);
for (int i = 0; i < 100000; i++) {
if (ref.size() < 5 || (ref.size() < 1000 && rand() % 2 == 0)) {
auto n = new Node;
n->i = rand() % 100000;
tree.insert(n);
ref.push_back(n);
} else {
int idx = rand() % ref.size();
tree.remove(ref[idx]);
delete ref[idx];
ref[idx] = ref.back();
ref.pop_back();
}
if (i%100 == 0 && !verify_tree(tree)) {
printf("FAILED at iteration %d\n", i);
return 1;
}
}
while (!ref.empty()) {
tree.remove(ref.back());
delete ref.back();
ref.pop_back();
if (tree.root && !verify_tree(tree)) {
printf("FAILED during final cleanup\n");
return 1;
}
}
printf("Stress test passed\n");
// Test rank, select, prev, next
printf("\nTesting rank/select/prev/next...\n");
std::vector<int> vals = {10, 20, 30, 40, 50};
for (int v : vals) {
auto n = new Node;
n->i = v;
tree.insert(n);
}
// rank tests (number of elements < key)
assert(tree.rank(5) == 0);
assert(tree.rank(10) == 0);
assert(tree.rank(15) == 1);
assert(tree.rank(20) == 1);
assert(tree.rank(25) == 2);
assert(tree.rank(50) == 4);
assert(tree.rank(55) == 5);
// select tests (0-indexed)
assert(tree.select(0) == 10);
assert(tree.select(1) == 20);
assert(tree.select(2) == 30);
assert(tree.select(3) == 40);
assert(tree.select(4) == 50);
assert(tree.select(5) == -1);
// prev tests (largest < key)
assert(tree.prev(10) == -1);
assert(tree.prev(11) == 10);
assert(tree.prev(20) == 10);
assert(tree.prev(21) == 20);
assert(tree.prev(50) == 40);
assert(tree.prev(55) == 50);
// next tests (smallest > key)
assert(tree.next(5) == 10);
assert(tree.next(10) == 20);
assert(tree.next(15) == 20);
assert(tree.next(40) == 50);
assert(tree.next(50) == -1);
assert(tree.next(55) == -1);
printf("rank/select/prev/next tests passed\n");
}
// luogu P3369 【模板】普通平衡树
#include <algorithm>
#include <cstdint>
#include <cstdio>
using namespace std;
struct Node {
uint32_t ch[2]{};
uint32_t par_and_flg{};
int i{}, size{};
uint32_t parent() const { return par_and_flg & 0x3FFFFFFFU; }
void set_parent(uint32_t p) { par_and_flg = (par_and_flg & 0xC0000000U) | p; }
uint32_t flags() const { return par_and_flg >> 30; }
bool rd2(int d) const { return par_and_flg & (1U << (30 + d)); }
void toggle(int d) { par_and_flg ^= (1U << (30 + d)); }
void clr_flags() { par_and_flg &= 0x3FFFFFFFU; }
};
Node pool[100001];
uint32_t pool_cnt = 1;
uint32_t new_node(int i) {
pool[pool_cnt].i = i;
// pool[pool_cnt].size = 0;
// pool[pool_cnt].ch[0] = pool[pool_cnt].ch[1] = 0;
// pool[pool_cnt].par_and_flg = 0;
return pool_cnt++;
}
struct WAVL {
uint32_t root{};
static void mconcat(uint32_t x) {
pool[x].size = 1 + pool[pool[x].ch[0]].size + pool[pool[x].ch[1]].size;
}
uint32_t rotate(uint32_t x, int d) {
auto pivot = pool[x].ch[d];
if ((pool[x].ch[d] = pool[pivot].ch[d^1])) pool[pool[x].ch[d]].set_parent(x);
pool[pivot].set_parent(pool[x].parent());
if (!pool[x].parent()) root = pivot;
else pool[pool[x].parent()].ch[x != pool[pool[x].parent()].ch[0]] = pivot;
pool[pivot].ch[d^1] = x;
pool[x].set_parent(pivot);
mconcat(x);
return pivot;
}
void insert(uint32_t x) {
uint32_t p = 0;
int d = 0;
for (auto tmp = root; tmp; ) {
p = tmp;
d = pool[p].i < pool[x].i;
tmp = pool[tmp].ch[d];
}
pool[x].par_and_flg = p;
if (!p) { root = x; mconcat(x); return; }
pool[p].ch[d] = x;
auto x2 = x;
if (pool[p].rd2(d)) {
pool[p].toggle(d);
} else {
pool[p].toggle(d^1);
int d1 = d;
for (x = p, p = pool[x].parent(); p; x = p, p = pool[x].parent()) {
d = (pool[p].ch[1] == x);
if (pool[p].rd2(d)) {
pool[p].toggle(d);
break;
}
pool[p].toggle(d^1);
if (!pool[p].rd2(d^1)) {
if ((d^1) == d1) {
pool[x].toggle(d);
auto y = rotate(x, d^1);
if (pool[y].rd2(d))
pool[x].toggle(d^1);
else if (pool[y].rd2(d^1))
pool[p].toggle(d);
x = y;
}
x = rotate(p, d);
pool[x].clr_flags();
break;
}
d1 = d;
}
}
for (; x2; x2 = pool[x2].parent()) mconcat(x2);
}
uint32_t remove(uint32_t x) {
auto old = x;
auto p = pool[x].parent();
auto right = pool[x].ch[1];
uint32_t child;
if (!pool[x].ch[0]) x = child = right;
else if (!right) x = child = pool[x].ch[0];
else {
if (!(child = pool[right].ch[0])) {
child = pool[right].ch[1];
p = x = right;
} else {
do x = child; while ((child = pool[x].ch[0]));
child = pool[x].ch[1];
p = pool[x].parent();
pool[p].ch[0] = child;
pool[pool[old].ch[1]].set_parent(x);
pool[x].ch[1] = pool[old].ch[1];
}
pool[pool[old].ch[0]].set_parent(x);
pool[x].ch[0] = pool[old].ch[0];
pool[x].par_and_flg = pool[old].par_and_flg;
}
if (!pool[old].parent()) root = x;
else pool[pool[old].parent()].ch[old != pool[pool[old].parent()].ch[0]] = x;
if (child) pool[child].set_parent(p);
uint32_t x2 = p;
if (p) {
x = child;
if (pool[p].ch[0] == x && pool[p].ch[1] == x) {
pool[p].clr_flags();
x = p;
p = pool[x].parent();
}
while (p) {
int d2 = (pool[p].ch[1] == x);
if (!pool[p].rd2(d2)) {
pool[p].toggle(d2);
break;
}
if (pool[p].rd2(d2 ^ 1)) {
pool[p].toggle(d2 ^ 1);
x = p;
p = pool[x].parent();
continue;
}
auto sib = pool[p].ch[d2^1];
if (pool[sib].flags() == 3) {
pool[sib].clr_flags();
x = p;
p = pool[x].parent();
continue;
}
pool[sib].toggle(d2^1);
if (pool[sib].rd2(d2))
pool[p].toggle(d2);
else if (!pool[sib].rd2(d2^1)) {
pool[p].toggle(d2);
x = rotate(sib, d2);
if (pool[x].rd2(d2^1)) pool[sib].toggle(d2);
if (pool[x].rd2(d2)) pool[p].toggle(d2^1);
pool[x].par_and_flg |= 0xC0000000U;
}
rotate(p, d2^1);
break;
}
}
for (; x2; x2 = pool[x2].parent()) mconcat(x2);
return old;
}
uint32_t find(int key) const {
auto tmp = root;
while (tmp) {
if (key < pool[tmp].i) tmp = pool[tmp].ch[0];
else if (key > pool[tmp].i) tmp = pool[tmp].ch[1];
else return tmp;
}
return 0;
}
int rank(int key) const {
int r = 0;
for (auto n = root; n; ) {
if (key <= pool[n].i) n = pool[n].ch[0];
else {
r += 1 + (pool[n].ch[0] ? pool[pool[n].ch[0]].size : 0);
n = pool[n].ch[1];
}
}
return r;
}
int select(int k) const {
auto x = root;
while (x) {
int lsz = pool[x].ch[0] ? pool[pool[x].ch[0]].size : 0;
if (k < lsz) x = pool[x].ch[0];
else if (k == lsz) return pool[x].i;
else k -= lsz + 1, x = pool[x].ch[1];
}
return -1;
}
int prev(int key) const {
int res = -1;
for (auto x = root; x; )
if (key <= pool[x].i) x = pool[x].ch[0];
else { res = pool[x].i; x = pool[x].ch[1]; }
return res;
}
int next(int key) const {
int res = -1;
for (auto x = root; x; )
if (key >= pool[x].i) x = pool[x].ch[1];
else { res = pool[x].i; x = pool[x].ch[0]; }
return res;
}
};
template <class T> void read(T &x) {
int s = 0, w = 1;
char c = getchar_unlocked();
while (c > '9' || c < '0') {
if (c == '-') w = -1;
c = getchar_unlocked();
}
while (c >= '0' && c <= '9')
s = s * 10 + c - '0', c = getchar_unlocked();
x = s * w;
}
inline void print(int x) {
if (x < 0)
putchar_unlocked('-'), x = -x;
do putchar_unlocked(x%10 + '0');
while (x /= 10);
}
int main() {
WAVL rb;
int n;
read(n);
int opt, x;
for (int i = 0; i < n; i++) {
read(opt);
read(x);
switch (opt) {
case 1:
rb.insert(new_node(x));
break;
case 2: {
auto it = rb.find(x);
if (it)
rb.remove(it);
} break;
case 3:
printf("%d\n", rb.rank(x)+1);
break;
case 4:
printf("%d\n", rb.select(x-1));
break;
case 5:
printf("%d\n", rb.prev(x));
break;
case 6:
printf("%d\n", rb.next(x));
break;
default:
break;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment