spirv: Simplify shared memory handling

This commit is contained in:
IndecisiveTurtle 2024-08-14 00:33:01 +03:00
parent ad3b6c793c
commit 2c723949a0
8 changed files with 11 additions and 183 deletions

View File

@ -93,15 +93,9 @@ Id EmitUndefU8(EmitContext& ctx);
Id EmitUndefU16(EmitContext& ctx); Id EmitUndefU16(EmitContext& ctx);
Id EmitUndefU32(EmitContext& ctx); Id EmitUndefU32(EmitContext& ctx);
Id EmitUndefU64(EmitContext& ctx); Id EmitUndefU64(EmitContext& ctx);
Id EmitLoadSharedU8(EmitContext& ctx, Id offset);
Id EmitLoadSharedS8(EmitContext& ctx, Id offset);
Id EmitLoadSharedU16(EmitContext& ctx, Id offset);
Id EmitLoadSharedS16(EmitContext& ctx, Id offset);
Id EmitLoadSharedU32(EmitContext& ctx, Id offset); Id EmitLoadSharedU32(EmitContext& ctx, Id offset);
Id EmitLoadSharedU64(EmitContext& ctx, Id offset); Id EmitLoadSharedU64(EmitContext& ctx, Id offset);
Id EmitLoadSharedU128(EmitContext& ctx, Id offset); Id EmitLoadSharedU128(EmitContext& ctx, Id offset);
void EmitWriteSharedU8(EmitContext& ctx, Id offset, Id value);
void EmitWriteSharedU16(EmitContext& ctx, Id offset, Id value);
void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value); void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value);
void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value); void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value);
void EmitWriteSharedU128(EmitContext& ctx, Id offset, Id value); void EmitWriteSharedU128(EmitContext& ctx, Id offset, Id value);

View File

@ -5,84 +5,15 @@
#include "shader_recompiler/backend/spirv/spirv_emit_context.h" #include "shader_recompiler/backend/spirv/spirv_emit_context.h"
namespace Shader::Backend::SPIRV { namespace Shader::Backend::SPIRV {
namespace {
Id Pointer(EmitContext& ctx, Id pointer_type, Id array, Id offset, u32 shift) {
const Id shift_id{ctx.ConstU32(shift)};
const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
return ctx.OpAccessChain(pointer_type, array, ctx.u32_zero_value, index);
}
Id Word(EmitContext& ctx, Id offset) { Id EmitLoadSharedU32(EmitContext& ctx, Id offset) {
const Id shift_id{ctx.ConstU32(2U)}; const Id shift_id{ctx.ConstU32(2U)};
const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)}; const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
const Id pointer{ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index)}; const Id pointer{ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index)};
return ctx.OpLoad(ctx.U32[1], pointer); return ctx.OpLoad(ctx.U32[1], pointer);
} }
std::pair<Id, Id> ExtractArgs(EmitContext& ctx, Id offset, u32 mask, u32 count) {
const Id shift{ctx.OpShiftLeftLogical(ctx.U32[1], offset, ctx.ConstU32(3U))};
const Id bit{ctx.OpBitwiseAnd(ctx.U32[1], shift, ctx.ConstU32(mask))};
const Id count_id{ctx.ConstU32(count)};
return {bit, count_id};
}
} // Anonymous namespace
Id EmitLoadSharedU8(EmitContext& ctx, Id offset) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{
ctx.OpAccessChain(ctx.shared_u8, ctx.shared_memory_u8, ctx.u32_zero_value, offset)};
return ctx.OpUConvert(ctx.U32[1], ctx.OpLoad(ctx.U8, pointer));
} else {
const auto [bit, count]{ExtractArgs(ctx, offset, 24, 8)};
return ctx.OpBitFieldUExtract(ctx.U32[1], Word(ctx, offset), bit, count);
}
}
Id EmitLoadSharedS8(EmitContext& ctx, Id offset) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{
ctx.OpAccessChain(ctx.shared_u8, ctx.shared_memory_u8, ctx.u32_zero_value, offset)};
return ctx.OpSConvert(ctx.U32[1], ctx.OpLoad(ctx.U8, pointer));
} else {
const auto [bit, count]{ExtractArgs(ctx, offset, 24, 8)};
return ctx.OpBitFieldSExtract(ctx.U32[1], Word(ctx, offset), bit, count);
}
}
Id EmitLoadSharedU16(EmitContext& ctx, Id offset) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{Pointer(ctx, ctx.shared_u16, ctx.shared_memory_u16, offset, 1)};
return ctx.OpUConvert(ctx.U32[1], ctx.OpLoad(ctx.U16, pointer));
} else {
const auto [bit, count]{ExtractArgs(ctx, offset, 16, 16)};
return ctx.OpBitFieldUExtract(ctx.U32[1], Word(ctx, offset), bit, count);
}
}
Id EmitLoadSharedS16(EmitContext& ctx, Id offset) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{Pointer(ctx, ctx.shared_u16, ctx.shared_memory_u16, offset, 1)};
return ctx.OpSConvert(ctx.U32[1], ctx.OpLoad(ctx.U16, pointer));
} else {
const auto [bit, count]{ExtractArgs(ctx, offset, 16, 16)};
return ctx.OpBitFieldSExtract(ctx.U32[1], Word(ctx, offset), bit, count);
}
}
Id EmitLoadSharedU32(EmitContext& ctx, Id offset) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{Pointer(ctx, ctx.shared_u32, ctx.shared_memory_u32, offset, 2)};
return ctx.OpLoad(ctx.U32[1], pointer);
} else {
return Word(ctx, offset);
}
}
Id EmitLoadSharedU64(EmitContext& ctx, Id offset) { Id EmitLoadSharedU64(EmitContext& ctx, Id offset) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{Pointer(ctx, ctx.shared_u32x2, ctx.shared_memory_u32x2, offset, 3)};
return ctx.OpLoad(ctx.U32[2], pointer);
} else {
const Id shift_id{ctx.ConstU32(2U)}; const Id shift_id{ctx.ConstU32(2U)};
const Id base_index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)}; const Id base_index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
const Id next_index{ctx.OpIAdd(ctx.U32[1], base_index, ctx.ConstU32(1U))}; const Id next_index{ctx.OpIAdd(ctx.U32[1], base_index, ctx.ConstU32(1U))};
@ -90,14 +21,9 @@ Id EmitLoadSharedU64(EmitContext& ctx, Id offset) {
const Id rhs_pointer{ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, next_index)}; const Id rhs_pointer{ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, next_index)};
return ctx.OpCompositeConstruct(ctx.U32[2], ctx.OpLoad(ctx.U32[1], lhs_pointer), return ctx.OpCompositeConstruct(ctx.U32[2], ctx.OpLoad(ctx.U32[1], lhs_pointer),
ctx.OpLoad(ctx.U32[1], rhs_pointer)); ctx.OpLoad(ctx.U32[1], rhs_pointer));
}
} }
Id EmitLoadSharedU128(EmitContext& ctx, Id offset) { Id EmitLoadSharedU128(EmitContext& ctx, Id offset) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{Pointer(ctx, ctx.shared_u32x4, ctx.shared_memory_u32x4, offset, 4)};
return ctx.OpLoad(ctx.U32[4], pointer);
}
const Id shift_id{ctx.ConstU32(2U)}; const Id shift_id{ctx.ConstU32(2U)};
const Id base_index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)}; const Id base_index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
std::array<Id, 4> values{}; std::array<Id, 4> values{};
@ -109,35 +35,14 @@ Id EmitLoadSharedU128(EmitContext& ctx, Id offset) {
return ctx.OpCompositeConstruct(ctx.U32[4], values); return ctx.OpCompositeConstruct(ctx.U32[4], values);
} }
void EmitWriteSharedU8(EmitContext& ctx, Id offset, Id value) {
const Id pointer{
ctx.OpAccessChain(ctx.shared_u8, ctx.shared_memory_u8, ctx.u32_zero_value, offset)};
ctx.OpStore(pointer, ctx.OpUConvert(ctx.U8, value));
}
void EmitWriteSharedU16(EmitContext& ctx, Id offset, Id value) {
const Id pointer{Pointer(ctx, ctx.shared_u16, ctx.shared_memory_u16, offset, 1)};
ctx.OpStore(pointer, ctx.OpUConvert(ctx.U16, value));
}
void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value) { void EmitWriteSharedU32(EmitContext& ctx, Id offset, Id value) {
Id pointer{};
if (ctx.profile.support_explicit_workgroup_layout) {
pointer = Pointer(ctx, ctx.shared_u32, ctx.shared_memory_u32, offset, 2);
} else {
const Id shift{ctx.ConstU32(2U)}; const Id shift{ctx.ConstU32(2U)};
const Id word_offset{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift)}; const Id word_offset{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift)};
pointer = ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, word_offset); const Id pointer = ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, word_offset);
}
ctx.OpStore(pointer, value); ctx.OpStore(pointer, value);
} }
void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value) { void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{Pointer(ctx, ctx.shared_u32x2, ctx.shared_memory_u32x2, offset, 3)};
ctx.OpStore(pointer, value);
return;
}
const Id shift{ctx.ConstU32(2U)}; const Id shift{ctx.ConstU32(2U)};
const Id word_offset{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift)}; const Id word_offset{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift)};
const Id next_offset{ctx.OpIAdd(ctx.U32[1], word_offset, ctx.ConstU32(1U))}; const Id next_offset{ctx.OpIAdd(ctx.U32[1], word_offset, ctx.ConstU32(1U))};
@ -148,11 +53,6 @@ void EmitWriteSharedU64(EmitContext& ctx, Id offset, Id value) {
} }
void EmitWriteSharedU128(EmitContext& ctx, Id offset, Id value) { void EmitWriteSharedU128(EmitContext& ctx, Id offset, Id value) {
if (ctx.profile.support_explicit_workgroup_layout) {
const Id pointer{Pointer(ctx, ctx.shared_u32x4, ctx.shared_memory_u32x4, offset, 4)};
ctx.OpStore(pointer, value);
return;
}
const Id shift{ctx.ConstU32(2U)}; const Id shift{ctx.ConstU32(2U)};
const Id base_index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift)}; const Id base_index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift)};
for (u32 i = 0; i < 4; ++i) { for (u32 i = 0; i < 4; ++i) {

View File

@ -513,43 +513,9 @@ void EmitContext::DefineSharedMemory() {
if (info.shared_memory_size == 0) { if (info.shared_memory_size == 0) {
info.shared_memory_size = DefaultSharedMemSize; info.shared_memory_size = DefaultSharedMemSize;
} }
const auto make{[&](Id element_type, u32 element_size) {
const u32 num_elements{Common::DivCeil(info.shared_memory_size, element_size)};
const Id array_type{TypeArray(element_type, ConstU32(num_elements))};
Decorate(array_type, spv::Decoration::ArrayStride, element_size);
const Id struct_type{TypeStruct(array_type)};
MemberDecorate(struct_type, 0U, spv::Decoration::Offset, 0U);
Decorate(struct_type, spv::Decoration::Block);
const Id pointer{TypePointer(spv::StorageClass::Workgroup, struct_type)};
const Id element_pointer{TypePointer(spv::StorageClass::Workgroup, element_type)};
const Id variable{AddGlobalVariable(pointer, spv::StorageClass::Workgroup)};
Decorate(variable, spv::Decoration::Aliased);
interfaces.push_back(variable);
return std::make_tuple(variable, element_pointer, pointer);
}};
if (profile.support_explicit_workgroup_layout) {
AddExtension("SPV_KHR_workgroup_memory_explicit_layout");
AddCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR);
if (info.uses_shared_u8) {
AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
std::tie(shared_memory_u8, shared_u8, std::ignore) = make(U8, 1);
}
if (info.uses_shared_u16) {
AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
std::tie(shared_memory_u16, shared_u16, std::ignore) = make(U16, 2);
}
std::tie(shared_memory_u32, shared_u32, shared_memory_u32_type) = make(U32[1], 4);
std::tie(shared_memory_u32x2, shared_u32x2, std::ignore) = make(U32[2], 8);
std::tie(shared_memory_u32x4, shared_u32x4, std::ignore) = make(U32[4], 16);
return;
}
const u32 num_elements{Common::DivCeil(info.shared_memory_size, 4U)}; const u32 num_elements{Common::DivCeil(info.shared_memory_size, 4U)};
const Id type{TypeArray(U32[1], ConstU32(num_elements))}; const Id type{TypeArray(U32[1], ConstU32(num_elements))};
shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type); shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type);
shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]); shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]);
shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup); shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup);
interfaces.push_back(shared_memory_u32); interfaces.push_back(shared_memory_u32);

View File

@ -259,10 +259,6 @@ void IREmitter::SetAttribute(IR::Attribute attribute, const F32& value, u32 comp
Value IREmitter::LoadShared(int bit_size, bool is_signed, const U32& offset) { Value IREmitter::LoadShared(int bit_size, bool is_signed, const U32& offset) {
switch (bit_size) { switch (bit_size) {
case 8:
return Inst<U32>(is_signed ? Opcode::LoadSharedS8 : Opcode::LoadSharedU8, offset);
case 16:
return Inst<U32>(is_signed ? Opcode::LoadSharedS16 : Opcode::LoadSharedU16, offset);
case 32: case 32:
return Inst<U32>(Opcode::LoadSharedU32, offset); return Inst<U32>(Opcode::LoadSharedU32, offset);
case 64: case 64:
@ -276,12 +272,6 @@ Value IREmitter::LoadShared(int bit_size, bool is_signed, const U32& offset) {
void IREmitter::WriteShared(int bit_size, const Value& value, const U32& offset) { void IREmitter::WriteShared(int bit_size, const Value& value, const U32& offset) {
switch (bit_size) { switch (bit_size) {
case 8:
Inst(Opcode::WriteSharedU8, offset, value);
break;
case 16:
Inst(Opcode::WriteSharedU16, offset, value);
break;
case 32: case 32:
Inst(Opcode::WriteSharedU32, offset, value); Inst(Opcode::WriteSharedU32, offset, value);
break; break;

View File

@ -59,8 +59,6 @@ bool Inst::MayHaveSideEffects() const noexcept {
case Opcode::WriteSharedU128: case Opcode::WriteSharedU128:
case Opcode::WriteSharedU64: case Opcode::WriteSharedU64:
case Opcode::WriteSharedU32: case Opcode::WriteSharedU32:
case Opcode::WriteSharedU16:
case Opcode::WriteSharedU8:
case Opcode::ImageWrite: case Opcode::ImageWrite:
case Opcode::ImageAtomicIAdd32: case Opcode::ImageAtomicIAdd32:
case Opcode::ImageAtomicSMin32: case Opcode::ImageAtomicSMin32:

View File

@ -26,15 +26,9 @@ OPCODE(WorkgroupMemoryBarrier, Void,
OPCODE(DeviceMemoryBarrier, Void, ) OPCODE(DeviceMemoryBarrier, Void, )
// Shared memory operations // Shared memory operations
OPCODE(LoadSharedU8, U32, U32, )
OPCODE(LoadSharedS8, U32, U32, )
OPCODE(LoadSharedU16, U32, U32, )
OPCODE(LoadSharedS16, U32, U32, )
OPCODE(LoadSharedU32, U32, U32, ) OPCODE(LoadSharedU32, U32, U32, )
OPCODE(LoadSharedU64, U32x2, U32, ) OPCODE(LoadSharedU64, U32x2, U32, )
OPCODE(LoadSharedU128, U32x4, U32, ) OPCODE(LoadSharedU128, U32x4, U32, )
OPCODE(WriteSharedU8, Void, U32, U32, )
OPCODE(WriteSharedU16, Void, U32, U32, )
OPCODE(WriteSharedU32, Void, U32, U32, ) OPCODE(WriteSharedU32, Void, U32, U32, )
OPCODE(WriteSharedU64, Void, U32, U32x2, ) OPCODE(WriteSharedU64, Void, U32, U32x2, )
OPCODE(WriteSharedU128, Void, U32, U32x4, ) OPCODE(WriteSharedU128, Void, U32, U32x4, )

View File

@ -16,18 +16,6 @@ void Visit(Info& info, IR::Inst& inst) {
info.stores.Set(inst.Arg(0).Attribute(), inst.Arg(2).U32()); info.stores.Set(inst.Arg(0).Attribute(), inst.Arg(2).U32());
break; break;
} }
case IR::Opcode::LoadSharedS8:
case IR::Opcode::LoadSharedU8:
case IR::Opcode::WriteSharedU8:
info.uses_shared_u8 = true;
info.uses_shared = true;
break;
case IR::Opcode::LoadSharedS16:
case IR::Opcode::LoadSharedU16:
case IR::Opcode::WriteSharedU16:
info.uses_shared_u16 = true;
info.uses_shared = true;
break;
case IR::Opcode::LoadSharedU32: case IR::Opcode::LoadSharedU32:
case IR::Opcode::LoadSharedU64: case IR::Opcode::LoadSharedU64:
case IR::Opcode::WriteSharedU32: case IR::Opcode::WriteSharedU32:

View File

@ -195,8 +195,6 @@ struct Info {
bool has_image_query{}; bool has_image_query{};
bool uses_group_quad{}; bool uses_group_quad{};
bool uses_shared{}; bool uses_shared{};
bool uses_shared_u8{};
bool uses_shared_u16{};
bool uses_fp16{}; bool uses_fp16{};
bool uses_step_rates{}; bool uses_step_rates{};
bool translation_failed{}; // indicates that shader has unsupported instructions bool translation_failed{}; // indicates that shader has unsupported instructions