6 #include <gtest/gtest.h>
17 const std::shared_ptr<const jlm::rvsdg::Type> operandType,
18 const std::shared_ptr<const jlm::rvsdg::Type> resultType,
26 const noexcept
override
28 auto n1 = jlm::rvsdg::TryGetOwnerNode<jlm::rvsdg::SimpleNode>(*operand1);
29 auto n2 = jlm::rvsdg::TryGetOwnerNode<jlm::rvsdg::SimpleNode>(*operand2);
31 if (n1 && n2 && jlm::rvsdg::is<jlm::rvsdg::UnaryOperation>(n1->GetOperation())
32 && jlm::rvsdg::is<jlm::rvsdg::UnaryOperation>(n2->GetOperation()))
56 flags() const noexcept
override
67 [[nodiscard]] std::string
70 return "BinaryOperation";
73 [[nodiscard]] std::unique_ptr<Operation>
83 TEST(BinaryOperationTests, ReduceFlattenedBinaryReductionParallel)
100 auto & node = CreateOpNode<FlattenedBinaryOperation>({ i0, i1, i2, i3 }, binaryOperation, 4);
114 auto node0 = TryGetOwnerNode<SimpleNode>(*ex.origin());
115 EXPECT_TRUE(is<TestBinaryOperation>(node0->GetOperation()));
117 auto node1 = TryGetOwnerNode<SimpleNode>(*node0->input(0)->origin());
118 EXPECT_TRUE(is<TestBinaryOperation>(node1->GetOperation()));
120 auto node2 = TryGetOwnerNode<SimpleNode>(*node0->input(1)->origin());
121 EXPECT_TRUE(is<TestBinaryOperation>(node2->GetOperation()));
124 TEST(BinaryOperationTests, ReduceFlattenedBinaryReductionLinear)
141 auto & node = CreateOpNode<FlattenedBinaryOperation>({ i0, i1, i2, i3 }, binaryOperation, 4);
156 auto node0 = TryGetOwnerNode<SimpleNode>(*ex.origin());
157 EXPECT_TRUE(is<TestBinaryOperation>(node0->GetOperation()));
159 auto node1 = TryGetOwnerNode<SimpleNode>(*node0->input(0)->origin());
160 EXPECT_TRUE(is<TestBinaryOperation>(node1->GetOperation()));
162 auto node2 = TryGetOwnerNode<SimpleNode>(*node1->input(0)->origin());
163 EXPECT_TRUE(is<TestBinaryOperation>(node2->GetOperation()));
166 TEST(BinaryOperationTests, FlattenAssociativeBinaryOperation_NotAssociativeBinary)
178 auto o1 = &CreateOpNode<TestBinaryOperation>(
183 auto o2 = &CreateOpNode<TestBinaryOperation>(
184 { o1->output(0), i2 },
194 auto node = TryGetOwnerNode<SimpleNode>(*ex.origin());
200 EXPECT_FALSE(success);
201 EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*ex.origin()), node);
204 TEST(BinaryOperationTests, FlattenAssociativeBinaryOperation_NoNewOperands)
215 auto u1 = &CreateOpNode<TestUnaryOperation>({ i0 }, valueType, valueType);
216 auto u2 = &CreateOpNode<TestUnaryOperation>({ i1 }, valueType, valueType);
217 auto b2 = &CreateOpNode<TestBinaryOperation>(
218 { u1->output(0), u2->output(0) },
228 auto node = TryGetOwnerNode<SimpleNode>(*ex.origin());
234 EXPECT_FALSE(success);
235 EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*ex.origin()), node);
238 TEST(BinaryOperationTests, FlattenAssociativeBinaryOperation_Success)
250 auto o1 = &CreateOpNode<TestBinaryOperation>(
255 auto o2 = &CreateOpNode<TestBinaryOperation>(
256 { o1->output(0), i2 },
266 auto node = TryGetOwnerNode<SimpleNode>(*ex.origin());
272 EXPECT_TRUE(success);
273 auto flattenedBinaryNode = TryGetOwnerNode<SimpleNode>(*ex.origin());
274 EXPECT_TRUE(is<FlattenedBinaryOperation>(flattenedBinaryNode->GetOperation()));
275 EXPECT_EQ(flattenedBinaryNode->ninputs(), 3u);
278 TEST(BinaryOperationTests, NormalizeBinaryOperation_NoNewOperands)
289 auto o1 = &CreateOpNode<TestBinaryOperation>(
300 auto node = TryGetOwnerNode<SimpleNode>(*ex.origin());
306 EXPECT_FALSE(success);
309 TEST(BinaryOperationTests, NormalizeBinaryOperation_SingleOperand)
320 auto u1 = &CreateOpNode<TestUnaryOperation>({ s0 }, valueType, valueType);
321 auto u2 = &CreateOpNode<TestUnaryOperation>({ s1 }, valueType, valueType);
323 auto o1 = &CreateOpNode<::BinaryOperation>(
324 { u1->output(0), u2->output(0) },
334 auto node = TryGetOwnerNode<SimpleNode>(*ex.origin());
340 EXPECT_TRUE(success);
341 EXPECT_EQ(ex.origin(), u2->output(0));
TEST(BinaryOperationTests, ReduceFlattenedBinaryReductionParallel)
std::string debug_string() const override
std::unique_ptr< Operation > copy() const override
enum jlm::rvsdg::BinaryOperation::flags flags() const noexcept override
enum jlm::rvsdg::BinaryOperation::flags Flags_
jlm::rvsdg::Output * reduce_operand_pair(jlm::rvsdg::unop_reduction_path_t path, jlm::rvsdg::Output *, jlm::rvsdg::Output *op2) const override
BinaryOperation(const std::shared_ptr< const jlm::rvsdg::Type > operandType, const std::shared_ptr< const jlm::rvsdg::Type > resultType, const enum jlm::rvsdg::BinaryOperation::flags &flags)
bool operator==(const Operation &) const noexcept override
jlm::rvsdg::binop_reduction_path_t can_reduce_operand_pair(const jlm::rvsdg::Output *operand1, const jlm::rvsdg::Output *operand2) const noexcept override
jlm::rvsdg::Output * reduce(const FlattenedBinaryOperation::reduction &reduction, const std::vector< jlm::rvsdg::Output * > &operands) const
static GraphExport & Create(Output &origin, std::string name)
static GraphImport & Create(Graph &graph, std::shared_ptr< const rvsdg::Type > type, std::string name)
Region & GetRootRegion() const noexcept
size_t numNodes() const noexcept
const std::shared_ptr< const rvsdg::Type > & argument(size_t index) const noexcept
const std::shared_ptr< const rvsdg::Type > & result(size_t index) const noexcept
static std::shared_ptr< const TestType > createValueType()
#define JLM_UNREACHABLE(msg)
size_t unop_reduction_path_t
std::optional< std::vector< rvsdg::Output * > > NormalizeBinaryOperation(const BinaryOperation &operation, const std::vector< rvsdg::Output * > &operands)
Applies the reductions implemented in the binary operations reduction functions.
size_t binop_reduction_path_t
std::string view(const rvsdg::Region *region)
std::optional< std::vector< rvsdg::Output * > > FlattenAssociativeBinaryOperation(const BinaryOperation &operation, const std::vector< rvsdg::Output * > &operands)
Flattens a cascade of the same binary operations into a single flattened binary operation.