The last bit – Part 1
The RISC-V Vector Extension (RVV) includes several instructions to operate
on masks and compute interesting things. One of them is vfirst.m that computes
the lowest-numbered element of the mask that is set.
However, there is no vlast.m instruction that computes the highest-numbered
element of the mask that is set.
The RISC-V Vector Extension
The RISC-V Vector Extension (collectively the V extension and shortened to RVV) aims at providing vector compute capabilities to the RISC-V ISA. Compared to other extensions, the V extension is rather sophisticated and provides a large number of instructions. It also features a few unusual features like having a vector length, that specifies how many elements are operated in the vector, and encoding the types of the instructions in RVV-specific architectural registers. This makes the extension rather flexible but also more involved than the more straightforward designs of other ISAs.
Typically, SIMD/vector instructions instructions are used when vectorizing code whose goal is to benefit from the higher throughput (amount of work per unit of time) of vector instructions.
Masking
Unfortunately, vectorization does not get along very well with control flow (i.e. branches/jumps in a program). The usual solution is to extend the idea of if-conversion where control flow is replaced by data-flow. After if-conversion, branches have been removed and the code gets flattened into straight-line code, i.e. without branches.
Extra code is emitted to compute special values, that we can call predicates. Predicates correspond to the conditional branches that we removed during if-conversion and they gate instructions such that executing may or may not have any architecturally visible effect depending on the predicate used by the instruction.
The vectorized version of predicates is effectively a vector of predicates. Because a predicate has only two values (enable/1/true vs disable/0/false) it can be represented with a single bit. Because of this, many SIMD/vector ISAs include support for masks (or simply predicates) to represent these vectors of predicates.
Regular data vectors, containing integers or floating point, have the different elements packed throughout the vector register. In contrast, masks in RVV are represented with a packed vector of bits stored in the lowest bits of the vector register. The layout for masks is called a mask vector.
Masks can be used when dealing with control flow such as if or even loops. In
particular, in scan-like loops (think of C library function strlen) the
instruction vfirst.m is useful to see if a mask contains a bit set to one
(for the strlen case, we can check if one of the bytes that we loaded is a
zero and act accordingly). More precisely vfirst.m returns -1 if the mask has
all elements cleared (i.e. zero) or the lowest-numbered element that is set
(set to one, that is). In the scalar world, this instruction is typically
called count trailing zeroes (often shortened to ctz).
The following table shows some examples, assuming the mask has 8 elements. By writing the elements by descending number we can clearly see why this instruction is often known as count traling zeroes: the result is exactly the number of consecutive zeros starting from the lowest element. The only special case is an all-zero mask, in this case the instruction returns -1 instead of 8.
| element number |
| |
| 7 6 5 4 3 2 1 0 | vfirst.m
----------+---------------+-----------
mask | 0 0 0 0 0 0 0 0 | -1
| 0 0 0 0 0 0 0 1 | 0
| 0 0 0 0 0 0 1 0 | 1
| 0 0 0 0 0 1 1 0 | 1
| 1 0 0 0 0 0 0 0 | 7
| 1 0 0 0 0 1 0 0 | 2
| 0 1 1 1 0 1 0 0 | 2
| 1 1 1 1 0 1 0 0 | 2
The last bit
What if instead of the lowest numbered element of a mask that is set we want
the highest numbered element instead? Something that intuitively we could
name vlast.m.
| element number |
| |
| 7 6 5 4 3 2 1 0 | vlast.m
----------+---------------+-----------
mask | 0 0 0 0 0 0 0 0 | -1
| 0 0 0 0 0 0 0 1 | 0
| 0 0 0 0 0 0 1 0 | 1
| 0 0 0 0 0 1 1 0 | 2
| 1 0 0 0 0 0 0 0 | 7
| 1 0 0 0 0 1 0 0 | 7
| 0 1 1 1 0 1 0 0 | 6
| 1 1 1 1 0 1 0 0 | 7
Unfortunately, RVV does not provide a vlast.m so we will need to compute it.
Let’s look at different options.
Reversing the mask
The first idea is reversing the mask and then applying vfirst.m. This way
the highest numbered element that is set will be, after the reversal, the lowest
numbered element set and so vfirst.m will compute what we want.
RVV has a very powerful (i.e., expensive to implement) general shuffling
instruction called vrgather.vv. We can reverse a data vector using this
instruction. Mask vectors have a different layout to data vectors, so we need
to first convert them to data vectors (not just reinterpreting them), shuffle
them and convert them back to mask vectors.
An example using vectors of 32-bit elements is shown below (for simplicity all the examples shown will be using LMUL=1 and a SEW=32).
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
vbool32_t reverse_mask_vrgather(vbool32_t m, size_t vl) {
// Create a vector of all zeros
vint32m1_t allzeros = __riscv_vmv_v_x_i32m1(0, vl);
// Set to 1 where the mask was true
vint32m1_t extended_mask = __riscv_vmerge_vxm_i32m1(allzeros, 1, m, vl);
// Create 0, 1, 2, …, vl-1
vuint32m1_t index = __riscv_vid_v_u32m1(vl);
// Compute vl-1, vl-2, …, 0
vuint32m1_t rev_index = __riscv_vrsub_vx_u32m1(index, vl - 1, vl);
// Now reverse the extended_mask
vint32m1_t reversed_extended_mask = __riscv_vrgather_vv_i32m1(extended_mask, rev_index, vl);
// Convert back to a mask
vbool32_t reversed_mask = __riscv_vmsne_vx_i32m1_b32(reversed_extended_mask, 0, vl);
return reversed_mask;
}
Note that one must use vrgatherei16 instead of vrgather if the index vector
is a 8-bit vector and there may be more than 256 elements in a vector.
Our vlast.m looks like this now:
1
2
3
4
5
int vlast_m_reverse(vbool32_t m, size_t vl) {
vbool32_t reversed_mask = reverse_mask_vrgather(m, vl);
long l = __riscv_vfirst_m_b32(reversed_mask, vl);
return l < 0 ? l : (vl - 1) - l;
}
Other options for reverse
Here I showed the most generic version of reversing a mask, within the V
extension. But if we have the Zbkb (Bit-manipulation for Cryptography)
extension available and we know our mask is not going to be larger than our
general purpose register (i.e., 64-bit in riscv64), we can reverse the bits of
the mask directly without using a vrgather.vv.
Something like this should do on riscv64 (I was not able to run this version so there might be some bug lurking in here!):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
vbool32_t reverse_mask_rev(vbool32_t m, size_t vl) {
// Reintepret the mask vector as data vector.
vuint64m1_t mask_as_data = __riscv_vreinterpret_v_b32_u64m1(m);
// Extract the lowest 64-bit elements of the vector (i.e. the first element)
uint64_t first_element = __riscv_vmv_x_s_u64m1_u64(mask_as_data);
assert(vl < 64);
// Clear the bits beyond vl because they might be garbage.
first_element &= ~(~(uint64_t)0 << vl);
uint64_t rev;
// Reverse bytes of the register.
asm("rev8 %0, %1" : "=r"(rev) : "r"(first_element));
// Reverse bits of each byte.
asm("brev8 %0, %1" : "=r"(rev) : "r"(rev));
// Now the bits are in the wrong end, push them back to the lower part.
rev = rev >> (64 - vl);
// Put the integer back into the first element of a vector.
vuint64m1_t data_as_mask = __riscv_vmv_s_x_u64m1(rev, vl);
// Reintepret the data vector as a mask vector.
vbool32_t reversed_mask = __riscv_vreinterpret_v_u64m1_b32(data_as_mask);
return reversed_mask;
}
Note that the reinteprets are no-operations and exist only to coerce the RVV intrinsics type system.
Counting the leading bits
If we have the Zbb (Basic bit-manipulation) extension available and our mask still fits in a general purpose register, we can just count the leading bits without having to reverse the mask. We still need to reverse the value taking into account that the last bit is one extra, so to say, leading zero.
On riscv64 it looks like this:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
int vlast_clz(vbool32_t m, size_t vl) {
// Reintepret the mask vector as data vector.
vuint64m1_t mask_as_data = __riscv_vreinterpret_v_b32_u64m1(m);
// Extract the lowest 64-bit elements of the vector (i.e. the first element)
uint64_t first_element = __riscv_vmv_x_s_u64m1_u64(mask_as_data);
assert(vl < 64);
// Clear the bits beyond vl because they might be garbage.
first_element &= ~(~(uint64_t)0 << vl);
// Align the relevant bits to the higher part of the register.
first_element <<= 64 - vl;
uint64_t clz;
// Compute the leading zeroes.
asm ("clz %0, %1" : "=r"(clz) : "r"(first_element));
// clz counts the number of leading zeroes, so the first bit set will be
// found at (clz + 1). Because we are counting from the most significant
// bit, we need to reverse the index using vl.
return clz == 64 ? -1 : vl - (clz + 1);
}
Using the prefix sum
This one is a fun and clever approach that still stays inside V and does not
require any assumption about the number of elements of the mask. It relies on
the viota.m instruction that computes the prefix sum of a mask: the result is an
integer data vector where each element contains the count of elements that are
set in the corresponding previous elements of the mask.
The count of elements that are set in a mask is called the population count,
and can be computed using the vpopc.m instruction. So, the prefix sum is the
population count of the subrange from the beginning of the mask to the previous
element of the current element.
| element number |
| |
| 7 6 5 4 3 2 1 0 | viota.m | vpopc.m
----------+---------------+-----------------+----------
mask | 0 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0 | 0
| 0 0 0 0 0 0 0 1 | 1 1 1 1 1 1 1 0 | 1
| 0 0 0 0 0 0 1 0 | 1 1 1 1 1 1 0 0 | 1
| 0 0 0 0 0 1 1 0 | 2 2 2 2 2 1 0 0 | 2
| 1 0 0 0 0 0 0 0 | 0 0 0 0 0 0 0 0 | 1
| 1 0 0 0 0 1 0 0 | 1 1 1 1 1 0 0 0 | 2
| 0 1 1 1 0 1 0 0 | 4 3 2 1 1 0 0 0 | 4
| 1 1 1 1 0 1 0 0 | 4 3 2 1 1 0 0 0 | 5
One interesting property of the prefixsum is that the last element is either the population count of the mask or that value minus one (in the specific case where the largest numbered element is set). We will prove this fact in the next part using Lean.
To compute the last bit set we first compute the population count. If the population count is zero there is no last bit and the result is trivially -1. Now we compute the prefix sum and we compute a mask comparing all the elements of the prefix sum with the population count.
If the most significant bit is not set, one of the elements of the prefix sum
will be exactly the population count. And the first bit where this happens is
exactly the first zero element after the last bit, so we only have to subtract
one to the result of vfirst.m.
If the most significant bit is set, the population count is not found in the
prefix sum (and vfirst.m computes -1) but at this point we know there is at
least one bit, so we know the last bit is exactly the most significant one.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
int vlast_prefixsum(vbool32_t m, size_t vl) {
// Compute the population count.
uint64_t popc = __riscv_vcpop_m_b32(m, vl);
// If no bit is set we're done: there is no last bit.
if (!popc) return -1;
// Compute the prefix sum of the mask.
vuint32m1_t prefixsum = __riscv_viota_m_u32m1(m, vl);
// Compute a mask of elements equal to the population count.
vbool32_t equal_to_popc = __riscv_vmseq_vx_u32m1_b32(prefixsum, popc, vl);
// Get the index of the lowest numbered element equal to the population count.
long idx = __riscv_vfirst_m_b32(equal_to_popc, vl);
// If there was none, the result is `vl - 1` otherwise `idx - 1` as we
// are always counting only the previous elements.
return (idx < 0 ? vl : idx) - 1;
}