6 #include <gtest/gtest.h>
22 TEST(NodeHoistingTests, simpleGamma)
28 const auto controlType = ControlType::Create(2);
29 const auto valueType = TestType::createValueType();
30 const auto functionType = FunctionType::Create(
38 auto & rvsdg = rvsdgModule.Rvsdg();
40 auto lambdaNode = LambdaNode::Create(
41 rvsdg.GetRootRegion(),
43 auto controlArgument = lambdaNode->GetFunctionArguments()[0];
44 auto valueArgument = lambdaNode->GetFunctionArguments()[1];
46 auto gammaNode = GammaNode::create(controlArgument, 2);
47 auto entryVar = gammaNode->AddEntryVar(valueArgument);
50 auto constantNode = TestOperation::createNode(gammaNode->subregion(0), {}, { valueType });
51 auto binaryNode = TestOperation::createNode(
52 gammaNode->subregion(0),
53 { entryVar.branchArgument[0], constantNode->output(0) },
57 auto unaryNode = TestOperation::createNode(
58 gammaNode->subregion(1),
59 { entryVar.branchArgument[1] },
62 auto exitVar = gammaNode->AddExitVar({ binaryNode->output(0), unaryNode->output(0) });
64 auto lambdaOutput = lambdaNode->finalize({ exitVar.output });
66 GraphExport::Create(*lambdaOutput,
"x");
79 EXPECT_EQ(lambdaNode->subregion()->numNodes(), 4u);
82 EXPECT_EQ(gammaNode->subregion(0)->numNodes(), 0u);
83 EXPECT_EQ(gammaNode->subregion(1)->numNodes(), 0u);
86 TEST(NodeHoistingTests, nestedGamma)
92 const auto controlType = ControlType::Create(2);
93 const auto valueType = TestType::createValueType();
94 const auto functionType = FunctionType::Create(
102 auto & rvsdg = rvsdgModule.Rvsdg();
104 auto lambdaNode = LambdaNode::Create(
105 rvsdg.GetRootRegion(),
107 auto controlArgument = lambdaNode->GetFunctionArguments()[0];
108 auto valueArgument = lambdaNode->GetFunctionArguments()[1];
110 auto gammaNode1 = GammaNode::create(controlArgument, 2);
111 auto controlEntryVar = gammaNode1->AddEntryVar(controlArgument);
112 auto valueEntryVar1 = gammaNode1->AddEntryVar(valueArgument);
115 auto constantNode1 = TestOperation::createNode(gammaNode1->subregion(0), {}, { valueType });
117 auto gammaNode2 = GammaNode::create(controlEntryVar.branchArgument[0], 2);
118 auto valueEntryVar2 = gammaNode2->AddEntryVar(valueEntryVar1.branchArgument[0]);
119 auto valueEntryVar3 = gammaNode2->AddEntryVar(constantNode1->output(0));
122 auto binaryNode = TestOperation::createNode(
123 gammaNode1->subregion(0),
124 { valueEntryVar2.branchArgument[0], valueEntryVar3.branchArgument[0] },
128 auto unaryNode = TestOperation::createNode(
129 gammaNode1->subregion(1),
130 { valueEntryVar2.branchArgument[1] },
133 auto exitVar1 = gammaNode2->AddExitVar({ binaryNode->output(0), unaryNode->output(0) });
136 auto constantNode2 = TestOperation::createNode(gammaNode1->subregion(1), {}, { valueType });
138 auto exitVar2 = gammaNode1->AddExitVar({ exitVar1.output, constantNode2->output(0) });
140 auto lambdaOutput = lambdaNode->finalize({ exitVar2.output });
142 GraphExport::Create(*lambdaOutput,
"x");
155 EXPECT_EQ(lambdaNode->subregion()->numNodes(), 5u);
158 EXPECT_EQ(gammaNode1->subregion(0)->numNodes(), 1u);
159 EXPECT_EQ(gammaNode1->subregion(1)->numNodes(), 0u);
162 EXPECT_EQ(gammaNode2->subregion(0)->numNodes(), 0u);
163 EXPECT_EQ(gammaNode2->subregion(1)->numNodes(), 0u);
166 TEST(NodeHoistingTests, simpleTheta)
172 auto controlType = ControlType::Create(2);
173 const auto valueType = TestType::createValueType();
174 const auto functionType = FunctionType::Create(
182 auto & rvsdg = rvsdgModule.Rvsdg();
184 auto lambdaNode = LambdaNode::Create(
185 rvsdg.GetRootRegion(),
187 auto controlArgument = lambdaNode->GetFunctionArguments()[0];
188 auto valueArgument = lambdaNode->GetFunctionArguments()[1];
190 auto thetaNode = ThetaNode::create(lambdaNode->subregion());
192 auto lv1 = thetaNode->AddLoopVar(controlArgument);
193 auto lv2 = thetaNode->AddLoopVar(valueArgument);
194 auto lv3 = thetaNode->AddLoopVar(valueArgument);
195 auto lv4 = thetaNode->AddLoopVar(valueArgument);
197 auto node1 = TestOperation::createNode(thetaNode->subregion(), {}, { valueType });
198 auto node2 = TestOperation::createNode(
199 thetaNode->subregion(),
200 { node1->output(0), lv3.pre },
202 auto node3 = TestOperation::createNode(
203 thetaNode->subregion(),
204 { lv2.pre, node2->output(0) },
207 TestOperation::createNode(thetaNode->subregion(), { lv3.pre, lv4.pre }, { valueType });
209 lv2.post->divert_to(node3->output(0));
210 lv4.post->divert_to(node4->output(0));
212 thetaNode->set_predicate(lv1.pre);
214 lambdaNode->finalize({ thetaNode->output(1) });
227 EXPECT_EQ(lambdaNode->subregion()->numNodes(), 3u);
228 EXPECT_EQ(thetaNode->subregion()->numNodes(), 2u);
230 EXPECT_EQ(lv2.post->origin(), node3->output(0));
231 EXPECT_EQ(lv4.post->origin(), node4->output(0));
234 TEST(NodeHoistingTests, invariantMemoryOperation)
242 const auto controlType = ControlType::Create(2);
243 const auto valueType = TestType::createValueType();
244 const auto functionType = FunctionType::Create(
245 { controlType, pointerType, valueType, memoryStateType },
246 { memoryStateType });
249 auto & rvsdg = rvsdgModule.Rvsdg();
251 auto lambdaNode = LambdaNode::Create(
252 rvsdg.GetRootRegion(),
254 auto controlArgument = lambdaNode->GetFunctionArguments()[0];
255 auto pointerArgument = lambdaNode->GetFunctionArguments()[1];
256 auto valueArgument = lambdaNode->GetFunctionArguments()[2];
257 auto memoryStateArgument = lambdaNode->GetFunctionArguments()[3];
259 auto thetaNode = ThetaNode::create(lambdaNode->subregion());
261 auto lvc = thetaNode->AddLoopVar(controlArgument);
262 auto lva = thetaNode->AddLoopVar(pointerArgument);
263 auto lvv = thetaNode->AddLoopVar(valueArgument);
264 auto lvs = thetaNode->AddLoopVar(memoryStateArgument);
268 lvs.post->divert_to(storeNode.output(0));
269 thetaNode->set_predicate(lvc.pre);
271 lambdaNode->finalize({ lvs.output });
284 EXPECT_EQ(lambdaNode->subregion()->numNodes(), 2u);
285 EXPECT_EQ(thetaNode->subregion()->numNodes(), 0u);
288 TEST(NodeHoistingTests, statefulOperations)
294 auto controlType = ControlType::Create(2);
295 auto valueType = TestType::createValueType();
296 auto stateType = TestType::createStateType();
297 const auto functionType = FunctionType::Create(
306 auto & rvsdg = rvsdgModule.Rvsdg();
308 auto lambdaNode = LambdaNode::Create(
309 rvsdg.GetRootRegion(),
311 auto controlArgument = lambdaNode->GetFunctionArguments()[0];
312 auto valueArgument = lambdaNode->GetFunctionArguments()[1];
313 auto stateArgument = lambdaNode->GetFunctionArguments()[2];
315 auto gammaNode1 = GammaNode::create(controlArgument, 2);
316 auto controlEntryVar = gammaNode1->AddEntryVar(controlArgument);
317 auto valueEntryVar1 = gammaNode1->AddEntryVar(valueArgument);
318 auto stateEntryVar = gammaNode1->AddEntryVar(stateArgument);
320 auto stateNode = TestOperation::createNode(
321 gammaNode1->subregion(0),
322 { valueEntryVar1.branchArgument[0], stateEntryVar.branchArgument[0] },
325 auto gammaNode2 = GammaNode::create(controlEntryVar.branchArgument[0], 2);
326 auto valueEntryVar2 = gammaNode2->AddEntryVar(stateNode->output(0));
327 auto valueEntryVar3 = gammaNode2->AddEntryVar(valueEntryVar1.branchArgument[0]);
329 auto binaryNode = TestOperation::createNode(
330 gammaNode2->subregion(0),
331 { valueEntryVar2.branchArgument[0], valueEntryVar3.branchArgument[0] },
335 gammaNode2->AddExitVar({ binaryNode->output(0), valueEntryVar2.branchArgument[1] });
337 auto exitVar = gammaNode1->AddExitVar({ exitVar2.output, valueEntryVar1.branchArgument[1] });
339 lambdaNode->finalize({ exitVar.output });
355 EXPECT_EQ(lambdaNode->subregion()->numNodes(), 1u);
358 EXPECT_EQ(gammaNode1->subregion(0)->numNodes(), 3u);
359 EXPECT_EQ(gammaNode1->subregion(1)->numNodes(), 0u);
361 EXPECT_EQ(gammaNode2->subregion(0)->numNodes(), 0u);
362 EXPECT_EQ(gammaNode2->subregion(1)->numNodes(), 0u);
static jlm::util::StatisticsCollector statisticsCollector
TEST(NodeHoistingTests, simpleGamma)
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
static std::shared_ptr< const MemoryStateType > Create()
Node Hoisting Transformation.
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static std::shared_ptr< const PointerType > Create()
static rvsdg::SimpleNode & CreateNode(rvsdg::Output &address, rvsdg::Output &value, const std::vector< rvsdg::Output * > &memoryStates, size_t alignment)
Global memory state passed between functions.
std::string view(const rvsdg::Region *region)