spirv: Address some regressions in buffer loads (#304)

* spirv: Use correct index

* spirv: Fix indices during buffer load

* clang-format fix

* spirv: Index can be const

---------

Co-authored-by: georgemoralis <giorgosmrls@gmail.com>
This commit is contained in:
TheTurtle 2024-07-19 19:36:07 +03:00 committed by GitHub
parent 60b1aa62a3
commit bfe3322977
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 11 deletions

View File

@ -215,20 +215,19 @@ Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
return EmitLoadBufferF32(ctx, inst, handle, address); return EmitLoadBufferF32(ctx, inst, handle, address);
} }
template <int N> template <u32 N>
static Id EmitLoadBufferF32xN(EmitContext& ctx, u32 handle, Id address) { static Id EmitLoadBufferF32xN(EmitContext& ctx, u32 handle, Id address) {
const auto& buffer = ctx.buffers[handle]; const auto& buffer = ctx.buffers[handle];
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u)); const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
if constexpr (N == 1) { if constexpr (N == 1) {
const Id ptr{ const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, address)};
return ctx.OpLoad(buffer.data_types->Get(1), ptr); return ctx.OpLoad(buffer.data_types->Get(1), ptr);
} else { } else {
boost::container::static_vector<Id, N> ids; boost::container::static_vector<Id, N> ids;
for (u32 i = 0; i < N; i++) { for (u32 i = 0; i < N; i++) {
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i)); const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id ptr{ const Id ptr{
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)}; ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index_i)};
ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr)); ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr));
} }
return ctx.OpCompositeConstruct(buffer.data_types->Get(N), ids); return ctx.OpCompositeConstruct(buffer.data_types->Get(N), ids);
@ -394,7 +393,7 @@ static Id GetBufferFormatValue(EmitContext& ctx, u32 handle, Id address, u32 com
} }
} }
template <int N> template <u32 N>
static Id EmitLoadBufferFormatF32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { static Id EmitLoadBufferFormatF32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
if constexpr (N == 1) { if constexpr (N == 1) {
return GetBufferFormatValue(ctx, handle, address, 0); return GetBufferFormatValue(ctx, handle, address, 0);
@ -423,18 +422,18 @@ Id EmitLoadBufferFormatF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id ad
return EmitLoadBufferFormatF32xN<4>(ctx, inst, handle, address); return EmitLoadBufferFormatF32xN<4>(ctx, inst, handle, address);
} }
template <unsigned N> template <u32 N>
static void EmitStoreBufferF32xN(EmitContext& ctx, u32 handle, Id address, Id value) { static void EmitStoreBufferF32xN(EmitContext& ctx, u32 handle, Id address, Id value) {
const auto& buffer = ctx.buffers[handle]; const auto& buffer = ctx.buffers[handle];
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u)); const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
if constexpr (N == 1) { if constexpr (N == 1) {
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)}; const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
ctx.OpStore(ptr, value); ctx.OpStore(ptr, value);
} else { } else {
for (u32 i = 0; i < N; i++) { for (u32 i = 0; i < N; i++) {
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i)); const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
const Id ptr = const Id ptr =
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index); ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index_i);
ctx.OpStore(ptr, ctx.OpCompositeExtract(ctx.F32[1], value, i)); ctx.OpStore(ptr, ctx.OpCompositeExtract(ctx.F32[1], value, i));
} }
} }