#include <iostream>
// AVL树节点结构
struct AVLNode {
int val;
AVLNode* left;
AVLNode* right;
int height;
AVLNode(int x) : val(x), left(nullptr), right(nullptr), height(1) {}
};
// 获取节点高度
int getHeight(AVLNode* node) {
if (node == nullptr) {
return 0;
}
return node->height;
}
// 更新节点高度
void updateHeight(AVLNode* node) {
node->height = std::max(getHeight(node->left), getHeight(node->right)) + 1;
}
// 右旋
AVLNode* rightRotate(AVLNode* y) {
AVLNode* x = y->left;
AVLNode* T2 = x->right;
// 执行旋转
x->right = y;
y->left = T2;
// 更新高度
updateHeight(y);
updateHeight(x);
return x;
}
// 左旋
AVLNode* leftRotate(AVLNode* x) {
AVLNode* y = x->right;
AVLNode* T2 = y->left;
// 执行旋转
y->left = x;
x->right = T2;
// 更新高度
updateHeight(x);
updateHeight(y);
return y;
}
// 获取平衡因子
int getBalanceFactor(AVLNode* node) {
if (node == nullptr) {
return 0;
}
return getHeight(node->left) - getHeight(node->right);
}
// 插入节点
AVLNode* insertAVL(AVLNode* root, int key) {
// 正常BST插入
if (root == nullptr) {
return new AVLNode(key);
}
if (key < root->val) {
root->left = insertAVL(root->left, key);
} else if (key > root->val) {
root->right = insertAVL(root->right, key);
} else {
// 重复键不插入
return root;
}
// 更新高度
updateHeight(root);
// 获取平衡因子
int balance = getBalanceFactor(root);
// 左左情况
if (balance > 1 && key < root->left->val) {
return rightRotate(root);
}
// 右右情况
if (balance < -1 && key > root->right->val) {
return leftRotate(root);
}
// 左右情况
if (balance > 1 && key > root->left->val) {
root->left = leftRotate(root->left);
return rightRotate(root);
}
// 右左情况
if (balance < -1 && key < root->right->val) {
root->right = rightRotate(root->right);
return leftRotate(root);
}
return root;
}