#include <iostream>
#include <type_traits>
#include <array>
#include <cstddef>
namespace wrap {
template <std::size_t N1, std::size_t N2, std::size_t N3>
void internal(std::array<float, N1 * N2> &, std::array<float, N3> &) {
std::cout << "Called internal float. A = " << N1 << "," << N2 << " b = " << N3 << std::endl;
}
template <std::size_t N1, std::size_t N2, std::size_t N3>
void internal(std::array<double, N1 * N2> &, std::array<double, N3> &) {
std::cout << "Called internal double. A = " << N1 << ", " << N2 << " b = " << N3 << std::endl;
}
}
/** Select the maximum between two constant number at compile time */
template < std::size_t N1, std::size_t N2 >
struct biggest_dim {
static std::size_t const value = std::conditional< N1 >= N2, std::integral_constant< std::size_t, N1 >,
std::integral_constant< std::size_t, N2 > >::type::value;
};
/** Creating a type to accomodate the vector */
template < typename REAL_T, std::size_t N1, std::size_t N2 >
using b_array_t = std::array< REAL_T, biggest_dim< N1, N2 >::value >;
/** Here we have the function that allows only the call with b of
* the correct size to continue */
template < typename REAL_T, std::size_t N1, std::size_t N2 >
void solve(std::array< REAL_T, N1 * N2 > & A, b_array_t< REAL_T, N1, N2 > & b) {
wrap::internal< N1, N2, biggest_dim< N1, N2 >::value >(A, b);
}
int main() {
/* Case 1: A = 2 x 2, b = max(rows, cols) = max(2, 2) = 2: it works */
const std::size_t n1_1 = 2;
const std::size_t n2_1 = 2;
std::array<double, n1_1 * n2_1> A_1 = {0};
std::array<double, n2_1> b_1 = {0};
solve<double, n1_1, n2_1>(A_1, b_1);
/* Case 2: A = 2 x 1, b = max(rows, cols) = max(2, 1) = 2: it works */
const std::size_t n1_2 = 2;
const std::size_t n2_2 = 1;
std::array<double, n1_2 * n2_2> A_2 = {0};
std::array<double, n1_2> b_2 = {0};
solve<double, n1_2, n2_2>(A_2, b_2);
/* Case 3: A = 1 x 2, b = max(rows, cols) = max(1, 2) = 2: it works */
const std::size_t n1_3 = 1;
const std::size_t n2_3 = 2;
std::array<double, n1_3 * n2_3> A_3 = {0};
std::array<double, n2_3> b_3 = {0};
solve<double, n1_3, n2_3>(A_3, b_3);
/* Case 4: A = 1 x 2, b = 1!!: it MUST RAISE A STATIC_ASSERTION error */
/* as for now it raises an error that it is difficult to understand (and it is more difficult in a more complex library setup) */
/* Uncomment this for see the compile error
const std::size_t n1_4 = 1;
const std::size_t n2_4 = 2;
std::array<double, n1_4 * n2_4> A_4 = {0};
std::array<double, n1_4> b_4 = {0};
solve<double, n1_4, n2_4>(A_4, b_4);
*/
return 0;
}