spirv: Define buffer offsets upfront

* Saves a lot of shader instructions
This commit is contained in:
IndecisiveTurtle 2024-08-10 03:36:42 +03:00
parent 0bac6e8e1c
commit 594606c1c3
4 changed files with 20 additions and 26 deletions

View File

@ -128,11 +128,7 @@ Id EmitReadConst(EmitContext& ctx) {
Id EmitReadConstBuffer(EmitContext& ctx, u32 handle, Id index) { Id EmitReadConstBuffer(EmitContext& ctx, u32 handle, Id index) {
auto& buffer = ctx.buffers[handle]; auto& buffer = ctx.buffers[handle];
if (!Sirit::ValidId(buffer.offset)) { index = ctx.OpIAdd(ctx.U32[1], index, buffer.offset_dwords);
buffer.offset = ctx.GetBufferOffset(buffer.global_binding);
}
const Id offset_dwords{ctx.OpShiftRightLogical(ctx.U32[1], buffer.offset, ctx.ConstU32(2U))};
index = ctx.OpIAdd(ctx.U32[1], index, offset_dwords);
const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)}; const Id ptr{ctx.OpAccessChain(buffer.pointer_type, buffer.id, ctx.u32_zero_value, index)};
return ctx.OpLoad(buffer.data_types->Get(1), ptr); return ctx.OpLoad(buffer.data_types->Get(1), ptr);
} }
@ -229,9 +225,6 @@ Id EmitLoadBufferU32(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
template <u32 N> template <u32 N>
static Id EmitLoadBufferF32xN(EmitContext& ctx, u32 handle, Id address) { static Id EmitLoadBufferF32xN(EmitContext& ctx, u32 handle, Id address) {
auto& buffer = ctx.buffers[handle]; auto& buffer = ctx.buffers[handle];
if (!Sirit::ValidId(buffer.offset)) {
buffer.offset = ctx.GetBufferOffset(buffer.global_binding);
}
address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset); address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset);
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u)); const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
if constexpr (N == 1) { if constexpr (N == 1) {
@ -404,9 +397,6 @@ static Id GetBufferFormatValue(EmitContext& ctx, u32 handle, Id address, u32 com
template <u32 N> template <u32 N>
static Id EmitLoadBufferFormatF32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) { static Id EmitLoadBufferFormatF32xN(EmitContext& ctx, IR::Inst* inst, u32 handle, Id address) {
auto& buffer = ctx.buffers[handle]; auto& buffer = ctx.buffers[handle];
if (!Sirit::ValidId(buffer.offset)) {
buffer.offset = ctx.GetBufferOffset(buffer.global_binding);
}
address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset); address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset);
if constexpr (N == 1) { if constexpr (N == 1) {
return GetBufferFormatValue(ctx, handle, address, 0); return GetBufferFormatValue(ctx, handle, address, 0);
@ -438,9 +428,6 @@ Id EmitLoadBufferFormatF32x4(EmitContext& ctx, IR::Inst* inst, u32 handle, Id ad
template <u32 N> template <u32 N>
static void EmitStoreBufferF32xN(EmitContext& ctx, u32 handle, Id address, Id value) { static void EmitStoreBufferF32xN(EmitContext& ctx, u32 handle, Id address, Id value) {
auto& buffer = ctx.buffers[handle]; auto& buffer = ctx.buffers[handle];
if (!Sirit::ValidId(buffer.offset)) {
buffer.offset = ctx.GetBufferOffset(buffer.global_binding);
}
address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset); address = ctx.OpIAdd(ctx.U32[1], address, buffer.offset);
const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u)); const Id index = ctx.OpShiftRightLogical(ctx.U32[1], address, ctx.ConstU32(2u));
if constexpr (N == 1) { if constexpr (N == 1) {

View File

@ -6,7 +6,9 @@
namespace Shader::Backend::SPIRV { namespace Shader::Backend::SPIRV {
void EmitPrologue(EmitContext& ctx) {} void EmitPrologue(EmitContext& ctx) {
ctx.DefineBufferOffsets();
}
void EmitEpilogue(EmitContext& ctx) {} void EmitEpilogue(EmitContext& ctx) {}

View File

@ -165,14 +165,18 @@ EmitContext::SpirvAttribute EmitContext::GetAttributeInfo(AmdGpu::NumberFormat f
throw InvalidArgument("Invalid attribute type {}", fmt); throw InvalidArgument("Invalid attribute type {}", fmt);
} }
Id EmitContext::GetBufferOffset(u32 binding) { void EmitContext::DefineBufferOffsets() {
for (auto& buffer : buffers) {
const u32 binding = buffer.binding;
const u32 half = Shader::PushData::BufOffsetIndex + (binding >> 4); const u32 half = Shader::PushData::BufOffsetIndex + (binding >> 4);
const u32 comp = (binding & 0xf) >> 2; const u32 comp = (binding & 0xf) >> 2;
const u32 offset = (binding & 0x3) << 3; const u32 offset = (binding & 0x3) << 3;
const Id ptr{OpAccessChain(TypePointer(spv::StorageClass::PushConstant, U32[1]), const Id ptr{OpAccessChain(TypePointer(spv::StorageClass::PushConstant, U32[1]),
push_data_block, ConstU32(half), ConstU32(comp))}; push_data_block, ConstU32(half), ConstU32(comp))};
const Id value{OpLoad(U32[1], ptr)}; const Id value{OpLoad(U32[1], ptr)};
return OpBitFieldUExtract(U32[1], value, ConstU32(offset), ConstU32(8U)); buffer.offset = OpBitFieldUExtract(U32[1], value, ConstU32(offset), ConstU32(8U));
buffer.offset_dwords = OpShiftRightLogical(U32[1], buffer.offset, ConstU32(2U));
}
} }
Id MakeDefaultValue(EmitContext& ctx, u32 default_value) { Id MakeDefaultValue(EmitContext& ctx, u32 default_value) {
@ -354,7 +358,7 @@ void EmitContext::DefineBuffers() {
buffers.push_back({ buffers.push_back({
.id = id, .id = id,
.global_binding = binding++, .binding = binding++,
.data_types = data_types, .data_types = data_types,
.pointer_type = pointer_type, .pointer_type = pointer_type,
.buffer = buffer.GetVsharp(info), .buffer = buffer.GetVsharp(info),

View File

@ -40,7 +40,7 @@ public:
~EmitContext(); ~EmitContext();
Id Def(const IR::Value& value); Id Def(const IR::Value& value);
Id GetBufferOffset(u32 binding); void DefineBufferOffsets();
[[nodiscard]] Id DefineInput(Id type, u32 location) { [[nodiscard]] Id DefineInput(Id type, u32 location) {
const Id input_id{DefineVar(type, spv::StorageClass::Input)}; const Id input_id{DefineVar(type, spv::StorageClass::Input)};
@ -203,7 +203,8 @@ public:
struct BufferDefinition { struct BufferDefinition {
Id id; Id id;
Id offset; Id offset;
u32 global_binding; Id offset_dwords;
u32 binding;
const VectorIds* data_types; const VectorIds* data_types;
Id pointer_type; Id pointer_type;
AmdGpu::Buffer buffer; AmdGpu::Buffer buffer;