spirv: Address some regressions in buffer loads (#304)

* spirv: Use correct index

* spirv: Fix indices during buffer load

* clang-format fix

* spirv: Index can be const

---------

Co-authored-by: georgemoralis <giorgosmrls@gmail.com>
This commit is contained in:
TheTurtle 2024-07-19 19:36:07 +03:00 committed by GitHub
parent 60b1aa62a3
commit bfe3322977
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 11 deletions

View File

@ -215,20 +215,19 @@ Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferF32(ctx, inst, handle, address);
}
template <int N>
template <u32 N>
static Id EmitLoadBufferF32xN(EmitContext& ctx, u32 handle, Id address) {
const auto& buffer = ctx.buffers[handle];
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
if constexpr (N == 1) {
const Id ptr{
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, address)};
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
return ctx.OpLoad(buffer.data_types->Get(1), ptr);
} else {
boost::container::static_vector<Id, N> ids;
for (u32 i = 0; i < N; i++) {
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id ptr{
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index_i)};
ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr));
}
return ctx.OpCompositeConstruct(buffer.data_types->Get(N), ids);
@ -394,7 +393,7 @@ static Id GetBufferFormatValue(EmitContext& ctx, u32 handle, Id address, u32 com
}
}
template <int N>
template <u32 N>
static Id EmitLoadBufferFormatF32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
if constexpr (N == 1) {
return GetBufferFormatValue(ctx, handle, address, 0);
@ -423,18 +422,18 @@ Id EmitLoadBufferFormatF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id ad
return EmitLoadBufferFormatF32xN<4>(ctx, inst, handle, address);
}
template <unsigned N>
template <u32 N>
static void EmitStoreBufferF32xN(EmitContext& ctx, u32 handle, Id address, Id value) {
const auto& buffer = ctx.buffers[handle];
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
if constexpr (N == 1) {
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ctx.OpStore(ptr, value);
} else {
for (u32 i = 0; i < N; i++) {
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id ptr =
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index);
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index_i);
ctx.OpStore(ptr, ctx.OpCompositeExtract(ctx.F32[1], value, i));
}
}