#include <iostream>
#include <vector>
using namespace std;
class Node {
int val;
Node* left;
Node* right;
public:
Node(int value) {
val = value;
left = NULL, right = NULL;
}
void setChild(Node* leftnode, Node* rightnode) {
left = leftnode, right = rightnode;
}
int getVal() {
return val;
}
Node* getLeftChild() {
return left;
}
Node* getRightChild() {
return right;
}
};
class Tree {
Node* root;
public:
Tree() {
root = NULL;
}
Node* constructTree(vector<int>, int);
void printPreOrder(Node*);
};
Node* Tree::constructTree(vector<int> values, int iter) {
if (values.size() == 0 || iter >= values.size() || values[iter] == -1)
return NULL;
Node* newNode = new Node(values[iter]);
if (root == NULL) {
root = newNode;
}
newNode->setChild(constructTree(values, 2*iter+1),
constructTree(values, 2*iter+2));
return newNode;
}
void Tree::printPreOrder(Node* node) {
if (node != NULL) {
cout << node->getVal() << " ";
printPreOrder(node->getLeftChild());
printPreOrder(node->getRightChild());
} else {
cout << "NULL ";
}
if (node == root) {
cout << endl;
}
}
bool isUnivalTree(Node* node, int *count) {
// Base case
if (node->getLeftChild() == NULL && node->getRightChild() == NULL) {
*count += 1;
return true;
}
bool isUnival = true;
if (node->getLeftChild() != NULL)
isUnival = isUnivalTree(node->getLeftChild(), count) &&
node->getVal() == node->getLeftChild()->getVal() &&
isUnival;
if (node->getRightChild() != NULL)
isUnival = isUnivalTree(node->getRightChild(), count) &&
node -> getVal() == node->getRightChild()->getVal() &&
isUnival;
// Increase the count if it forms a unival subtree
if (isUnival) {
*count += 1;
}
return isUnival;
}
int main() {
vector<int> values = {0,1,0,-1,-1,1,0,-1,-1,-1,-1,1,1};
Tree tree = Tree();
Node* root = tree.constructTree(values, 0);
tree.printPreOrder(root);
int count = 0;
isUnivalTree(root, &count);
cout << "Number of unival subtrees: " << count << endl;
}