I have a wgsl compute shader, that essentially checks pairs of datapoints and counts the number pairs that meet a certain requirement (passes(i,j)
is true) for each point. (ie. if there are 5 passing pairs that include datapoint #3, then the 3rd element in the output buffer would be set to 5).
Here is a simplified MRE of the shader:
@group(0) @binding(0)
var<uniform> uniforms: Uniforms;
@group(0) @binding(1)
var<storage, read> data: array<f32>;
@group(0) @binding(2)
var<storage, read_write> output: array<atomic<u32>>;
fn passes(i: u32, j: u32) -> bool {
// To be discussed
}
@compute
@workgroup_size(32, 32, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
if(passes(global_id.x, global_id.y)) {
atomicAdd(&output[global_id.x], 1u);
}
}
I had been using an approximation of this requirement, which was fairly easy to calculate, and all results were returning as expected. However, when I attempted to implement the actual requirement, which is slightly more expensive, the shader’s output suggested that there were no successful pairs of data (The output buffer was all 0s). I expected this meant the function that checks the requirement was only returning false, due to bad logic. However, if I replaced the final return result;
call in the passes
function with
// Calculate expensive result
let result = ...; // Lots of vector flops, but no branching
// Calculate cheap approximation
let approximation = ...; // Something along the lines of data[i] < data[j], something simple
return result || approximation;
I was still getting all 0’s. This is very odd as the approximation had worked fine on its own, and even if result
was always 0 because of a bug, the approximation should still carry through to the return value. My next intuition was that it could be a shader timeout error (perhaps similar to what was discussed here or here). However, if I modify the code again to resemble this:
// Calculate expensive result
let result = ...; // Lots of vector flops, but no branching
// Calculate cheap approximation
let approximation = ...; // Something along the lines of data[i] < data[j], something simple
return approximation;
It runs just the same as when the whole function was just return approximation
. The amount of code executed hasn’t really changed (I think) but it is now working just fine.
For completeness, here is the rust code responsible for saving that output buffer:
encoder.copy_buffer_to_buffer(&output_buffer, 0, &output_staging_buffer, 0, output_buffer.size());
queue.submit(Some(encoder.finish()));
let output_staging_slice = output_staging_buffer.slice(..);
let (s1, r1) = flume::bounded(1);
output_staging_slice.map_async(MapMode::Read, move |data| s1.send(data).unwrap());
device.poll(wgpu::Maintain::Wait);
if let Ok(Ok(())) = r1.recv() {
let output_data = output_staging_slice.get_mapped_range();
let output_result = bytemuck::cast_slice(&output).to_vec();
drop(output_data);
output_staging_buffer.unmap();
Some(output_result)
} else {
None
}