2022年9月13日

[C++] Metaprogram for Getting Bit Size & Converting to Power of 2

大家有沒有遇過這種狀況呢? 一些資料結構 (ex: heap, binary tree) 的節點數量往往是 2 的冪次方 - 1;或者是為了程式效率,把陣列或是 struct 的大小設定成 2 的冪次方。以往在寫這類數值時通常會利用 #define 或者是 constexpr 等等之類的方法把這種數值變成某個有意義的變數,像是這樣:

constexpr int BITSIZE = 10;
constexpr int BUFSIZE = 1024;

但這種時候會遇到一個小麻煩:開發過程中為了程式效率會常常調整這些數值,所以為了盡可能減少錯誤會有各種方式來做這些基礎設定。比方說 BUFSIZE 就可以改成 (1 << BITSIZE) 來避免改了 BITSIZE 忘記改掉 BUFSIZE。

當然,現在 compiler 已經很厲害了,如果某些基礎設定值在編譯期 (compile-time) 就是常數,現在也有不少方法可以讓後續的衍伸運算也都變成編譯期的常數,從而減少執行期的時間。這篇文章要做的主要是介紹用 metaprogramming 把這些運算通通轉成編譯期的常數。

常用的第一個運算就是給定一個 2 的冪次方數,想要算出他的刺方數值。最簡單的方法其實就是寫個 loop:

int exp = 0;
for (; N != 0; N >>= 1, ++exp);

簡單,直覺,但是這會真的變成一段程式碼在執行期運算。
所以我們反過來想,反正一個變數最多也就 64 bits,常用的也是 32 bits 而已,我們直接建表就好。在 C++ 最簡單的建表方法就是利用 template:

template <unsigned N>
struct Exp2;

template <>
struct Exp2<0x0000'0001> {
    static constexpr int val = 0;
};
template <>
struct Exp2<0x0000'0002> {
    static constexpr int val = 1;
};
template <>
struct Exp2<0x0000'0004> {
    static constexpr int val = 2;
};
...

下一個問題是如果給定的數值不是 2 的冪次方呢? (比方說希望最小數值有 10,000,但這數值不是 2 的冪次方) 這時候我們可以把他轉成最接近的 2 的冪次方。

template <unsigned N>
struct Exp2 {
    // N & (~N + 1): get the right most set bit
    // ex: 0x1010 -> N & (~N + 1) will obtain 0x0010
    // Hence N ^ (N & (~N + 1)) will unset the rightmost set bit
    // In the end, this value will be the exponent part of associated with the leftmost set bit
    // ex: 0x1010 -> leftmost set bit is 0x1000 -> exponent part is 3 (2^3 = 0x1000)
    static constexpr int val = Exp2<N ^ (N & (~N +1))>::val;
};

template <unsigned N>
inline constexpr unsigned toPow2()
{
    // convert the given value to the nearest (but not less than original) value
    // which is power of 2.
    if constexpr (N == Exp2<N>::val) {
        return N;
    }
    else {
        // ex: let say N = 8200, MASK = (1 << (13 + 1)) - 1 = 0x3FFF
        // (N + MASK) & (~MASK) = (0x2008 + 0x3FFF) & (0xFFFF'C000) = 0x4000
        constexpr auto MASK = (1 << (Exp2<N>::val + 1)) - 1;
        return (N + MASK) & (~MASK);
    }
}

結合起來後我們就可以簡單地把任意數值在編譯期算出最接近且比他大的 2 的冪次方數:

auto i = toPow2<8200>();
// assembly code:
// mov     DWORD PTR [rbp-4], 16384

從對應的組合語言就可以發現編譯器已經直接把最接近 8200 的 2 的冪次方數 16384 算出來,並且避免產生多餘的程式碼

完整程式碼如下:

namespace util
{
template <unsigned N>
struct Exp2 {
// N & (~N + 1): get the right most set bit
// ex: 0x1010 -> N & (~N + 1) will obtain 0x0010
// Hence N ^ (N & (~N + 1)) will unset the rightmost set bit
// In the end, this value will be the exponent part of associated with the leftmost set bit
// ex: 0x1010 -> leftmost set bit is 0x1000 -> exponent part is 3 (2^3 = 0x1000)
static constexpr int val = Exp2<N ^ (N & (~N +1))>::val;
};
template <>
struct Exp2<0x0000'0000> {
static constexpr int val = 0;
};
template <>
struct Exp2<0x0000'0001> {
static constexpr int val = 0;
};
template <>
struct Exp2<0x0000'0002> {
static constexpr int val = 1;
};
template <>
struct Exp2<0x0000'0004> {
static constexpr int val = 2;
};
template <>
struct Exp2<0x0000'0008> {
static constexpr int val = 3;
};
template <>
struct Exp2<0x0000'0010> {
static constexpr int val = 4;
};
template <>
struct Exp2<0x0000'0020> {
static constexpr int val = 5;
};
template <>
struct Exp2<0x0000'0040> {
static constexpr int val = 6;
};
template <>
struct Exp2<0x0000'0080> {
static constexpr int val = 7;
};
template <>
struct Exp2<0x0000'0100> {
static constexpr int val = 8;
};
template <>
struct Exp2<0x0000'0200> {
static constexpr int val = 9;
};
template <>
struct Exp2<0x0000'0400> {
static constexpr int val = 10;
};
template <>
struct Exp2<0x0000'0800> {
static constexpr int val = 11;
};
template <>
struct Exp2<0x0000'1000> {
static constexpr int val = 12;
};
template <>
struct Exp2<0x0000'2000> {
static constexpr int val = 13;
};
template <>
struct Exp2<0x0000'4000> {
static constexpr int val = 14;
};
template <>
struct Exp2<0x0000'8000> {
static constexpr int val = 15;
};
template <>
struct Exp2<0x0001'0000> {
static constexpr int val = 16;
};
template <>
struct Exp2<0x0002'0000> {
static constexpr int val = 17;
};
template <>
struct Exp2<0x0004'0000> {
static constexpr int val = 18;
};
template <>
struct Exp2<0x0008'0000> {
static constexpr int val = 19;
};
template <>
struct Exp2<0x0010'0000> {
static constexpr int val = 20;
};
template <>
struct Exp2<0x0020'0000> {
static constexpr int val = 21;
};
template <>
struct Exp2<0x0040'0000> {
static constexpr int val = 22;
};
template <>
struct Exp2<0x0080'0000> {
static constexpr int val = 23;
};
template <>
struct Exp2<0x0100'0000> {
static constexpr int val = 24;
};
template <>
struct Exp2<0x0200'0000> {
static constexpr int val = 25;
};
template <>
struct Exp2<0x0400'0000> {
static constexpr int val = 26;
};
template <>
struct Exp2<0x0800'0000> {
static constexpr int val = 27;
};
template <>
struct Exp2<0x1000'0000> {
static constexpr int val = 28;
};
template <>
struct Exp2<0x2000'0000> {
static constexpr int val = 29;
};
template <>
struct Exp2<0x4000'0000> {
static constexpr int val = 30;
};
template <>
struct Exp2<0x8000'0000> {
static constexpr int val = 31;
};
}
template <unsigned N>
inline constexpr unsigned toPow2()
{
// convert the given value to the nearest (but not less than original) value
// which is power of 2.
if constexpr (N == util::Exp2<N>::val) {
return N;
}
else {
// ex: let say N = 8200, MASK = (1 << (13 + 1)) - 1 = 0x3FFF
// (N + MASK) & (~MASK) = (0x2008 + 0x3FFF) & (0xFFFF'C000) = 0x4000
constexpr auto MASK = (1 << (util::Exp2<N>::val + 1)) - 1;
return (N + MASK) & (~MASK);
}
}
int main()
{
auto i = toPow2<8200>();
return 0;
}
/* Associated assembly code: (by x86-64 gcc-12.2)
main:
push rbp
mov rbp, rsp
# toPow2<8200>() becomes a constant value 16384 here
mov DWORD PTR [rbp-4], 16384
mov eax, 0
pop rbp
ret
*/
view raw exponent_2.cpp hosted with ❤ by GitHub

沒有留言:

張貼留言