From 680192a0c452d2f0d79582494c4e0a6e47e6e365 Mon Sep 17 00:00:00 2001 From: Vinicius Rangel Date: Sat, 27 Jul 2024 11:23:59 -0300 Subject: [PATCH] 64 bits OP, impl V_ADDC_U32 & V_MAD_U64_U32 (#310) * impl V_ADDC_U32 & V_MAD_U64_U32 * shader recompiler: add 64 bits version to get register / GetSrc * fix V_ADDC_U32 carry * shader recompiler: removed automatic conversion to force_flt in GetSRc * shader recompiler: auto cast between u32 and u64 during ssa pass * shader recompiler: fix SetVectorReg64 & standardize switches-case * shader translate: fix overflow detection in V_ADD_I32 use vcc lo instead of vcc thread bit * shader recompiler: more 64-bit work - removed bit_size parameter from Get[Scalar/Vector]Register - add BitwiseOr64 - add SetDst64 as a replacement for SetScalarReg64 & SetVectorReg64 - add GetSrc64 for 64-bit value * shader recompiler: add V_MAD_U64_U32 vcc output - add V_MAD_U64_U32 vcc output - ILessThan for 64-bits * shader recompiler: removed unnecessary changes & missing consts * shader_recompiler: Add s64 type in constant propagation --- .../backend/spirv/emit_spirv_instructions.h | 8 +- .../backend/spirv/emit_spirv_integer.cpp | 23 +- src/shader_recompiler/frontend/opcodes.h | 8 +- .../frontend/translate/translate.cpp | 208 +++++++++++++++++- .../frontend/translate/translate.h | 8 +- .../frontend/translate/vector_alu.cpp | 44 +++- src/shader_recompiler/ir/ir_emitter.cpp | 67 +++++- src/shader_recompiler/ir/ir_emitter.h | 6 +- src/shader_recompiler/ir/opcodes.inc | 8 +- .../ir/passes/constant_propogation_pass.cpp | 12 +- .../ir/passes/ssa_rewrite_pass.cpp | 8 +- src/shader_recompiler/ir/value.h | 1 + 12 files changed, 361 insertions(+), 40 deletions(-) diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h index e2b411e4..80dd66b1 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h +++ b/src/shader_recompiler/backend/spirv/emit_spirv_instructions.h @@ -258,6 +258,7 @@ Id EmitISub64(EmitContext& ctx, Id a, Id b); Id EmitSMulExt(EmitContext& ctx, Id a, Id b); Id EmitUMulExt(EmitContext& ctx, Id a, Id b); Id EmitIMul32(EmitContext& ctx, Id a, Id b); +Id EmitIMul64(EmitContext& ctx, Id a, Id b); Id EmitSDiv32(EmitContext& ctx, Id a, Id b); Id EmitUDiv32(EmitContext& ctx, Id a, Id b); Id EmitINeg32(EmitContext& ctx, Id value); @@ -271,6 +272,7 @@ Id EmitShiftRightArithmetic32(EmitContext& ctx, Id base, Id shift); Id EmitShiftRightArithmetic64(EmitContext& ctx, Id base, Id shift); Id EmitBitwiseAnd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); Id EmitBitwiseOr32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); +Id EmitBitwiseOr64(EmitContext& ctx, IR::Inst* inst, Id a, Id b); Id EmitBitwiseXor32(EmitContext& ctx, IR::Inst* inst, Id a, Id b); Id EmitBitFieldInsert(EmitContext& ctx, Id base, Id insert, Id offset, Id count); Id EmitBitFieldSExtract(EmitContext& ctx, IR::Inst* inst, Id base, Id offset, Id count); @@ -286,8 +288,10 @@ Id EmitSMax32(EmitContext& ctx, Id a, Id b); Id EmitUMax32(EmitContext& ctx, Id a, Id b); Id EmitSClamp32(EmitContext& ctx, IR::Inst* inst, Id value, Id min, Id max); Id EmitUClamp32(EmitContext& ctx, IR::Inst* inst, Id value, Id min, Id max); -Id EmitSLessThan(EmitContext& ctx, Id lhs, Id rhs); -Id EmitULessThan(EmitContext& ctx, Id lhs, Id rhs); +Id EmitSLessThan32(EmitContext& ctx, Id lhs, Id rhs); +Id EmitSLessThan64(EmitContext& ctx, Id lhs, Id rhs); +Id EmitULessThan32(EmitContext& ctx, Id lhs, Id rhs); +Id EmitULessThan64(EmitContext& ctx, Id lhs, Id rhs); Id EmitIEqual(EmitContext& ctx, Id lhs, Id rhs); Id EmitSLessThanEqual(EmitContext& ctx, Id lhs, Id rhs); Id EmitULessThanEqual(EmitContext& ctx, Id lhs, Id rhs); diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp index d5a0f276..019ceb01 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp @@ -84,6 +84,10 @@ Id EmitIMul32(EmitContext& ctx, Id a, Id b) { return ctx.OpIMul(ctx.U32[1], a, b); } +Id EmitIMul64(EmitContext& ctx, Id a, Id b) { + return ctx.OpIMul(ctx.U64, a, b); +} + Id EmitSDiv32(EmitContext& ctx, Id a, Id b) { return ctx.OpSDiv(ctx.U32[1], a, b); } @@ -142,6 +146,13 @@ Id EmitBitwiseOr32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { return result; } +Id EmitBitwiseOr64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { + const Id result{ctx.OpBitwiseOr(ctx.U64, a, b)}; + SetZeroFlag(ctx, inst, result); + SetSignFlag(ctx, inst, result); + return result; +} + Id EmitBitwiseXor32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { const Id result{ctx.OpBitwiseXor(ctx.U32[1], a, b)}; SetZeroFlag(ctx, inst, result); @@ -231,11 +242,19 @@ Id EmitUClamp32(EmitContext& ctx, IR::Inst* inst, Id value, Id min, Id max) { return result; } -Id EmitSLessThan(EmitContext& ctx, Id lhs, Id rhs) { +Id EmitSLessThan32(EmitContext& ctx, Id lhs, Id rhs) { return ctx.OpSLessThan(ctx.U1[1], lhs, rhs); } -Id EmitULessThan(EmitContext& ctx, Id lhs, Id rhs) { +Id EmitSLessThan64(EmitContext& ctx, Id lhs, Id rhs) { + return ctx.OpSLessThan(ctx.U1[1], lhs, rhs); +} + +Id EmitULessThan32(EmitContext& ctx, Id lhs, Id rhs) { + return ctx.OpULessThan(ctx.U1[1], lhs, rhs); +} + +Id EmitULessThan64(EmitContext& ctx, Id lhs, Id rhs) { return ctx.OpULessThan(ctx.U1[1], lhs, rhs); } diff --git a/src/shader_recompiler/frontend/opcodes.h b/src/shader_recompiler/frontend/opcodes.h index d38140d8..cdc1e474 100644 --- a/src/shader_recompiler/frontend/opcodes.h +++ b/src/shader_recompiler/frontend/opcodes.h @@ -2392,10 +2392,10 @@ enum class OperandField : u32 { ConstFloatPos_4_0, ConstFloatNeg_4_0, VccZ = 251, - ExecZ, - Scc, - LdsDirect, - LiteralConst, + ExecZ = 252, + Scc = 253, + LdsDirect = 254, + LiteralConst = 255, VectorGPR, Undefined = 0xFFFFFFFF, diff --git a/src/shader_recompiler/frontend/translate/translate.cpp b/src/shader_recompiler/frontend/translate/translate.cpp index 15052b2a..c4c6e505 100644 --- a/src/shader_recompiler/frontend/translate/translate.cpp +++ b/src/shader_recompiler/frontend/translate/translate.cpp @@ -76,21 +76,21 @@ void Translator::EmitPrologue() { } } +template <> IR::U32F32 Translator::GetSrc(const InstOperand& operand, bool force_flt) { - // Input modifiers work on float values. - force_flt |= operand.input_modifier.abs | operand.input_modifier.neg; - IR::U32F32 value{}; + + const bool is_float = operand.type == ScalarType::Float32 || force_flt; switch (operand.field) { case OperandField::ScalarGPR: - if (operand.type == ScalarType::Float32 || force_flt) { + if (is_float) { value = ir.GetScalarReg(IR::ScalarReg(operand.code)); } else { value = ir.GetScalarReg(IR::ScalarReg(operand.code)); } break; case OperandField::VectorGPR: - if (operand.type == ScalarType::Float32 || force_flt) { + if (is_float) { value = ir.GetVectorReg(IR::VectorReg(operand.code)); } else { value = ir.GetVectorReg(IR::VectorReg(operand.code)); @@ -164,15 +164,160 @@ IR::U32F32 Translator::GetSrc(const InstOperand& operand, bool force_flt) { UNREACHABLE(); } - if (operand.input_modifier.abs) { - value = ir.FPAbs(value); - } - if (operand.input_modifier.neg) { - value = ir.FPNeg(value); + if (is_float) { + if (operand.input_modifier.abs) { + value = ir.FPAbs(value); + } + if (operand.input_modifier.neg) { + value = ir.FPNeg(value); + } } return value; } +template <> +IR::U32 Translator::GetSrc(const InstOperand& operand, bool force_flt) { + return GetSrc(operand, force_flt); +} + +template <> +IR::F32 Translator::GetSrc(const InstOperand& operand, bool) { + return GetSrc(operand, true); +} + +template <> +IR::U64F64 Translator::GetSrc64(const InstOperand& operand, bool force_flt) { + IR::Value value_hi{}; + IR::Value value_lo{}; + + bool immediate = false; + const bool is_float = operand.type == ScalarType::Float64 || force_flt; + switch (operand.field) { + case OperandField::ScalarGPR: + if (is_float) { + value_lo = ir.GetScalarReg(IR::ScalarReg(operand.code)); + value_hi = ir.GetScalarReg(IR::ScalarReg(operand.code + 1)); + } else if (operand.type == ScalarType::Uint64 || operand.type == ScalarType::Sint64) { + value_lo = ir.GetScalarReg(IR::ScalarReg(operand.code)); + value_hi = ir.GetScalarReg(IR::ScalarReg(operand.code + 1)); + } else { + UNREACHABLE(); + } + break; + case OperandField::VectorGPR: + if (is_float) { + value_lo = ir.GetVectorReg(IR::VectorReg(operand.code)); + value_hi = ir.GetVectorReg(IR::VectorReg(operand.code + 1)); + } else if (operand.type == ScalarType::Uint64 || operand.type == ScalarType::Sint64) { + value_lo = ir.GetVectorReg(IR::VectorReg(operand.code)); + value_hi = ir.GetVectorReg(IR::VectorReg(operand.code + 1)); + } else { + UNREACHABLE(); + } + break; + case OperandField::ConstZero: + immediate = true; + if (force_flt) { + value_lo = ir.Imm64(0.0); + } else { + value_lo = ir.Imm64(u64(0U)); + } + break; + case OperandField::SignedConstIntPos: + ASSERT(!force_flt); + immediate = true; + value_lo = ir.Imm64(s64(operand.code) - SignedConstIntPosMin + 1); + break; + case OperandField::SignedConstIntNeg: + ASSERT(!force_flt); + immediate = true; + value_lo = ir.Imm64(-s64(operand.code) + SignedConstIntNegMin - 1); + break; + case OperandField::LiteralConst: + immediate = true; + if (force_flt) { + UNREACHABLE(); // There is a literal double? + } else { + value_lo = ir.Imm64(u64(operand.code)); + } + break; + case OperandField::ConstFloatPos_1_0: + immediate = true; + if (force_flt) { + value_lo = ir.Imm64(1.0); + } else { + value_lo = ir.Imm64(std::bit_cast(f64(1.0))); + } + break; + case OperandField::ConstFloatPos_0_5: + immediate = true; + value_lo = ir.Imm64(0.5); + break; + case OperandField::ConstFloatPos_2_0: + immediate = true; + value_lo = ir.Imm64(2.0); + break; + case OperandField::ConstFloatPos_4_0: + immediate = true; + value_lo = ir.Imm64(4.0); + break; + case OperandField::ConstFloatNeg_0_5: + immediate = true; + value_lo = ir.Imm64(-0.5); + break; + case OperandField::ConstFloatNeg_1_0: + immediate = true; + value_lo = ir.Imm64(-1.0); + break; + case OperandField::ConstFloatNeg_2_0: + immediate = true; + value_lo = ir.Imm64(-2.0); + break; + case OperandField::ConstFloatNeg_4_0: + immediate = true; + value_lo = ir.Imm64(-4.0); + break; + case OperandField::VccLo: { + value_lo = ir.GetVccLo(); + value_hi = ir.GetVccHi(); + } break; + case OperandField::VccHi: + UNREACHABLE(); + default: + UNREACHABLE(); + } + + IR::Value value; + + if (immediate) { + value = value_lo; + } else if (is_float) { + throw NotImplementedException("required OpPackDouble2x32 implementation"); + } else { + IR::Value packed = ir.CompositeConstruct(value_lo, value_hi); + value = ir.PackUint2x32(packed); + } + + if (is_float) { + if (operand.input_modifier.abs) { + value = ir.FPAbs(IR::F32F64(value)); + } + if (operand.input_modifier.neg) { + value = ir.FPNeg(IR::F32F64(value)); + } + } + return IR::U64F64(value); +} + +template <> +IR::U64 Translator::GetSrc64(const InstOperand& operand, bool force_flt) { + return GetSrc64(operand, force_flt); +} +template <> +IR::F64 Translator::GetSrc64(const InstOperand& operand, bool) { + return GetSrc64(operand, true); +} + void Translator::SetDst(const InstOperand& operand, const IR::U32F32& value) { IR::U32F32 result = value; if (operand.output_modifier.multiplier != 0.f) { @@ -197,6 +342,43 @@ void Translator::SetDst(const InstOperand& operand, const IR::U32F32& value) { } } +void Translator::SetDst64(const InstOperand& operand, const IR::U64F64& value_raw) { + IR::U64F64 value_untyped = value_raw; + + const bool is_float = value_raw.Type() == IR::Type::F64 || value_raw.Type() == IR::Type::F32; + if (is_float) { + if (operand.output_modifier.multiplier != 0.f) { + value_untyped = + ir.FPMul(value_untyped, ir.Imm64(f64(operand.output_modifier.multiplier))); + } + if (operand.output_modifier.clamp) { + value_untyped = ir.FPSaturate(value_raw); + } + } + const IR::U64 value = + is_float ? ir.BitCast(IR::F64{value_untyped}) : IR::U64{value_untyped}; + + const IR::Value unpacked{ir.UnpackUint2x32(value)}; + const IR::U32 lo{ir.CompositeExtract(unpacked, 0U)}; + const IR::U32 hi{ir.CompositeExtract(unpacked, 1U)}; + switch (operand.field) { + case OperandField::ScalarGPR: + ir.SetScalarReg(IR::ScalarReg(operand.code + 1), hi); + return ir.SetScalarReg(IR::ScalarReg(operand.code), lo); + case OperandField::VectorGPR: + ir.SetVectorReg(IR::VectorReg(operand.code + 1), hi); + return ir.SetVectorReg(IR::VectorReg(operand.code), lo); + case OperandField::VccLo: + UNREACHABLE(); + case OperandField::VccHi: + UNREACHABLE(); + case OperandField::M0: + break; + default: + UNREACHABLE(); + } +} + void Translator::EmitFetch(const GcnInst& inst) { // Read the pointer to the fetch shader assembly. const u32 sgpr_base = inst.src[0].code; @@ -320,6 +502,9 @@ void Translate(IR::Block* block, u32 block_base, std::span inst_l case Opcode::V_ADD_I32: translator.V_ADD_I32(inst); break; + case Opcode::V_ADDC_U32: + translator.V_ADDC_U32(inst); + break; case Opcode::V_CVT_F32_I32: translator.V_CVT_F32_I32(inst); break; @@ -470,6 +655,9 @@ void Translate(IR::Block* block, u32 block_base, std::span inst_l case Opcode::IMAGE_LOAD: translator.IMAGE_LOAD(false, inst); break; + case Opcode::V_MAD_U64_U32: + translator.V_MAD_U64_U32(inst); + break; case Opcode::V_CMP_GE_I32: translator.V_CMP_U32(ConditionOp::GE, true, false, inst); break; diff --git a/src/shader_recompiler/frontend/translate/translate.h b/src/shader_recompiler/frontend/translate/translate.h index 2aa6f712..3203ad73 100644 --- a/src/shader_recompiler/frontend/translate/translate.h +++ b/src/shader_recompiler/frontend/translate/translate.h @@ -100,6 +100,7 @@ public: void V_AND_B32(const GcnInst& inst); void V_LSHLREV_B32(const GcnInst& inst); void V_ADD_I32(const GcnInst& inst); + void V_ADDC_U32(const GcnInst& inst); void V_CVT_F32_I32(const GcnInst& inst); void V_CVT_F32_U32(const GcnInst& inst); void V_MAD_F32(const GcnInst& inst); @@ -129,6 +130,7 @@ public: void V_CVT_U32_F32(const GcnInst& inst); void V_SUBREV_F32(const GcnInst& inst); void V_SUBREV_I32(const GcnInst& inst); + void V_MAD_U64_U32(const GcnInst& inst); void V_CMP_U32(ConditionOp op, bool is_signed, bool set_exec, const GcnInst& inst); void V_LSHRREV_B32(const GcnInst& inst); void V_MUL_HI_U32(bool is_signed, const GcnInst& inst); @@ -186,8 +188,12 @@ public: void EXP(const GcnInst& inst); private: - IR::U32F32 GetSrc(const InstOperand& operand, bool flt_zero = false); + template + [[nodiscard]] T GetSrc(const InstOperand& operand, bool flt_zero = false); + template + [[nodiscard]] T GetSrc64(const InstOperand& operand, bool flt_zero = false); void SetDst(const InstOperand& operand, const IR::U32F32& value); + void SetDst64(const InstOperand& operand, const IR::U64F64& value_raw); private: IR::IREmitter ir; diff --git a/src/shader_recompiler/frontend/translate/vector_alu.cpp b/src/shader_recompiler/frontend/translate/vector_alu.cpp index ca648f88..1b2024f8 100644 --- a/src/shader_recompiler/frontend/translate/vector_alu.cpp +++ b/src/shader_recompiler/frontend/translate/vector_alu.cpp @@ -67,7 +67,8 @@ void Translator::V_OR_B32(bool is_xor, const GcnInst& inst) { const IR::U32 src0{GetSrc(inst.src[0])}; const IR::U32 src1{ir.GetVectorReg(IR::VectorReg(inst.src[1].code))}; const IR::VectorReg dst_reg{inst.dst[0].code}; - ir.SetVectorReg(dst_reg, is_xor ? ir.BitwiseXor(src0, src1) : ir.BitwiseOr(src0, src1)); + ir.SetVectorReg(dst_reg, + is_xor ? ir.BitwiseXor(src0, src1) : IR::U32(ir.BitwiseOr(src0, src1))); } void Translator::V_AND_B32(const GcnInst& inst) { @@ -92,6 +93,30 @@ void Translator::V_ADD_I32(const GcnInst& inst) { // TODO: Carry } +void Translator::V_ADDC_U32(const GcnInst& inst) { + + const auto src0 = GetSrc(inst.src[0]); + const auto src1 = GetSrc(inst.src[1]); + + IR::U32 scarry; + if (inst.src_count == 3) { // VOP3 + IR::U1 thread_bit{ir.GetThreadBitScalarReg(IR::ScalarReg(inst.src[2].code))}; + scarry = IR::U32{ir.Select(thread_bit, ir.Imm32(1), ir.Imm32(0))}; + } else { // VOP2 + scarry = ir.GetVccLo(); + } + + const IR::U32 result = ir.IAdd(ir.IAdd(src0, src1), scarry); + + const IR::VectorReg dst_reg{inst.dst[0].code}; + ir.SetVectorReg(dst_reg, result); + + const IR::U1 less_src0 = ir.ILessThan(result, src0, false); + const IR::U1 less_src1 = ir.ILessThan(result, src1, false); + const IR::U1 did_overflow = ir.LogicalOr(less_src0, less_src1); + ir.SetVcc(did_overflow); +} + void Translator::V_CVT_F32_I32(const GcnInst& inst) { const IR::U32 src0{GetSrc(inst.src[0])}; const IR::VectorReg dst_reg{inst.dst[0].code}; @@ -294,6 +319,23 @@ void Translator::V_SUBREV_I32(const GcnInst& inst) { // TODO: Carry-out } +void Translator::V_MAD_U64_U32(const GcnInst& inst) { + + const auto src0 = GetSrc(inst.src[0]); + const auto src1 = GetSrc(inst.src[1]); + const auto src2 = GetSrc64(inst.src[2]); + + const IR::U64 mul_result = ir.UConvert(64, ir.IMul(src0, src1)); + const IR::U64 sum_result = ir.IAdd(mul_result, src2); + + SetDst64(inst.dst[0], sum_result); + + const IR::U1 less_src0 = ir.ILessThan(sum_result, mul_result, false); + const IR::U1 less_src1 = ir.ILessThan(sum_result, src2, false); + const IR::U1 did_overflow = ir.LogicalOr(less_src0, less_src1); + ir.SetVcc(did_overflow); +} + void Translator::V_CMP_U32(ConditionOp op, bool is_signed, bool set_exec, const GcnInst& inst) { const IR::U32 src0{GetSrc(inst.src[0])}; const IR::U32 src1{GetSrc(inst.src[1])}; diff --git a/src/shader_recompiler/ir/ir_emitter.cpp b/src/shader_recompiler/ir/ir_emitter.cpp index cd4fdaa2..6ea3123d 100644 --- a/src/shader_recompiler/ir/ir_emitter.cpp +++ b/src/shader_recompiler/ir/ir_emitter.cpp @@ -964,8 +964,18 @@ IR::Value IREmitter::IMulExt(const U32& a, const U32& b, bool is_signed) { return Inst(is_signed ? Opcode::SMulExt : Opcode::UMulExt, a, b); } -U32 IREmitter::IMul(const U32& a, const U32& b) { - return Inst(Opcode::IMul32, a, b); +U32U64 IREmitter::IMul(const U32U64& a, const U32U64& b) { + if (a.Type() != b.Type()) { + UNREACHABLE_MSG("Mismatching types {} and {}", a.Type(), b.Type()); + } + switch (a.Type()) { + case Type::U32: + return Inst(Opcode::IMul32, a, b); + case Type::U64: + return Inst(Opcode::IMul64, a, b); + default: + ThrowInvalidType(a.Type()); + } } U32 IREmitter::IDiv(const U32& a, const U32& b, bool is_signed) { @@ -1024,8 +1034,18 @@ U32 IREmitter::BitwiseAnd(const U32& a, const U32& b) { return Inst(Opcode::BitwiseAnd32, a, b); } -U32 IREmitter::BitwiseOr(const U32& a, const U32& b) { - return Inst(Opcode::BitwiseOr32, a, b); +U32U64 IREmitter::BitwiseOr(const U32U64& a, const U32U64& b) { + if (a.Type() != b.Type()) { + UNREACHABLE_MSG("Mismatching types {} and {}", a.Type(), b.Type()); + } + switch (a.Type()) { + case Type::U32: + return Inst(Opcode::BitwiseOr32, a, b); + case Type::U64: + return Inst(Opcode::BitwiseOr64, a, b); + default: + ThrowInvalidType(a.Type()); + } } U32 IREmitter::BitwiseXor(const U32& a, const U32& b) { @@ -1095,8 +1115,18 @@ U32 IREmitter::UClamp(const U32& value, const U32& min, const U32& max) { return Inst(Opcode::UClamp32, value, min, max); } -U1 IREmitter::ILessThan(const U32& lhs, const U32& rhs, bool is_signed) { - return Inst(is_signed ? Opcode::SLessThan : Opcode::ULessThan, lhs, rhs); +U1 IREmitter::ILessThan(const U32U64& lhs, const U32U64& rhs, bool is_signed) { + if (lhs.Type() != rhs.Type()) { + UNREACHABLE_MSG("Mismatching types {} and {}", lhs.Type(), rhs.Type()); + } + switch (lhs.Type()) { + case Type::U32: + return Inst(is_signed ? Opcode::SLessThan32 : Opcode::ULessThan32, lhs, rhs); + case Type::U64: + return Inst(is_signed ? Opcode::SLessThan64 : Opcode::ULessThan64, lhs, rhs); + default: + ThrowInvalidType(lhs.Type()); + } } U1 IREmitter::IEqual(const U32U64& lhs, const U32U64& rhs) { @@ -1155,8 +1185,9 @@ U32U64 IREmitter::ConvertFToS(size_t bitsize, const F32F64& value) { ThrowInvalidType(value.Type()); } default: - UNREACHABLE_MSG("Invalid destination bitsize {}", bitsize); + break; } + throw NotImplementedException("Invalid destination bitsize {}", bitsize); } U32U64 IREmitter::ConvertFToU(size_t bitsize, const F32F64& value) { @@ -1183,13 +1214,17 @@ F32F64 IREmitter::ConvertSToF(size_t dest_bitsize, size_t src_bitsize, const Val switch (src_bitsize) { case 32: return Inst(Opcode::ConvertF32S32, value); + default: + break; } - break; case 64: switch (src_bitsize) { case 32: return Inst(Opcode::ConvertF64S32, value); + default: + break; } + default: break; } UNREACHABLE_MSG("Invalid bit size combination dst={} src={}", dest_bitsize, src_bitsize); @@ -1203,13 +1238,17 @@ F32F64 IREmitter::ConvertUToF(size_t dest_bitsize, size_t src_bitsize, const Val return Inst(Opcode::ConvertF32U16, value); case 32: return Inst(Opcode::ConvertF32U32, value); + default: + break; } - break; case 64: switch (src_bitsize) { case 32: return Inst(Opcode::ConvertF64U32, value); + default: + break; } + default: break; } UNREACHABLE_MSG("Invalid bit size combination dst={} src={}", dest_bitsize, src_bitsize); @@ -1227,7 +1266,11 @@ U16U32U64 IREmitter::UConvert(size_t result_bitsize, const U16U32U64& value) { switch (value.Type()) { case Type::U32: return Inst(Opcode::ConvertU16U32, value); + default: + break; } + default: + break; } throw NotImplementedException("Conversion from {} to {} bits", value.Type(), result_bitsize); } @@ -1238,13 +1281,17 @@ F16F32F64 IREmitter::FPConvert(size_t result_bitsize, const F16F32F64& value) { switch (value.Type()) { case Type::F32: return Inst(Opcode::ConvertF16F32, value); + default: + break; } - break; case 32: switch (value.Type()) { case Type::F16: return Inst(Opcode::ConvertF32F16, value); + default: + break; } + default: break; } throw NotImplementedException("Conversion from {} to {} bits", value.Type(), result_bitsize); diff --git a/src/shader_recompiler/ir/ir_emitter.h b/src/shader_recompiler/ir/ir_emitter.h index e7512430..7ee4e824 100644 --- a/src/shader_recompiler/ir/ir_emitter.h +++ b/src/shader_recompiler/ir/ir_emitter.h @@ -159,7 +159,7 @@ public: [[nodiscard]] Value IAddCary(const U32& a, const U32& b); [[nodiscard]] U32U64 ISub(const U32U64& a, const U32U64& b); [[nodiscard]] Value IMulExt(const U32& a, const U32& b, bool is_signed = false); - [[nodiscard]] U32 IMul(const U32& a, const U32& b); + [[nodiscard]] U32U64 IMul(const U32U64& a, const U32U64& b); [[nodiscard]] U32 IDiv(const U32& a, const U32& b, bool is_signed = false); [[nodiscard]] U32U64 INeg(const U32U64& value); [[nodiscard]] U32 IAbs(const U32& value); @@ -167,7 +167,7 @@ public: [[nodiscard]] U32U64 ShiftRightLogical(const U32U64& base, const U32& shift); [[nodiscard]] U32U64 ShiftRightArithmetic(const U32U64& base, const U32& shift); [[nodiscard]] U32 BitwiseAnd(const U32& a, const U32& b); - [[nodiscard]] U32 BitwiseOr(const U32& a, const U32& b); + [[nodiscard]] U32U64 BitwiseOr(const U32U64& a, const U32U64& b); [[nodiscard]] U32 BitwiseXor(const U32& a, const U32& b); [[nodiscard]] U32 BitFieldInsert(const U32& base, const U32& insert, const U32& offset, const U32& count); @@ -188,7 +188,7 @@ public: [[nodiscard]] U32 SClamp(const U32& value, const U32& min, const U32& max); [[nodiscard]] U32 UClamp(const U32& value, const U32& min, const U32& max); - [[nodiscard]] U1 ILessThan(const U32& lhs, const U32& rhs, bool is_signed); + [[nodiscard]] U1 ILessThan(const U32U64& lhs, const U32U64& rhs, bool is_signed); [[nodiscard]] U1 IEqual(const U32U64& lhs, const U32U64& rhs); [[nodiscard]] U1 ILessThanEqual(const U32& lhs, const U32& rhs, bool is_signed); [[nodiscard]] U1 IGreaterThan(const U32& lhs, const U32& rhs, bool is_signed); diff --git a/src/shader_recompiler/ir/opcodes.inc b/src/shader_recompiler/ir/opcodes.inc index 9aefc8b3..628b8d4f 100644 --- a/src/shader_recompiler/ir/opcodes.inc +++ b/src/shader_recompiler/ir/opcodes.inc @@ -227,6 +227,7 @@ OPCODE(IAddCary32, U32x2, U32, OPCODE(ISub32, U32, U32, U32, ) OPCODE(ISub64, U64, U64, U64, ) OPCODE(IMul32, U32, U32, U32, ) +OPCODE(IMul64, U64, U64, U64, ) OPCODE(SMulExt, U32x2, U32, U32, ) OPCODE(UMulExt, U32x2, U32, U32, ) OPCODE(SDiv32, U32, U32, U32, ) @@ -242,6 +243,7 @@ OPCODE(ShiftRightArithmetic32, U32, U32, OPCODE(ShiftRightArithmetic64, U64, U64, U32, ) OPCODE(BitwiseAnd32, U32, U32, U32, ) OPCODE(BitwiseOr32, U32, U32, U32, ) +OPCODE(BitwiseOr64, U64, U64, U64, ) OPCODE(BitwiseXor32, U32, U32, U32, ) OPCODE(BitFieldInsert, U32, U32, U32, U32, U32, ) OPCODE(BitFieldSExtract, U32, U32, U32, U32, ) @@ -258,8 +260,10 @@ OPCODE(SMax32, U32, U32, OPCODE(UMax32, U32, U32, U32, ) OPCODE(SClamp32, U32, U32, U32, U32, ) OPCODE(UClamp32, U32, U32, U32, U32, ) -OPCODE(SLessThan, U1, U32, U32, ) -OPCODE(ULessThan, U1, U32, U32, ) +OPCODE(SLessThan32, U1, U32, U32, ) +OPCODE(SLessThan64, U1, U64, U64, ) +OPCODE(ULessThan32, U1, U32, U32, ) +OPCODE(ULessThan64, U1, U64, U64, ) OPCODE(IEqual, U1, U32, U32, ) OPCODE(SLessThanEqual, U1, U32, U32, ) OPCODE(ULessThanEqual, U1, U32, U32, ) diff --git a/src/shader_recompiler/ir/passes/constant_propogation_pass.cpp b/src/shader_recompiler/ir/passes/constant_propogation_pass.cpp index 7cd896fb..13c0246e 100644 --- a/src/shader_recompiler/ir/passes/constant_propogation_pass.cpp +++ b/src/shader_recompiler/ir/passes/constant_propogation_pass.cpp @@ -21,6 +21,8 @@ template return value.F32(); } else if constexpr (std::is_same_v) { return value.U64(); + } else if constexpr (std::is_same_v) { + return static_cast(value.U64()); } } @@ -281,12 +283,18 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) { return FoldLogicalOr(inst); case IR::Opcode::LogicalNot: return FoldLogicalNot(inst); - case IR::Opcode::SLessThan: + case IR::Opcode::SLessThan32: FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); return; - case IR::Opcode::ULessThan: + case IR::Opcode::SLessThan64: + FoldWhenAllImmediates(inst, [](s64 a, s64 b) { return a < b; }); + return; + case IR::Opcode::ULessThan32: FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); return; + case IR::Opcode::ULessThan64: + FoldWhenAllImmediates(inst, [](u64 a, u64 b) { return a < b; }); + return; case IR::Opcode::SLessThanEqual: FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a <= b; }); return; diff --git a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp index 6a43ad6b..80591492 100644 --- a/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp +++ b/src/shader_recompiler/ir/passes/ssa_rewrite_pass.cpp @@ -348,13 +348,15 @@ void VisitInst(Pass& pass, IR::Block* block, IR::Inst& inst) { case IR::Opcode::GetThreadBitScalarReg: case IR::Opcode::GetScalarRegister: { const IR::ScalarReg reg{inst.Arg(0).ScalarReg()}; - inst.ReplaceUsesWith( - pass.ReadVariable(reg, block, opcode == IR::Opcode::GetThreadBitScalarReg)); + const bool thread_bit = opcode == IR::Opcode::GetThreadBitScalarReg; + const IR::Value value = pass.ReadVariable(reg, block, thread_bit); + inst.ReplaceUsesWith(value); break; } case IR::Opcode::GetVectorRegister: { const IR::VectorReg reg{inst.Arg(0).VectorReg()}; - inst.ReplaceUsesWith(pass.ReadVariable(reg, block)); + const IR::Value value = pass.ReadVariable(reg, block); + inst.ReplaceUsesWith(value); break; } case IR::Opcode::GetGotoVariable: diff --git a/src/shader_recompiler/ir/value.h b/src/shader_recompiler/ir/value.h index a43c17f5..db939eaa 100644 --- a/src/shader_recompiler/ir/value.h +++ b/src/shader_recompiler/ir/value.h @@ -220,6 +220,7 @@ using F16 = TypedValue; using F32 = TypedValue; using F64 = TypedValue; using U32F32 = TypedValue; +using U64F64 = TypedValue; using U32U64 = TypedValue; using U16U32U64 = TypedValue; using F32F64 = TypedValue;