spirv: Address some regressions in buffer loads (#304)
* spirv: Use correct index * spirv: Fix indices during buffer load * clang-format fix * spirv: Index can be const --------- Co-authored-by: georgemoralis <giorgosmrls@gmail.com>
This commit is contained in:
parent
60b1aa62a3
commit
bfe3322977
|
@ -215,20 +215,19 @@ Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
|
|||
return EmitLoadBufferF32(ctx, inst, handle, address);
|
||||
}
|
||||
|
||||
template <int N>
|
||||
template <u32 N>
|
||||
static Id EmitLoadBufferF32xN(EmitContext& ctx, u32 handle, Id address) {
|
||||
const auto& buffer = ctx.buffers[handle];
|
||||
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
|
||||
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
|
||||
if constexpr (N == 1) {
|
||||
const Id ptr{
|
||||
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, address)};
|
||||
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
|
||||
return ctx.OpLoad(buffer.data_types->Get(1), ptr);
|
||||
} else {
|
||||
boost::container::static_vector<Id, N> ids;
|
||||
for (u32 i = 0; i < N; i++) {
|
||||
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
|
||||
const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
|
||||
const Id ptr{
|
||||
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
|
||||
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index_i)};
|
||||
ids.push_back(ctx.OpLoad(buffer.data_types->Get(1), ptr));
|
||||
}
|
||||
return ctx.OpCompositeConstruct(buffer.data_types->Get(N), ids);
|
||||
|
@ -394,7 +393,7 @@ static Id GetBufferFormatValue(EmitContext& ctx, u32 handle, Id address, u32 com
|
|||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
template <u32 N>
|
||||
static Id EmitLoadBufferFormatF32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
|
||||
if constexpr (N == 1) {
|
||||
return GetBufferFormatValue(ctx, handle, address, 0);
|
||||
|
@ -423,18 +422,18 @@ Id EmitLoadBufferFormatF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id ad
|
|||
return EmitLoadBufferFormatF32xN<4>(ctx, inst, handle, address);
|
||||
}
|
||||
|
||||
template <unsigned N>
|
||||
template <u32 N>
|
||||
static void EmitStoreBufferF32xN(EmitContext& ctx, u32 handle, Id address, Id value) {
|
||||
const auto& buffer = ctx.buffers[handle];
|
||||
Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
|
||||
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
|
||||
if constexpr (N == 1) {
|
||||
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
|
||||
ctx.OpStore(ptr, value);
|
||||
} else {
|
||||
for (u32 i = 0; i < N; i++) {
|
||||
index = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
|
||||
const Id index_i = ctx.OpIAdd(ctx.U32[1], index, ctx.ConstU32(i));
|
||||
const Id ptr =
|
||||
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index);
|
||||
ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index_i);
|
||||
ctx.OpStore(ptr, ctx.OpCompositeExtract(ctx.F32[1], value, i));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue