Jlm
NodeReductionTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2024 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
12 #include <jlm/llvm/ir/types.hpp>
18 #include <jlm/rvsdg/graph.hpp>
19 #include <jlm/rvsdg/view.hpp>
20 #include <jlm/util/Statistics.hpp>
21 
22 TEST(NodeReductionTests, MultipleReductionsPerRegion)
23 {
24  using namespace jlm::llvm;
25  using namespace jlm::rvsdg;
26 
27  // Arrange
28  const auto bitType = BitType::Create(32);
29  const auto memoryStateType = MemoryStateType::Create();
30 
31  jlm::llvm::LlvmRvsdgModule rvsdgModule(jlm::util::FilePath(""), "", "");
32  auto & graph = rvsdgModule.Rvsdg();
33 
34  auto & sizeArgument = jlm::rvsdg::GraphImport::Create(graph, bitType, "size");
35  auto allocaResults = AllocaOperation::create(bitType, &sizeArgument, 4);
36 
37  const auto c3 =
38  &BitConstantOperation::create(graph.GetRootRegion(), BitValueRepresentation(32, 3));
39  auto storeResults =
40  StoreNonVolatileOperation::Create(allocaResults[0], c3, { allocaResults[1] }, 4);
41  auto loadResults =
42  LoadNonVolatileOperation::Create(allocaResults[0], { storeResults[0] }, bitType, 4);
43 
44  const auto c5 =
45  &BitConstantOperation::create(graph.GetRootRegion(), BitValueRepresentation(32, 5));
46  auto sum = bitadd_op::create(32, loadResults[0], c5);
47 
48  auto & sumExport = jlm::rvsdg::GraphExport::Create(*sum, "sum");
49 
50  view(graph, stdout);
51 
52  // Act
53  NodeReduction nodeReduction;
56  nodeReduction.Run(rvsdgModule, statisticsCollector);
57 
58  view(graph, stdout);
59 
60  // Assert
61  // We expect that two reductions are applied:
62  // 1. NormalizeLoadStore - This ensures that the stored constant value is directly forwarded to
63  // the add operation
64  // 2. Constant folding on the add operation
65  // The result is that a single constant node with value 8 is left in the graph.
66  EXPECT_EQ(graph.GetRootRegion().numNodes(), 1u);
67 
68  auto constantNode = TryGetOwnerNode<SimpleNode>(*sumExport.origin());
69  auto constantOperation =
70  dynamic_cast<const BitConstantOperation *>(&constantNode->GetOperation());
71  EXPECT_EQ(constantOperation->value().to_uint(), 8u);
72 
73  // We expect that the node reductions transformation iterated over the root region 2 times.
74  auto & statistics = *statisticsCollector.CollectedStatistics().begin();
75  auto & nodeReductionStatistics = dynamic_cast<const NodeReduction::Statistics &>(statistics);
76  auto numIterations = nodeReductionStatistics.GetNumIterations(graph.GetRootRegion()).value();
77  EXPECT_EQ(numIterations, 2u);
78 }
static jlm::util::StatisticsCollector statisticsCollector
TEST(NodeReductionTests, MultipleReductionsPerRegion)
static std::vector< rvsdg::Output * > create(std::shared_ptr< const rvsdg::Type > allocatedType, rvsdg::Output *count, const size_t alignment)
Definition: alloca.hpp:131
static std::unique_ptr< llvm::ThreeAddressCode > Create(const Variable *address, const Variable *state, std::shared_ptr< const rvsdg::Type > loadedType, size_t alignment)
Definition: Load.hpp:444
static std::shared_ptr< const MemoryStateType > Create()
Definition: types.cpp:379
std::optional< size_t > GetNumIterations(const rvsdg::Region &region) const noexcept
Definition: reduction.cpp:46
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: reduction.cpp:63
static std::unique_ptr< llvm::ThreeAddressCode > Create(const Variable *address, const Variable *value, const Variable *state, size_t alignment)
Definition: Store.hpp:304
static GraphExport & Create(Output &origin, std::string name)
Definition: graph.cpp:62
static GraphImport & Create(Graph &graph, std::shared_ptr< const rvsdg::Type > type, std::string name)
Definition: graph.cpp:36
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
StatisticsRange CollectedStatistics() const noexcept
Definition: Statistics.hpp:528
Global memory state passed between functions.
std::string view(const rvsdg::Region *region)
Definition: view.cpp:142