//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// \file /// Contains matchers for matching SelectionDAG nodes and values. /// //===----------------------------------------------------------------------===// #ifndef LLVM_CODEGEN_SDPATTERNMATCH_H #define LLVM_CODEGEN_SDPATTERNMATCH_H #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/CodeGen/SelectionDAG.h" #include "llvm/CodeGen/SelectionDAGNodes.h" #include "llvm/CodeGen/TargetLowering.h" namespace llvm { namespace SDPatternMatch { /// MatchContext can repurpose existing patterns to behave differently under /// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes /// in normal circumstances, but matches VP_ADD nodes under a custom /// VPMatchContext. This design is meant to facilitate code / pattern reusing. class BasicMatchContext { const SelectionDAG *DAG; const TargetLowering *TLI; public: explicit BasicMatchContext(const SelectionDAG *DAG) : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {} explicit BasicMatchContext(const TargetLowering *TLI) : DAG(nullptr), TLI(TLI) {} // A valid MatchContext has to implement the following functions. const SelectionDAG *getDAG() const { return DAG; } const TargetLowering *getTLI() const { return TLI; } /// Return true if N effectively has opcode Opcode. bool match(SDValue N, unsigned Opcode) const { return N->getOpcode() == Opcode; } }; template [[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx, Pattern &&P) { return P.match(Ctx, N); } template [[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx, Pattern &&P) { return sd_context_match(SDValue(N, 0), Ctx, P); } template [[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) { return sd_context_match(N, BasicMatchContext(DAG), P); } template [[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) { return sd_context_match(N, BasicMatchContext(DAG), P); } template [[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) { return sd_match(N, nullptr, P); } template [[nodiscard]] bool sd_match(SDValue N, Pattern &&P) { return sd_match(N, nullptr, P); } // === Utilities === struct Value_match { SDValue MatchVal; Value_match() = default; explicit Value_match(SDValue Match) : MatchVal(Match) {} template bool match(const MatchContext &, SDValue N) { if (MatchVal) return MatchVal == N; return N.getNode(); } }; /// Match any valid SDValue. inline Value_match m_Value() { return Value_match(); } inline Value_match m_Specific(SDValue N) { assert(N); return Value_match(N); } struct DeferredValue_match { SDValue &MatchVal; explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {} template bool match(const MatchContext &, SDValue N) { return N == MatchVal; } }; /// Similar to m_Specific, but the specific value to match is determined by /// another sub-pattern in the same sd_match() expression. For instance, /// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since /// `X` is not initialized at the time it got copied into `m_Specific`. Instead, /// we should use `m_Add(m_Value(X), m_Deferred(X))`. inline DeferredValue_match m_Deferred(SDValue &V) { return DeferredValue_match(V); } struct Opcode_match { unsigned Opcode; explicit Opcode_match(unsigned Opc) : Opcode(Opc) {} template bool match(const MatchContext &Ctx, SDValue N) { return Ctx.match(N, Opcode); } }; inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); } template struct NUses_match { Pattern P; explicit NUses_match(const Pattern &P) : P(P) {} template bool match(const MatchContext &Ctx, SDValue N) { // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces // multiple results, hence we check the subsequent pattern here before // checking the number of value users. return P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo()); } }; template inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) { return NUses_match<1, Pattern>(P); } template inline NUses_match m_NUses(const Pattern &P) { return NUses_match(P); } inline NUses_match<1, Value_match> m_OneUse() { return NUses_match<1, Value_match>(m_Value()); } template inline NUses_match m_NUses() { return NUses_match(m_Value()); } struct Value_bind { SDValue &BindVal; explicit Value_bind(SDValue &N) : BindVal(N) {} template bool match(const MatchContext &, SDValue N) { BindVal = N; return true; } }; inline Value_bind m_Value(SDValue &N) { return Value_bind(N); } template struct TLI_pred_match { Pattern P; PredFuncT PredFunc; TLI_pred_match(const PredFuncT &Pred, const Pattern &P) : P(P), PredFunc(Pred) {} template bool match(const MatchContext &Ctx, SDValue N) { assert(Ctx.getTLI() && "TargetLowering is required for this pattern."); return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N); } }; // Explicit deduction guide. template TLI_pred_match(const PredFuncT &Pred, const Pattern &P) -> TLI_pred_match; /// Match legal SDNodes based on the information provided by TargetLowering. template inline auto m_LegalOp(const Pattern &P) { return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { return TLI.isOperationLegal(N->getOpcode(), N.getValueType()); }, P}; } /// Switch to a different MatchContext for subsequent patterns. template struct SwitchContext { const NewMatchContext &Ctx; Pattern P; template bool match(const OrigMatchContext &, SDValue N) { return P.match(Ctx, N); } }; template inline SwitchContext m_Context(const MatchContext &Ctx, Pattern &&P) { return SwitchContext{Ctx, std::move(P)}; } // === Value type === struct ValueType_bind { EVT &BindVT; explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {} template bool match(const MatchContext &, SDValue N) { BindVT = N.getValueType(); return true; } }; /// Retreive the ValueType of the current SDValue. inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); } template struct ValueType_match { PredFuncT PredFunc; Pattern P; ValueType_match(const PredFuncT &Pred, const Pattern &P) : PredFunc(Pred), P(P) {} template bool match(const MatchContext &Ctx, SDValue N) { return PredFunc(N.getValueType()) && P.match(Ctx, N); } }; // Explicit deduction guide. template ValueType_match(const PredFuncT &Pred, const Pattern &P) -> ValueType_match; /// Match a specific ValueType. template inline auto m_SpecificVT(EVT RefVT, const Pattern &P) { return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P}; } inline auto m_SpecificVT(EVT RefVT) { return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()}; } inline auto m_Glue() { return m_SpecificVT(MVT::Glue); } inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); } /// Match any integer ValueTypes. template inline auto m_IntegerVT(const Pattern &P) { return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P}; } inline auto m_IntegerVT() { return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()}; } /// Match any floating point ValueTypes. template inline auto m_FloatingPointVT(const Pattern &P) { return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P}; } inline auto m_FloatingPointVT() { return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, m_Value()}; } /// Match any vector ValueTypes. template inline auto m_VectorVT(const Pattern &P) { return ValueType_match{[](EVT VT) { return VT.isVector(); }, P}; } inline auto m_VectorVT() { return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()}; } /// Match fixed-length vector ValueTypes. template inline auto m_FixedVectorVT(const Pattern &P) { return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P}; } inline auto m_FixedVectorVT() { return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, m_Value()}; } /// Match scalable vector ValueTypes. template inline auto m_ScalableVectorVT(const Pattern &P) { return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P}; } inline auto m_ScalableVectorVT() { return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, m_Value()}; } /// Match legal ValueTypes based on the information provided by TargetLowering. template inline auto m_LegalType(const Pattern &P) { return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { return TLI.isTypeLegal(N.getValueType()); }, P}; } // === Patterns combinators === template struct And { template bool match(const MatchContext &, SDValue N) { return true; } }; template struct And : And { Pred P; And(const Pred &p, const Preds &...preds) : And(preds...), P(p) {} template bool match(const MatchContext &Ctx, SDValue N) { return P.match(Ctx, N) && And::match(Ctx, N); } }; template struct Or { template bool match(const MatchContext &, SDValue N) { return false; } }; template struct Or : Or { Pred P; Or(const Pred &p, const Preds &...preds) : Or(preds...), P(p) {} template bool match(const MatchContext &Ctx, SDValue N) { return P.match(Ctx, N) || Or::match(Ctx, N); } }; template struct Not { Pred P; explicit Not(const Pred &P) : P(P) {} template bool match(const MatchContext &Ctx, SDValue N) { return !P.match(Ctx, N); } }; // Explicit deduction guide. template Not(const Pred &P) -> Not; /// Match if the inner pattern does NOT match. template inline Not m_Unless(const Pred &P) { return Not{P}; } template And m_AllOf(const Preds &...preds) { return And(preds...); } template Or m_AnyOf(const Preds &...preds) { return Or(preds...); } template auto m_NoneOf(const Preds &...preds) { return m_Unless(m_AnyOf(preds...)); } // === Generic node matching === template struct Operands_match { template bool match(const MatchContext &Ctx, SDValue N) { // Returns false if there are more operands than predicates; return N->getNumOperands() == OpIdx; } }; template struct Operands_match : Operands_match { OpndPred P; Operands_match(const OpndPred &p, const OpndPreds &...preds) : Operands_match(preds...), P(p) {} template bool match(const MatchContext &Ctx, SDValue N) { if (OpIdx < N->getNumOperands()) return P.match(Ctx, N->getOperand(OpIdx)) && Operands_match::match(Ctx, N); // This is the case where there are more predicates than operands. return false; } }; template auto m_Node(unsigned Opcode, const OpndPreds &...preds) { return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>(preds...)); } /// Provide number of operands that are not chain or glue, as well as the first /// index of such operand. template struct EffectiveOperands { unsigned Size = 0; unsigned FirstIndex = 0; explicit EffectiveOperands(SDValue N) { const unsigned TotalNumOps = N->getNumOperands(); FirstIndex = TotalNumOps; for (unsigned I = 0; I < TotalNumOps; ++I) { // Count the number of non-chain and non-glue nodes (we ignore chain // and glue by default) and retreive the operand index offset. EVT VT = N->getOperand(I).getValueType(); if (VT != MVT::Glue && VT != MVT::Other) { ++Size; if (FirstIndex == TotalNumOps) FirstIndex = I; } } } }; template <> struct EffectiveOperands { unsigned Size = 0; unsigned FirstIndex = 0; explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {} }; // === Ternary operations === template struct TernaryOpc_match { unsigned Opcode; T0_P Op0; T1_P Op1; T2_P Op2; TernaryOpc_match(unsigned Opc, const T0_P &Op0, const T1_P &Op1, const T2_P &Op2) : Opcode(Opc), Op0(Op0), Op1(Op1), Op2(Op2) {} template bool match(const MatchContext &Ctx, SDValue N) { if (sd_context_match(N, Ctx, m_Opc(Opcode))) { EffectiveOperands EO(N); assert(EO.Size == 3); return ((Op0.match(Ctx, N->getOperand(EO.FirstIndex)) && Op1.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || (Commutable && Op0.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && Op1.match(Ctx, N->getOperand(EO.FirstIndex)))) && Op2.match(Ctx, N->getOperand(EO.FirstIndex + 2)); } return false; } }; template inline TernaryOpc_match m_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) { return TernaryOpc_match(ISD::SETCC, LHS, RHS, CC); } template inline TernaryOpc_match m_c_SetCC(const T0_P &LHS, const T1_P &RHS, const T2_P &CC) { return TernaryOpc_match(ISD::SETCC, LHS, RHS, CC); } // === Binary operations === template struct BinaryOpc_match { unsigned Opcode; LHS_P LHS; RHS_P RHS; BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R) : Opcode(Opc), LHS(L), RHS(R) {} template bool match(const MatchContext &Ctx, SDValue N) { if (sd_context_match(N, Ctx, m_Opc(Opcode))) { EffectiveOperands EO(N); assert(EO.Size == 2); return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) && RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && RHS.match(Ctx, N->getOperand(EO.FirstIndex))); } return false; } }; template inline BinaryOpc_match m_BinOp(unsigned Opc, const LHS &L, const RHS &R) { return BinaryOpc_match(Opc, L, R); } template inline BinaryOpc_match m_c_BinOp(unsigned Opc, const LHS &L, const RHS &R) { return BinaryOpc_match(Opc, L, R); } template inline BinaryOpc_match m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { return BinaryOpc_match(Opc, L, R); } template inline BinaryOpc_match m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { return BinaryOpc_match(Opc, L, R); } // Common binary operations template inline BinaryOpc_match m_Add(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::ADD, L, R); } template inline BinaryOpc_match m_Sub(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SUB, L, R); } template inline BinaryOpc_match m_Mul(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::MUL, L, R); } template inline BinaryOpc_match m_And(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::AND, L, R); } template inline BinaryOpc_match m_Or(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::OR, L, R); } template inline BinaryOpc_match m_Xor(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::XOR, L, R); } template inline BinaryOpc_match m_SMin(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SMIN, L, R); } template inline BinaryOpc_match m_SMax(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SMAX, L, R); } template inline BinaryOpc_match m_UMin(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::UMIN, L, R); } template inline BinaryOpc_match m_UMax(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::UMAX, L, R); } template inline BinaryOpc_match m_UDiv(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::UDIV, L, R); } template inline BinaryOpc_match m_SDiv(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SDIV, L, R); } template inline BinaryOpc_match m_URem(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::UREM, L, R); } template inline BinaryOpc_match m_SRem(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SREM, L, R); } template inline BinaryOpc_match m_Shl(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SHL, L, R); } template inline BinaryOpc_match m_Sra(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SRA, L, R); } template inline BinaryOpc_match m_Srl(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::SRL, L, R); } template inline BinaryOpc_match m_FAdd(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::FADD, L, R); } template inline BinaryOpc_match m_FSub(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::FSUB, L, R); } template inline BinaryOpc_match m_FMul(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::FMUL, L, R); } template inline BinaryOpc_match m_FDiv(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::FDIV, L, R); } template inline BinaryOpc_match m_FRem(const LHS &L, const RHS &R) { return BinaryOpc_match(ISD::FREM, L, R); } // === Unary operations === template struct UnaryOpc_match { unsigned Opcode; Opnd_P Opnd; UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {} template bool match(const MatchContext &Ctx, SDValue N) { if (sd_context_match(N, Ctx, m_Opc(Opcode))) { EffectiveOperands EO(N); assert(EO.Size == 1); return Opnd.match(Ctx, N->getOperand(EO.FirstIndex)); } return false; } }; template inline UnaryOpc_match m_UnaryOp(unsigned Opc, const Opnd &Op) { return UnaryOpc_match(Opc, Op); } template inline UnaryOpc_match m_ChainedUnaryOp(unsigned Opc, const Opnd &Op) { return UnaryOpc_match(Opc, Op); } template inline UnaryOpc_match m_BitReverse(const Opnd &Op) { return UnaryOpc_match(ISD::BITREVERSE, Op); } template inline UnaryOpc_match m_ZExt(const Opnd &Op) { return UnaryOpc_match(ISD::ZERO_EXTEND, Op); } template inline auto m_SExt(const Opnd &Op) { return UnaryOpc_match(ISD::SIGN_EXTEND, Op); } template inline UnaryOpc_match m_AnyExt(const Opnd &Op) { return UnaryOpc_match(ISD::ANY_EXTEND, Op); } template inline UnaryOpc_match m_Trunc(const Opnd &Op) { return UnaryOpc_match(ISD::TRUNCATE, Op); } /// Match a zext or identity /// Allows to peek through optional extensions template inline auto m_ZExtOrSelf(const Opnd &Op) { return m_AnyOf(m_ZExt(Op), Op); } /// Match a sext or identity /// Allows to peek through optional extensions template inline auto m_SExtOrSelf(const Opnd &Op) { return m_AnyOf(m_SExt(Op), Op); } /// Match a aext or identity /// Allows to peek through optional extensions template inline Or, Opnd> m_AExtOrSelf(const Opnd &Op) { return Or, Opnd>(m_AnyExt(Op), Op); } /// Match a trunc or identity /// Allows to peek through optional truncations template inline Or, Opnd> m_TruncOrSelf(const Opnd &Op) { return Or, Opnd>(m_Trunc(Op), Op); } // === Constants === struct ConstantInt_match { APInt *BindVal; explicit ConstantInt_match(APInt *V) : BindVal(V) {} template bool match(const MatchContext &, SDValue N) { // The logics here are similar to that in // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also // treats GlobalAddressSDNode as a constant, which is difficult to turn into // APInt. if (auto *C = dyn_cast_or_null(N.getNode())) { if (BindVal) *BindVal = C->getAPIntValue(); return true; } APInt Discard; return ISD::isConstantSplatVector(N.getNode(), BindVal ? *BindVal : Discard); } }; /// Match any interger constants or splat of an integer constant. inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); } /// Match any interger constants or splat of an integer constant; return the /// specific constant or constant splat value. inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); } struct SpecificInt_match { APInt IntVal; explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {} template bool match(const MatchContext &Ctx, SDValue N) { APInt ConstInt; if (sd_context_match(N, Ctx, m_ConstInt(ConstInt))) return APInt::isSameValue(IntVal, ConstInt); return false; } }; /// Match a specific integer constant or constant splat value. inline SpecificInt_match m_SpecificInt(APInt V) { return SpecificInt_match(std::move(V)); } inline SpecificInt_match m_SpecificInt(uint64_t V) { return SpecificInt_match(APInt(64, V)); } inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); } inline SpecificInt_match m_One() { return m_SpecificInt(1U); } struct AllOnes_match { AllOnes_match() = default; template bool match(const MatchContext &, SDValue N) { return isAllOnesOrAllOnesSplat(N); } }; inline AllOnes_match m_AllOnes() { return AllOnes_match(); } /// Match true boolean value based on the information provided by /// TargetLowering. inline auto m_True() { return TLI_pred_match{ [](const TargetLowering &TLI, SDValue N) { APInt ConstVal; if (sd_match(N, m_ConstInt(ConstVal))) switch (TLI.getBooleanContents(N.getValueType())) { case TargetLowering::ZeroOrOneBooleanContent: return ConstVal.isOne(); case TargetLowering::ZeroOrNegativeOneBooleanContent: return ConstVal.isAllOnes(); case TargetLowering::UndefinedBooleanContent: return (ConstVal & 0x01) == 1; } return false; }, m_Value()}; } /// Match false boolean value based on the information provided by /// TargetLowering. inline auto m_False() { return TLI_pred_match{ [](const TargetLowering &TLI, SDValue N) { APInt ConstVal; if (sd_match(N, m_ConstInt(ConstVal))) switch (TLI.getBooleanContents(N.getValueType())) { case TargetLowering::ZeroOrOneBooleanContent: case TargetLowering::ZeroOrNegativeOneBooleanContent: return ConstVal.isZero(); case TargetLowering::UndefinedBooleanContent: return (ConstVal & 0x01) == 0; } return false; }, m_Value()}; } struct CondCode_match { std::optional CCToMatch; ISD::CondCode *BindCC = nullptr; explicit CondCode_match(ISD::CondCode CC) : CCToMatch(CC) {} explicit CondCode_match(ISD::CondCode *CC) : BindCC(CC) {} template bool match(const MatchContext &, SDValue N) { if (auto *CC = dyn_cast(N.getNode())) { if (CCToMatch && *CCToMatch != CC->get()) return false; if (BindCC) *BindCC = CC->get(); return true; } return false; } }; /// Match any conditional code SDNode. inline CondCode_match m_CondCode() { return CondCode_match(nullptr); } /// Match any conditional code SDNode and return its ISD::CondCode value. inline CondCode_match m_CondCode(ISD::CondCode &CC) { return CondCode_match(&CC); } /// Match a conditional code SDNode with a specific ISD::CondCode. inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) { return CondCode_match(CC); } /// Match a negate as a sub(0, v) template inline BinaryOpc_match m_Neg(const ValTy &V) { return m_Sub(m_Zero(), V); } /// Match a Not as a xor(v, -1) or xor(-1, v) template inline BinaryOpc_match m_Not(const ValTy &V) { return m_Xor(V, m_AllOnes()); } } // namespace SDPatternMatch } // namespace llvm #endif