#include <iostream>
template<class S, class T>
class node {
public:
S key;
T value;
node<S, T> *right_node = nullptr, *left_node = nullptr, *nearest_right_node = nullptr, *nearest_left_node = nullptr;
int height;
public:
template<typename, typename> friend
class avl_tree;
node() = default;
node(const S &key, const T &value, node<S, T> *right_node = nullptr, node<S, T> *left_node = nullptr,
node<S, T> *nearest_right_node = nullptr, node<S, T> *nearest_left_node = nullptr) : key(key), value(value),
right_node(right_node),
left_node(left_node),
nearest_right_node(
nearest_right_node),
nearest_left_node(
nearest_left_node),
height(0) {
}
~node() = default;
};
template<class S, class T>
class avl_tree {
public:
int get_height(const node<S, T> *ptr);
node<S, T> *rotate_right(node<S, T> *ptr);
node<S, T> *rotate_left(node<S, T> *ptr);
node<S, T> *maximum_node(node<S, T> *ptr);
node<S, T> *internal_remove(node<S, T> *ptr, const S &key);
node<S, T> *internal_insert(node<S, T> *ptr, const S &key, const T &value);
node<S, T> *nearest_left(node<S, T> *ptr, const S &key);
node<S, T> *nearest_right(node<S, T> *ptr, const S &key);
void update_minimum();
void clear_tree(node<S, T> *ptr);
node<S, T> *root = nullptr;
node<S, T> *minimal_node = nullptr;
avl_tree();
~avl_tree();
node<S, T> *search(const S &key);
void remove(const S &key);
void insert(const S &key, const T &value);
};
template<class S, class T>
avl_tree<S, T>::avl_tree() {
root = nullptr;
minimal_node = nullptr;
}
template<class S, class T>
int avl_tree<S, T>::get_height(const node<S, T> *ptr) {
if (ptr == nullptr) {
return 0;
}
return ptr->height + 1;
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::rotate_right(node<S, T> *ptr) {
if (ptr == nullptr) {
return nullptr;
}
node<S, T> *tmp_ptr = ptr->left_node;
ptr->left_node = tmp_ptr->right_node;
tmp_ptr->right_node = ptr;
/*
tmp_ptr->parent_node = nullptr;
ptr->parent_node = tmp_ptr;
if (ptr->left_node != nullptr)
ptr->left_node->parent_node = ptr;
*/
ptr->height = std::max(get_height(ptr->left_node), get_height(ptr->right_node));
tmp_ptr->height = std::max(get_height(tmp_ptr->left_node), get_height(tmp_ptr->right_node));
return tmp_ptr;
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::rotate_left(node<S, T> *ptr) {
node<S, T> *tmp_ptr = ptr->right_node;
ptr->right_node = tmp_ptr->left_node;
tmp_ptr->left_node = ptr;
/*
tmp_ptr->parent_node = nullptr;
ptr->parent_node = tmp_ptr;
if (ptr->right_node != nullptr)
ptr->right_node->parent_node = ptr;
*/
ptr->height = std::max(get_height(ptr->left_node), get_height(ptr->right_node));
tmp_ptr->height = std::max(get_height(tmp_ptr->left_node), get_height(tmp_ptr->right_node));
return tmp_ptr;
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::search(const S &key) {
node<S, T> *ptr = root;
while (ptr != nullptr) {
if (key == ptr->key) {
break;
}
if (key < ptr->key) {
ptr = ptr->left_node;
} else {
ptr = ptr->right_node;
}
}
return ptr;
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::internal_insert(node<S, T> *ptr, const S &key, const T &value) {
//Standard BST Insert
if (ptr == nullptr) {
return ptr = new node<S, T>(key, value);
}
if (key == ptr->key) {
return ptr;
}
if (key < ptr->key) {
ptr->left_node = internal_insert(ptr->left_node, key, value);
/*
ptr->left_node->parent_node = ptr;
*/
} else {
ptr->right_node = internal_insert(ptr->right_node, key, value);
/*
ptr->right_node->parent_node = ptr;
*/
}
//One node has been inserted
int left_node_height = get_height(ptr->left_node);
int right_node_height = get_height(ptr->right_node);
ptr->height = std::max(left_node_height, right_node_height);
int balance_factor = left_node_height - right_node_height;
if (balance_factor == 2) {
if (key < ptr->left_node->key) {
//BF(V_l)>=0
return ptr = rotate_right(ptr);
} else {
//BF(V_l)==-1
ptr->left_node = rotate_left(ptr->left_node);
return ptr = rotate_right(ptr);
}
}
if (balance_factor == -2) {
if (key < ptr->right_node->key) {
//BF(V_r)==1
ptr->right_node = rotate_right(ptr->right_node);
return ptr = rotate_left(ptr);
} else {
//BF(V_r)<=0
return ptr = rotate_left(ptr);
}
}
return ptr;
}
template<class S, class T>
void avl_tree<S, T>::insert(const S &key, const T &value) {
node<S, T> *tmp1 = nearest_right(root, key);
node<S, T> *tmp2 = nearest_left(root, key);
root = internal_insert(root, key, value);
update_minimum();
node<S, T> *new_node = search(key);
new_node->nearest_right_node = tmp1;
new_node->nearest_left_node = tmp2;
if (tmp1 != nullptr) {
tmp1->nearest_left_node = new_node;
}
if (tmp2 != nullptr) {
tmp2->nearest_right_node = new_node;
}
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::internal_remove(node<S, T> *ptr, const S &key) {
//Standard BST Remove
if (ptr == nullptr) {
return nullptr;
}
if (key == ptr->key) {
if (ptr->left_node == nullptr && ptr->right_node == nullptr) {
if (ptr->left_node!= nullptr)
ptr->left_node->right_node=ptr->right_node;
if (ptr->right_node!= nullptr)
ptr->right_node->left_node=ptr->left_node;
delete ptr;
return nullptr;
} else if (ptr->left_node == nullptr && ptr->right_node != nullptr) {
node<S, T> *right_tree = ptr->right_node;
delete ptr;
return right_tree;
} else if (ptr->left_node != nullptr && ptr->right_node == nullptr) {
node<S, T> *left_tree = ptr->left_node;
delete ptr;
return left_tree;
} else //if (ptr->left_node != nullptr && ptr->right_node != nullptr)
{
node<S, T> *left_tree = maximum_node(ptr->left_node);
/****/
ptr->key = left_tree->key;
ptr->value = left_tree->value;
ptr->nearest_left_node = left_tree->nearest_left_node;
ptr->nearest_right_node = left_tree->nearest_right_node;
/****/
ptr->left_node = internal_remove(ptr->left_node, left_tree->key);
}
} else if (key < ptr->key) {
ptr->left_node = internal_remove(ptr->left_node, key);
} else //key>ptr->key
{
ptr->right_node = internal_remove(ptr->right_node, key);
}
//One node has been removed
int left_node_height = get_height(ptr->left_node);
int right_node_height = get_height(ptr->right_node);
ptr->height = std::max(left_node_height, right_node_height);
int balance_factor = left_node_height - right_node_height;
if (balance_factor == 2) {
if (get_height(ptr->left_node->left_node) > get_height(ptr->left_node->right_node)) {
return ptr = rotate_right(ptr);
} else {
ptr->left_node = rotate_left(ptr->left_node);
return ptr = rotate_right(ptr);
}
}
if (balance_factor == -2) {
if (get_height(ptr->right_node->left_node) > get_height(ptr->right_node->right_node)) {
ptr->right_node = rotate_right(ptr->right_node);
return ptr = rotate_left(ptr);
} else {
return ptr = rotate_left(ptr);
}
}
return ptr;
}
template<class S, class T>
void avl_tree<S, T>::remove(const S &key) {
root = internal_remove(root, key);
update_minimum();
node<S, T> *tmp1 = nearest_right(root, key);
node<S, T> *tmp2 = nearest_left(root, key);
if (tmp1 != nullptr) {
tmp1->nearest_left_node = tmp2;
}
if (tmp2 != nullptr) {
tmp2->nearest_right_node = tmp1;
}
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::maximum_node(node<S, T> *ptr) {
if (ptr->right_node == nullptr) {
return ptr;
}
return maximum_node(ptr->right_node);
}
template<class S, class T>
void avl_tree<S, T>::update_minimum() {
minimal_node = root;
if (root == nullptr) {
return;
}
while (minimal_node->left_node != nullptr) {
minimal_node = minimal_node->left_node;
}
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::nearest_left(node<S, T> *ptr, const S &key) {
node<S, T> *res = nullptr;
while (ptr) {
if (key < ptr->key) {
ptr = ptr->left_node;
} else {
res = ptr;
ptr = ptr->right_node;
}
}
return res;
}
template<class S, class T>
node<S, T> *avl_tree<S, T>::nearest_right(node<S, T> *ptr, const S &key) {
node<S, T> *res = nullptr;
while (ptr) {
if (key < ptr->key || key == ptr->key) {
res = ptr;
ptr = ptr->left_node;
} else {
ptr = ptr->right_node;
}
}
return res;
}
using namespace std;
int main()
{
avl_tree<int, int> *tree = new avl_tree<int, int>();
tree->insert(6, 17);
tree->insert(4, 6);
tree->insert(3, 10);
tree->insert(1, 7);
tree->insert(2, 8);
tree->insert(5, 99);
cout<<tree->search(3)->nearest_right_node->key <<endl;
cout<<tree->minimal_node->nearest_right_node->nearest_right_node->nearest_right_node->key<<endl;
tree->remove(4);
cout<<tree->search(3)->nearest_right_node->key <<endl;
cout<<tree->minimal_node->nearest_right_node->nearest_right_node->nearest_right_node->key<<endl;
return 0;
}