// SPDX-FileCopyrightText: Copyright 2024 shadPS4 Emulator Project // SPDX-License-Identifier: GPL-2.0-or-later #include #include #include #include #include "common/func_traits.h" #include "shader_recompiler/backend/spirv/emit_spirv.h" #include "shader_recompiler/backend/spirv/emit_spirv_instructions.h" #include "shader_recompiler/backend/spirv/spirv_emit_context.h" #include "shader_recompiler/ir/basic_block.h" #include "shader_recompiler/ir/program.h" namespace Shader::Backend::SPIRV { namespace { template void SetDefinition(EmitContext& ctx, IR::Inst* inst, Args... args) { inst->SetDefinition(func(ctx, std::forward(args)...)); } template ArgType Arg(EmitContext& ctx, const IR::Value& arg) { if constexpr (std::is_same_v) { return ctx.Def(arg); } else if constexpr (std::is_same_v) { return arg; } else if constexpr (std::is_same_v) { return arg.U32(); } else if constexpr (std::is_same_v) { return arg.Attribute(); } else if constexpr (std::is_same_v) { return arg.ScalarReg(); } else if constexpr (std::is_same_v) { return arg.VectorReg(); } } template void Invoke(EmitContext& ctx, IR::Inst* inst, std::index_sequence) { using Traits = Common::FuncTraits; if constexpr (std::is_same_v) { if constexpr (is_first_arg_inst) { SetDefinition( ctx, inst, inst, Arg>(ctx, inst->Arg(I))...); } else { SetDefinition( ctx, inst, Arg>(ctx, inst->Arg(I))...); } } else { if constexpr (is_first_arg_inst) { func(ctx, inst, Arg>(ctx, inst->Arg(I))...); } else { func(ctx, Arg>(ctx, inst->Arg(I))...); } } } template void Invoke(EmitContext& ctx, IR::Inst* inst) { using Traits = Common::FuncTraits; static_assert(Traits::NUM_ARGS >= 1, "Insufficient arguments"); if constexpr (Traits::NUM_ARGS == 1) { Invoke(ctx, inst, std::make_index_sequence<0>{}); } else { using FirstArgType = typename Traits::template ArgType<1>; static constexpr bool is_first_arg_inst = std::is_same_v; using Indices = std::make_index_sequence; Invoke(ctx, inst, Indices{}); } } void EmitInst(EmitContext& ctx, IR::Inst* inst) { switch (inst->GetOpcode()) { #define OPCODE(name, result_type, ...) \ case IR::Opcode::name: \ return Invoke<&Emit##name>(ctx, inst); #include "shader_recompiler/ir/opcodes.inc" #undef OPCODE } throw LogicError("Invalid opcode {}", inst->GetOpcode()); } Id TypeId(const EmitContext& ctx, IR::Type type) { switch (type) { case IR::Type::U1: return ctx.U1[1]; case IR::Type::U32: return ctx.U32[1]; default: throw NotImplementedException("Phi node type {}", type); } } void Traverse(EmitContext& ctx, IR::Program& program) { IR::Block* current_block{}; for (const IR::AbstractSyntaxNode& node : program.syntax_list) { switch (node.type) { case IR::AbstractSyntaxNode::Type::Block: { const Id label{node.data.block->Definition()}; if (current_block) { ctx.OpBranch(label); } current_block = node.data.block; ctx.AddLabel(label); for (IR::Inst& inst : node.data.block->Instructions()) { EmitInst(ctx, &inst); } break; } case IR::AbstractSyntaxNode::Type::If: { const Id if_label{node.data.if_node.body->Definition()}; const Id endif_label{node.data.if_node.merge->Definition()}; ctx.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone); ctx.OpBranchConditional(ctx.Def(node.data.if_node.cond), if_label, endif_label); break; } case IR::AbstractSyntaxNode::Type::Loop: { const Id body_label{node.data.loop.body->Definition()}; const Id continue_label{node.data.loop.continue_block->Definition()}; const Id endloop_label{node.data.loop.merge->Definition()}; ctx.OpLoopMerge(endloop_label, continue_label, spv::LoopControlMask::MaskNone); ctx.OpBranch(body_label); break; } case IR::AbstractSyntaxNode::Type::Break: { const Id break_label{node.data.break_node.merge->Definition()}; const Id skip_label{node.data.break_node.skip->Definition()}; ctx.OpBranchConditional(ctx.Def(node.data.break_node.cond), break_label, skip_label); break; } case IR::AbstractSyntaxNode::Type::EndIf: if (current_block) { ctx.OpBranch(node.data.end_if.merge->Definition()); } break; case IR::AbstractSyntaxNode::Type::Repeat: { Id cond{ctx.Def(node.data.repeat.cond)}; const Id loop_header_label{node.data.repeat.loop_header->Definition()}; const Id merge_label{node.data.repeat.merge->Definition()}; ctx.OpBranchConditional(cond, loop_header_label, merge_label); break; } case IR::AbstractSyntaxNode::Type::Return: ctx.OpReturn(); break; case IR::AbstractSyntaxNode::Type::Unreachable: ctx.OpUnreachable(); break; } if (node.type != IR::AbstractSyntaxNode::Type::Block) { current_block = nullptr; } } } Id DefineMain(EmitContext& ctx, IR::Program& program) { const Id void_function{ctx.TypeFunction(ctx.void_id)}; const Id main{ctx.OpFunction(ctx.void_id, spv::FunctionControlMask::MaskNone, void_function)}; for (IR::Block* const block : program.blocks) { block->SetDefinition(ctx.OpLabel()); } Traverse(ctx, program); ctx.OpFunctionEnd(); return main; } void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) { const std::span interfaces(ctx.interfaces.data(), ctx.interfaces.size()); spv::ExecutionModel execution_model{}; switch (program.info.stage) { case Stage::Compute: { // const std::array workgroup_size{program.workgroup_size}; // execution_model = spv::ExecutionModel::GLCompute; // ctx.AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0], // workgroup_size[1], workgroup_size[2]); break; } case Stage::Vertex: execution_model = spv::ExecutionModel::Vertex; break; case Stage::Fragment: execution_model = spv::ExecutionModel::Fragment; if (ctx.profile.lower_left_origin_mode) { ctx.AddExecutionMode(main, spv::ExecutionMode::OriginLowerLeft); } else { ctx.AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft); } // if (program.info.stores_frag_depth) { // ctx.AddExecutionMode(main, spv::ExecutionMode::DepthReplacing); // } break; default: throw NotImplementedException("Stage {}", u32(program.info.stage)); } ctx.AddEntryPoint(execution_model, main, "main", interfaces); } void PatchPhiNodes(IR::Program& program, EmitContext& ctx) { auto inst{program.blocks.front()->begin()}; size_t block_index{0}; ctx.PatchDeferredPhi([&](size_t phi_arg) { if (phi_arg == 0) { ++inst; if (inst == program.blocks[block_index]->end() || inst->GetOpcode() != IR::Opcode::Phi) { do { ++block_index; inst = program.blocks[block_index]->begin(); } while (inst->GetOpcode() != IR::Opcode::Phi); } } return ctx.Def(inst->Arg(phi_arg)); }); } } // Anonymous namespace std::vector EmitSPIRV(const Profile& profile, IR::Program& program, Bindings& bindings) { EmitContext ctx{profile, program, bindings}; const Id main{DefineMain(ctx, program)}; DefineEntryPoint(program, ctx, main); if (program.info.stage == Stage::Vertex) { ctx.AddExtension("SPV_KHR_shader_draw_parameters"); ctx.AddCapability(spv::Capability::DrawParameters); } PatchPhiNodes(program, ctx); return ctx.Assemble(); } Id EmitPhi(EmitContext& ctx, IR::Inst* inst) { const size_t num_args{inst->NumArgs()}; boost::container::small_vector blocks; blocks.reserve(num_args); for (size_t index = 0; index < num_args; ++index) { blocks.push_back(inst->PhiBlock(index)->Definition()); } // The type of a phi instruction is stored in its flags const Id result_type{TypeId(ctx, inst->Flags())}; return ctx.DeferredOpPhi(result_type, std::span(blocks.data(), blocks.size())); } void EmitVoid(EmitContext&) {} Id EmitIdentity(EmitContext& ctx, const IR::Value& value) { throw NotImplementedException("Forward identity declaration"); } Id EmitConditionRef(EmitContext& ctx, const IR::Value& value) { throw NotImplementedException("Forward identity declaration"); } void EmitReference(EmitContext&) {} void EmitPhiMove(EmitContext&) { throw LogicError("Unreachable instruction"); } void EmitGetZeroFromOp(EmitContext&) { throw LogicError("Unreachable instruction"); } void EmitGetSignFromOp(EmitContext&) { throw LogicError("Unreachable instruction"); } void EmitGetCarryFromOp(EmitContext&) { throw LogicError("Unreachable instruction"); } void EmitGetOverflowFromOp(EmitContext&) { throw LogicError("Unreachable instruction"); } void EmitSetVcc(EmitContext& ctx) { throw LogicError("Unreachable instruction"); } void EmitGetVcc(EmitContext& ctx) { throw LogicError("Unreachable instruction"); } } // namespace Shader::Backend::SPIRV