Jlm
ForkTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2024 Magnus Sjalander <work@sjalander.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
10 #include <jlm/hls/ir/hls.hpp>
15 #include <jlm/rvsdg/lambda.hpp>
16 #include <jlm/rvsdg/view.hpp>
17 
18 TEST(ForkInsertionTests, ForkInsertion)
19 {
20  using namespace jlm;
21  using namespace jlm::llvm;
22 
23  // Arrange
24  auto bit32Type = rvsdg::BitType::Create(32);
25  const auto functionType = jlm::rvsdg::FunctionType::Create(
26  { bit32Type, bit32Type, bit32Type },
27  { bit32Type, bit32Type, bit32Type });
28 
29  LlvmRvsdgModule rvsdgModule(util::FilePath(""), "", "");
30  auto & rootRegion = rvsdgModule.Rvsdg().GetRootRegion();
31 
32  auto lambda = jlm::rvsdg::LambdaNode::Create(
33  rootRegion,
35 
36  auto loop = hls::LoopNode::create(lambda->subregion());
37  rvsdg::Output * idvBuffer = nullptr;
38  loop->AddLoopVar(lambda->GetFunctionArguments()[0], &idvBuffer);
39  rvsdg::Output * lvsBuffer = nullptr;
40  loop->AddLoopVar(lambda->GetFunctionArguments()[1], &lvsBuffer);
41  rvsdg::Output * lveBuffer = nullptr;
42  loop->AddLoopVar(lambda->GetFunctionArguments()[2], &lveBuffer);
43 
44  auto arm = rvsdg::CreateOpNode<rvsdg::bitadd_op>({ idvBuffer, lvsBuffer }, 32).output(0);
45  auto cmp = rvsdg::CreateOpNode<rvsdg::bitult_op>({ arm, lveBuffer }, 32).output(0);
46  auto & matchNode = rvsdg::MatchOperation::CreateNode(*cmp, { { 1, 1 } }, 0, 2);
47 
48  loop->set_predicate(matchNode.output(0));
49 
50  auto lambdaOutput = lambda->finalize({ loop->output(0), loop->output(1), loop->output(2) });
51  rvsdg::GraphExport::Create(*lambdaOutput, "");
52 
53  rvsdg::view(rvsdgModule.Rvsdg(), stdout);
54 
55  // Act
58  rvsdg::view(rvsdgModule.Rvsdg(), stdout);
59 
60  // Assert
61  {
62  EXPECT_EQ(rootRegion.numNodes(), 1);
63  auto lambda = util::assertedCast<jlm::rvsdg::LambdaNode>(rootRegion.Nodes().begin().ptr());
64  EXPECT_NE(dynamic_cast<const jlm::rvsdg::LambdaNode *>(lambda), nullptr);
65 
66  auto lambdaSubregion = lambda->subregion();
67  EXPECT_EQ(lambdaSubregion->numNodes(), 1);
68  auto loop = util::assertedCast<hls::LoopNode>(lambdaSubregion->Nodes().begin().ptr());
69  EXPECT_NE(dynamic_cast<const hls::LoopNode *>(loop), nullptr);
70 
71  auto [forkNode, forkOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<hls::ForkOperation>(
72  *loop->subregion()->result(0)->origin());
73  EXPECT_TRUE(forkNode && forkOperation);
74  EXPECT_EQ(forkNode->ninputs(), 1);
75  EXPECT_EQ(forkNode->noutputs(), 4);
76  EXPECT_FALSE(forkOperation->IsConstant());
77  }
78 }
79 
80 TEST(SinkInsertionTests, ConstantForkInsertion)
81 {
82  using namespace jlm;
83  using namespace jlm::llvm;
84 
85  // Arrange
86  auto bit32Type = rvsdg::BitType::Create(32);
87  const auto functionType = rvsdg::FunctionType::Create({ bit32Type }, { bit32Type });
88 
89  LlvmRvsdgModule rvsdgModule(util::FilePath(""), "", "");
90  auto & rootRegion = rvsdgModule.Rvsdg().GetRootRegion();
91 
92  auto lambda = rvsdg::LambdaNode::Create(
93  rootRegion,
95 
96  auto loop = hls::LoopNode::create(lambda->subregion());
97  auto subregion = loop->subregion();
98  rvsdg::Output * idvBuffer = nullptr;
99  loop->AddLoopVar(lambda->GetFunctionArguments()[0], &idvBuffer);
100  auto bitConstant1 = &rvsdg::BitConstantOperation::create(*subregion, { 32, 1 });
101 
102  auto arm = rvsdg::CreateOpNode<rvsdg::bitadd_op>({ idvBuffer, bitConstant1 }, 32).output(0);
103  auto cmp = rvsdg::CreateOpNode<rvsdg::bitult_op>({ arm, bitConstant1 }, 32).output(0);
104  auto & matchNode = rvsdg::MatchOperation::CreateNode(*cmp, { { 1, 1 } }, 0, 2);
105 
106  loop->set_predicate(matchNode.output(0));
107 
108  auto lambdaOutput = lambda->finalize({ loop->output(0) });
109  rvsdg::GraphExport::Create(*lambdaOutput, "");
110 
111  rvsdg::view(rvsdgModule.Rvsdg(), stdout);
112 
113  // Act
116  rvsdg::view(rvsdgModule.Rvsdg(), stdout);
117 
118  // Assert
119  {
120  EXPECT_EQ(rootRegion.numNodes(), 1);
121  auto lambda = util::assertedCast<jlm::rvsdg::LambdaNode>(rootRegion.Nodes().begin().ptr());
122  EXPECT_TRUE(rvsdg::is<jlm::rvsdg::LambdaOperation>(lambda));
123 
124  auto lambdaRegion = lambda->subregion();
125  EXPECT_EQ(lambdaRegion->numNodes(), 1);
126 
127  const rvsdg::NodeOutput * loopOutput =
128  dynamic_cast<jlm::rvsdg::NodeOutput *>(lambdaRegion->result(0)->origin());
129  EXPECT_NE(loopOutput, nullptr);
130  auto loopNode = loopOutput->node();
131  EXPECT_TRUE(rvsdg::is<hls::LoopOperation>(loopNode));
132  auto loop = util::assertedCast<hls::LoopNode>(loopNode);
133 
134  auto [forkNode, forkOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<hls::ForkOperation>(
135  *loop->subregion()->result(0)->origin());
136  EXPECT_TRUE(forkNode && forkOperation);
137  EXPECT_EQ(forkNode->ninputs(), 1);
138  EXPECT_EQ(forkNode->noutputs(), 2);
139  EXPECT_FALSE(forkOperation->IsConstant());
140 
141  auto matchNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*forkNode->input(0)->origin());
142  auto bitsUltNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*matchNode->input(0)->origin());
143  auto [cForkNode, cForkOperation] =
144  rvsdg::TryGetSimpleNodeAndOptionalOp<hls::ForkOperation>(*bitsUltNode->input(1)->origin());
145  EXPECT_EQ(cForkNode->ninputs(), 1);
146  EXPECT_EQ(cForkNode->noutputs(), 2);
147  EXPECT_TRUE(cForkOperation->IsConstant());
148  }
149 }
static jlm::util::StatisticsCollector statisticsCollector
TEST(ForkInsertionTests, ForkInsertion)
Definition: ForkTests.cpp:18
static void CreateAndRun(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector)
Definition: add-forks.cpp:21
static LoopNode * create(rvsdg::Region *parent, bool init=true)
Definition: hls.cpp:294
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:84
static Output & create(Region &region, BitValueRepresentation value)
Definition: constant.hpp:44
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
static GraphExport & Create(Output &origin, std::string name)
Definition: graph.cpp:62
Lambda node.
Definition: lambda.hpp:83
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
static Node & CreateNode(Output &predicate, const std::unordered_map< uint64_t, uint64_t > &mapping, const uint64_t defaultAlternative, const size_t numAlternatives)
Definition: control.hpp:226
Node * node() const noexcept
Definition: node.hpp:572
Global memory state passed between functions.
std::string view(const rvsdg::Region *region)
Definition: view.cpp:142