Jlm
LoadChainSeparationTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
17 #include <jlm/llvm/ir/types.hpp>
19 #include <jlm/rvsdg/gamma.hpp>
20 #include <jlm/rvsdg/graph.hpp>
21 #include <jlm/rvsdg/lambda.hpp>
23 #include <jlm/rvsdg/TestType.hpp>
24 #include <jlm/rvsdg/theta.hpp>
25 #include <jlm/rvsdg/view.hpp>
26 #include <jlm/util/Statistics.hpp>
27 
28 TEST(LoadChainSeparationTests, LoadNonVolatile)
29 {
30  // Arrange
31  using namespace jlm::llvm;
32  using namespace jlm::rvsdg;
33  using namespace jlm::util;
34 
35  const auto pointerType = PointerType::Create();
36  const auto memoryStateType = MemoryStateType::Create();
37  const auto ioStateType = IOStateType::Create();
38  const auto valueType = TestType::createValueType();
39  const auto functionType = FunctionType::Create(
40  { pointerType, ioStateType, memoryStateType },
41  { ioStateType, memoryStateType });
42 
43  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
44  auto & rvsdg = rvsdgModule.Rvsdg();
45 
46  auto lambdaNode = LambdaNode::Create(
47  rvsdg.GetRootRegion(),
49 
50  auto & addressArgument = *lambdaNode->GetFunctionArguments()[0];
51  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[1];
52  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[2];
53 
54  auto & lambdaEntrySplitNode =
55  LambdaEntryMemoryStateSplitOperation::CreateNode(memoryStateArgument, { 0, 1 });
56 
57  auto & loadNode1 = LoadNonVolatileOperation::CreateNode(
58  addressArgument,
59  { lambdaEntrySplitNode.output(0), lambdaEntrySplitNode.output(1) },
60  valueType,
61  4);
62 
63  auto & loadNode2 = LoadNonVolatileOperation::CreateNode(
64  addressArgument,
65  { loadNode1.output(1), loadNode1.output(2) },
66  valueType,
67  4);
68 
69  auto & loadNode3 =
70  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode2.output(2) }, valueType, 4);
71 
72  auto & lambdaExitMergeNode = LambdaExitMemoryStateMergeOperation::CreateNode(
73  *lambdaNode->subregion(),
74  { loadNode2.output(1), loadNode3.output(1) },
75  { 0, 1 });
76 
77  lambdaNode->finalize({ &ioStateArgument, lambdaExitMergeNode.output(0) });
78 
79  view(rvsdg, stdout);
80 
81  // Act
83  LoadChainSeparation loadChainSeparation;
84  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
85 
86  view(rvsdg, stdout);
87 
88  // Assert
89 
90  // We expect the transformation to create two join nodes, one for each memory state chain.
91 
92  // Check transformation for the chain of memory state 1
93  {
94  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
95  *lambdaExitMergeNode.input(0)->origin());
96  EXPECT_TRUE(joinNode && joinOperation);
97  EXPECT_EQ(joinNode->ninputs(), 2u);
98 
99  EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*joinNode->input(0)->origin()), &loadNode2);
100  EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*joinNode->input(1)->origin()), &loadNode1);
101 
102  EXPECT_EQ(loadNode1.input(1)->origin(), lambdaEntrySplitNode.output(0));
103  EXPECT_EQ(loadNode1.input(2)->origin(), lambdaEntrySplitNode.output(1));
104 
105  EXPECT_EQ(loadNode2.input(1)->origin(), lambdaEntrySplitNode.output(0));
106  EXPECT_EQ(loadNode2.input(2)->origin(), lambdaEntrySplitNode.output(1));
107  }
108 
109  // Check transformation for the chain of memory state 2
110  {
111  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
112  *lambdaExitMergeNode.input(1)->origin());
113  EXPECT_TRUE(joinNode && joinOperation);
114  EXPECT_EQ(joinNode->ninputs(), 3u);
115 
116  EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*joinNode->input(0)->origin()), &loadNode3);
117  EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*joinNode->input(1)->origin()), &loadNode2);
118  EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*joinNode->input(2)->origin()), &loadNode1);
119 
120  EXPECT_EQ(loadNode3.input(1)->origin(), lambdaEntrySplitNode.output(1));
121  }
122 }
123 
124 TEST(LoadChainSeparationTests, LoadVolatile)
125 {
126  // Arrange
127  using namespace jlm::llvm;
128  using namespace jlm::rvsdg;
129  using namespace jlm::util;
130 
131  const auto pointerType = PointerType::Create();
132  const auto ioStateType = IOStateType::Create();
133  const auto memoryStateType = MemoryStateType::Create();
134  const auto valueType = TestType::createValueType();
135  const auto functionType = FunctionType::Create(
136  { pointerType, ioStateType, memoryStateType },
137  { ioStateType, memoryStateType });
138 
139  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
140  auto & rvsdg = rvsdgModule.Rvsdg();
141 
142  auto lambdaNode = LambdaNode::Create(
143  rvsdg.GetRootRegion(),
145 
146  auto & addressArgument = *lambdaNode->GetFunctionArguments()[0];
147  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[1];
148  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[2];
149 
150  auto & loadNode1 = LoadVolatileOperation::CreateNode(
151  addressArgument,
152  ioStateArgument,
153  { &memoryStateArgument },
154  valueType,
155  4);
156 
157  auto & loadNode2 = LoadVolatileOperation::CreateNode(
158  addressArgument,
160  { &*LoadOperation::MemoryStateOutputs(loadNode1).begin() },
161  valueType,
162  4);
163 
164  lambdaNode->finalize({ &ioStateArgument, loadNode2.output(2) });
165 
166  view(rvsdg, stdout);
167 
168  // Act
170  LoadChainSeparation loadChainSeparation;
171  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
172 
173  view(rvsdg, stdout);
174 
175  // Assert
176 
177  // We expect the transformation to create a single join node with the memory state outputs of the
178  // two load nodes as operands
179 
180  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
181  *GetMemoryStateRegionResult(*lambdaNode).origin());
182  EXPECT_TRUE(joinNode && joinOperation);
183  EXPECT_EQ(joinNode->ninputs(), 2u);
184 
185  EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*joinNode->input(0)->origin()), &loadNode2);
186  EXPECT_EQ(TryGetOwnerNode<SimpleNode>(*joinNode->input(1)->origin()), &loadNode1);
187 
188  EXPECT_EQ(loadNode1.input(2)->origin(), &memoryStateArgument);
189  EXPECT_EQ(loadNode2.input(2)->origin(), &memoryStateArgument);
190 }
191 
192 TEST(LoadChainSeparationTests, SingleLoad)
193 {
194  // Arrange
195  using namespace jlm::llvm;
196  using namespace jlm::rvsdg;
197  using namespace jlm::util;
198 
199  const auto pointerType = PointerType::Create();
200  const auto ioStateType = IOStateType::Create();
201  const auto memoryStateType = MemoryStateType::Create();
202  const auto valueType = TestType::createValueType();
203  const auto functionType = FunctionType::Create(
204  { pointerType, ioStateType, memoryStateType },
205  { ioStateType, memoryStateType });
206 
207  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
208  auto & rvsdg = rvsdgModule.Rvsdg();
209 
210  auto lambdaNode = LambdaNode::Create(
211  rvsdg.GetRootRegion(),
213 
214  auto & addressArgument = *lambdaNode->GetFunctionArguments()[0];
215  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[1];
216  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[2];
217 
218  auto & loadNode =
219  LoadNonVolatileOperation::CreateNode(addressArgument, { &memoryStateArgument }, valueType, 4);
220 
221  lambdaNode->finalize({ &ioStateArgument, loadNode.output(1) });
222 
223  view(rvsdg, stdout);
224 
225  // Act
227  LoadChainSeparation loadChainSeparation;
228  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
229 
230  view(rvsdg, stdout);
231 
232  // Assert
233  // We expect nothing to happen as there is no chain of load nodes
234  EXPECT_EQ(
235  TryGetOwnerNode<SimpleNode>(*GetMemoryStateRegionResult(*lambdaNode).origin()),
236  &loadNode);
237  EXPECT_EQ(LoadOperation::MemoryStateInputs(loadNode).begin()->origin(), &memoryStateArgument);
238 }
239 
240 TEST(LoadChainSeparationTests, LoadAndStore)
241 {
242  // Arrange
243  using namespace jlm::llvm;
244  using namespace jlm::rvsdg;
245  using namespace jlm::util;
246 
247  const auto pointerType = PointerType::Create();
248  const auto memoryStateType = MemoryStateType::Create();
249  const auto ioStateType = IOStateType::Create();
250  const auto valueType = TestType::createValueType();
251  const auto functionType = FunctionType::Create(
252  { pointerType, ioStateType, memoryStateType },
253  { ioStateType, memoryStateType });
254 
255  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
256  auto & rvsdg = rvsdgModule.Rvsdg();
257 
258  auto lambdaNode = LambdaNode::Create(
259  rvsdg.GetRootRegion(),
261 
262  auto & addressArgument = *lambdaNode->GetFunctionArguments()[0];
263  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[1];
264  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[2];
265 
266  auto valueNode = TestOperation::createNode(lambdaNode->subregion(), {}, { valueType });
267 
268  auto & loadNode1 =
269  LoadNonVolatileOperation::CreateNode(addressArgument, { &memoryStateArgument }, valueType, 4);
270 
271  auto & loadNode2 =
272  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode1.output(1) }, valueType, 4);
273 
274  auto & storeNode1 = StoreNonVolatileOperation::CreateNode(
275  addressArgument,
276  *valueNode->output(0),
277  { loadNode2.output(1) },
278  4);
279 
280  auto & loadNode3 =
281  LoadNonVolatileOperation::CreateNode(addressArgument, { storeNode1.output(0) }, valueType, 4);
282 
283  auto & loadNode4 =
284  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode3.output(1) }, valueType, 4);
285 
286  auto & storeNode2 = StoreNonVolatileOperation::CreateNode(
287  addressArgument,
288  *valueNode->output(0),
289  { loadNode4.output(1) },
290  4);
291 
292  auto & loadNode5 =
293  LoadNonVolatileOperation::CreateNode(addressArgument, { storeNode2.output(0) }, valueType, 4);
294 
295  lambdaNode->finalize({ &ioStateArgument, loadNode5.output(1) });
296 
297  view(rvsdg, stdout);
298 
299  // Act
301  LoadChainSeparation loadChainSeparation;
302  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
303 
304  view(rvsdg, stdout);
305 
306  // Assert
307  // We expect two join nodes to appear.
308  {
309  auto [joinNode, joinOperation] =
310  TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(*storeNode2.input(2)->origin());
311  EXPECT_TRUE(joinOperation);
312  EXPECT_EQ(joinNode->ninputs(), 2u);
313  }
314 
315  {
316  auto [joinNode, joinOperation] =
317  TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(*storeNode1.input(2)->origin());
318  EXPECT_TRUE(joinOperation);
319  EXPECT_EQ(joinNode->ninputs(), 2u);
320  }
321 }
322 
323 TEST(LoadChainSeparationTests, GammaWithOnlyLoads)
324 {
325  // Arrange
326  using namespace jlm::llvm;
327  using namespace jlm::rvsdg;
328  using namespace jlm::util;
329 
330  const auto pointerType = PointerType::Create();
331  const auto memoryStateType = MemoryStateType::Create();
332  const auto ioStateType = IOStateType::Create();
333  const auto valueType = TestType::createValueType();
334  const auto controlType = ControlType::Create(2);
335  const auto functionType = FunctionType::Create(
336  { controlType, pointerType, ioStateType, memoryStateType },
337  { ioStateType, memoryStateType });
338 
339  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
340  auto & rvsdg = rvsdgModule.Rvsdg();
341 
342  auto lambdaNode = LambdaNode::Create(
343  rvsdg.GetRootRegion(),
345 
346  auto & controlArgument = *lambdaNode->GetFunctionArguments()[0];
347  auto & addressArgument = *lambdaNode->GetFunctionArguments()[1];
348  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[2];
349  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[3];
350 
351  auto & loadNode1 =
352  LoadNonVolatileOperation::CreateNode(addressArgument, { &memoryStateArgument }, valueType, 4);
353 
354  auto & loadNode2 =
355  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode1.output(1) }, valueType, 4);
356 
357  auto gammaNode = GammaNode::create(&controlArgument, 2);
358  auto addressEntryVar = gammaNode->AddEntryVar(&addressArgument);
359  auto memoryStateEntryVar = gammaNode->AddEntryVar(loadNode2.output(1));
360 
361  // subregion 0
362  auto & loadNode3 = LoadNonVolatileOperation::CreateNode(
363  *addressEntryVar.branchArgument[0],
364  { memoryStateEntryVar.branchArgument[0] },
365  valueType,
366  4);
367 
368  auto & loadNode4 = LoadNonVolatileOperation::CreateNode(
369  *addressEntryVar.branchArgument[0],
370  { loadNode3.output(1) },
371  valueType,
372  4);
373 
374  // subregion 1
375  auto & loadNode5 = LoadNonVolatileOperation::CreateNode(
376  *addressEntryVar.branchArgument[1],
377  { memoryStateEntryVar.branchArgument[1] },
378  valueType,
379  4);
380 
381  auto memoryStateExitVar = gammaNode->AddExitVar({ loadNode4.output(1), loadNode5.output(1) });
382 
383  auto & loadNode6 = LoadNonVolatileOperation::CreateNode(
384  addressArgument,
385  { memoryStateExitVar.output },
386  valueType,
387  4);
388 
389  auto & loadNode7 =
390  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode6.output(1) }, valueType, 4);
391 
392  lambdaNode->finalize({ &ioStateArgument, loadNode7.output(1) });
393 
394  view(rvsdg, stdout);
395 
396  // Act
398  LoadChainSeparation loadChainSeparation;
399  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
400 
401  view(rvsdg, stdout);
402 
403  // Assert
404  // We expect three join nodes to appear
405  {
406  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
407  *GetMemoryStateRegionResult(*lambdaNode).origin());
408  EXPECT_TRUE(joinOperation);
409  EXPECT_EQ(joinNode->ninputs(), 2u);
410  }
411 
412  {
413  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
414  *gammaNode->GetExitVars()[0].branchResult[0]->origin());
415  EXPECT_TRUE(joinOperation);
416  EXPECT_EQ(joinNode->ninputs(), 2u);
417  }
418 
419  {
420  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
421  *gammaNode->GetEntryVars()[1].input->origin());
422  EXPECT_TRUE(joinOperation);
423  EXPECT_EQ(joinNode->ninputs(), 2u);
424  }
425 }
426 
427 TEST(LoadChainSeparationTests, GammaWithLoadsAndStores)
428 {
429  // Arrange
430  using namespace jlm::llvm;
431  using namespace jlm::rvsdg;
432  using namespace jlm::util;
433 
434  const auto pointerType = PointerType::Create();
435  const auto ioStateType = IOStateType::Create();
436  const auto memoryStateType = MemoryStateType::Create();
437  const auto valueType = TestType::createValueType();
438  const auto controlType = ControlType::Create(2);
439  const auto functionType = FunctionType::Create(
440  { controlType, pointerType, ioStateType, memoryStateType },
441  { ioStateType, memoryStateType });
442 
443  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
444  auto & rvsdg = rvsdgModule.Rvsdg();
445 
446  auto lambdaNode = LambdaNode::Create(
447  rvsdg.GetRootRegion(),
449 
450  auto & controlArgument = *lambdaNode->GetFunctionArguments()[0];
451  auto & addressArgument = *lambdaNode->GetFunctionArguments()[1];
452  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[2];
453  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[3];
454 
455  auto & loadNode1 =
456  LoadNonVolatileOperation::CreateNode(addressArgument, { &memoryStateArgument }, valueType, 4);
457 
458  auto gammaNode = GammaNode::create(&controlArgument, 2);
459  auto addressEntryVar = gammaNode->AddEntryVar(&addressArgument);
460  auto memoryStateEntryVar = gammaNode->AddEntryVar(loadNode1.output(1));
461 
462  // subregion 0
463  auto & loadNode2 = LoadNonVolatileOperation::CreateNode(
464  *addressEntryVar.branchArgument[0],
465  { memoryStateEntryVar.branchArgument[0] },
466  valueType,
467  4);
468 
469  auto & loadNode3 = LoadNonVolatileOperation::CreateNode(
470  *addressEntryVar.branchArgument[0],
471  { loadNode2.output(1) },
472  valueType,
473  4);
474 
475  // subregion 1
476  auto value = TestOperation::createNode(gammaNode->subregion(1), {}, { valueType });
477  auto & storeNode = StoreNonVolatileOperation::CreateNode(
478  *addressEntryVar.branchArgument[1],
479  *value->output(0),
480  { memoryStateEntryVar.branchArgument[1] },
481  4);
482 
483  auto & loadNode4 = LoadNonVolatileOperation::CreateNode(
484  *addressEntryVar.branchArgument[1],
485  { storeNode.output(0) },
486  valueType,
487  4);
488 
489  auto memoryStateExitVar = gammaNode->AddExitVar({ loadNode3.output(1), loadNode4.output(1) });
490 
491  auto & loadNode5 = LoadNonVolatileOperation::CreateNode(
492  addressArgument,
493  { memoryStateExitVar.output },
494  valueType,
495  4);
496 
497  auto & loadNode6 =
498  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode5.output(1) }, valueType, 4);
499 
500  lambdaNode->finalize({ &ioStateArgument, loadNode6.output(1) });
501 
502  view(rvsdg, stdout);
503 
504  // Act
506  LoadChainSeparation loadChainSeparation;
507  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
508 
509  view(rvsdg, stdout);
510 
511  // Assert
512  // We expect two join nodes to appear
513  {
514  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
515  *GetMemoryStateRegionResult(*lambdaNode).origin());
516  EXPECT_TRUE(joinOperation);
517  EXPECT_EQ(joinNode->ninputs(), 2u);
518  }
519 
520  {
521  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
522  *gammaNode->GetExitVars()[0].branchResult[0]->origin());
523  EXPECT_TRUE(joinOperation);
524  EXPECT_EQ(joinNode->ninputs(), 2u);
525  }
526 }
527 
528 TEST(LoadChainSeparationTests, ThetaWithLoadsOnly)
529 {
530  // Arrange
531  using namespace jlm::llvm;
532  using namespace jlm::rvsdg;
533  using namespace jlm::util;
534 
535  const auto pointerType = PointerType::Create();
536  const auto memoryStateType = MemoryStateType::Create();
537  const auto ioStateType = IOStateType::Create();
538  const auto valueType = TestType::createValueType();
539  const auto controlType = ControlType::Create(2);
540  const auto functionType = FunctionType::Create(
541  { controlType, pointerType, ioStateType, memoryStateType },
542  { ioStateType, memoryStateType });
543 
544  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
545  auto & rvsdg = rvsdgModule.Rvsdg();
546 
547  auto lambdaNode = LambdaNode::Create(
548  rvsdg.GetRootRegion(),
550 
551  auto & addressArgument = *lambdaNode->GetFunctionArguments()[1];
552  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[2];
553  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[3];
554 
555  auto & loadNode1 =
556  LoadNonVolatileOperation::CreateNode(addressArgument, { &memoryStateArgument }, valueType, 4);
557 
558  auto & loadNode2 =
559  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode1.output(1) }, valueType, 4);
560 
561  auto thetaNode = ThetaNode::create(lambdaNode->subregion());
562 
563  auto addressLoopVar = thetaNode->AddLoopVar(&addressArgument);
564  auto memoryStateLoopVar = thetaNode->AddLoopVar(loadNode2.output(1));
565 
566  auto & loadNode3 = LoadNonVolatileOperation::CreateNode(
567  *addressLoopVar.pre,
568  { memoryStateLoopVar.pre },
569  valueType,
570  4);
571 
572  auto & loadNode4 = LoadNonVolatileOperation::CreateNode(
573  *addressLoopVar.pre,
574  { loadNode3.output(1) },
575  valueType,
576  4);
577 
578  memoryStateLoopVar.post->divert_to(loadNode4.output(1));
579 
580  auto & loadNode5 = LoadNonVolatileOperation::CreateNode(
581  addressArgument,
582  { memoryStateLoopVar.output },
583  valueType,
584  4);
585 
586  auto & loadNode6 =
587  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode5.output(1) }, valueType, 4);
588 
589  lambdaNode->finalize({ &ioStateArgument, loadNode6.output(1) });
590 
591  view(rvsdg, stdout);
592 
593  // Act
595  LoadChainSeparation loadChainSeparation;
596  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
597 
598  view(rvsdg, stdout);
599 
600  // Assert
601  // We expect a single join node in the theta subregion
602  {
603  auto [joinNode, joinOperation] =
604  TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(*memoryStateLoopVar.post->origin());
605  EXPECT_TRUE(joinOperation);
606  EXPECT_EQ(joinNode->ninputs(), 2u);
607  }
608 
609  // We expect a single join node in the lambda subregion
610  {
611  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
612  *GetMemoryStateRegionResult(*lambdaNode).origin());
613  EXPECT_TRUE(joinOperation);
614  EXPECT_EQ(joinNode->ninputs(), 5u);
615  }
616 }
617 
618 TEST(LoadChainSeparationTests, ExternalCall)
619 {
620  // Arrange
621  using namespace jlm::llvm;
622  using namespace jlm::rvsdg;
623  using namespace jlm::util;
624 
625  const auto pointerType = PointerType::Create();
626  const auto memoryStateType = MemoryStateType::Create();
627  const auto ioStateType = IOStateType::Create();
628  const auto valueType = TestType::createValueType();
629  const auto controlType = ControlType::Create(2);
630  const auto functionType = FunctionType::Create(
631  { controlType, pointerType, ioStateType, memoryStateType },
632  { ioStateType, memoryStateType });
633  const auto externalFunctionType =
634  FunctionType::Create({ ioStateType, memoryStateType }, { ioStateType, memoryStateType });
635 
636  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
637  auto & rvsdg = rvsdgModule.Rvsdg();
638 
639  auto & externalFunction = jlm::rvsdg::GraphImport::Create(rvsdg, externalFunctionType, "g");
640 
641  auto lambdaNode = LambdaNode::Create(
642  rvsdg.GetRootRegion(),
644 
645  auto & addressArgument = *lambdaNode->GetFunctionArguments()[1];
646  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[2];
647  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[3];
648  auto externalFunctionCtxVar = lambdaNode->AddContextVar(externalFunction);
649 
650  auto & lambdaEntrySplitNode =
651  LambdaEntryMemoryStateSplitOperation::CreateNode(memoryStateArgument, { 0, 1 });
652 
653  auto & loadNode1 = LoadNonVolatileOperation::CreateNode(
654  addressArgument,
655  { lambdaEntrySplitNode.output(0) },
656  valueType,
657  4);
658 
659  auto & loadNode2 =
660  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode1.output(1) }, valueType, 4);
661 
662  auto & loadNode3 = LoadNonVolatileOperation::CreateNode(
663  addressArgument,
664  { lambdaEntrySplitNode.output(1) },
665  valueType,
666  4);
667 
668  auto & loadNode4 =
669  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode3.output(1) }, valueType, 4);
670 
671  auto & callEntryMergeNode = CallEntryMemoryStateMergeOperation::CreateNode(
672  *lambdaNode->subregion(),
673  { loadNode2.output(1), loadNode4.output(1) },
674  { 0, 1 });
675 
676  auto & callNode = CallOperation::CreateNode(
677  externalFunctionCtxVar.inner,
678  externalFunctionType,
679  { &ioStateArgument, callEntryMergeNode.output(0) });
680 
681  auto & callExitSplitNode =
682  CallExitMemoryStateSplitOperation::CreateNode(*callNode.output(1), { 0, 1 });
683 
684  auto & loadNode5 = LoadNonVolatileOperation::CreateNode(
685  addressArgument,
686  { callExitSplitNode.output(0) },
687  valueType,
688  4);
689 
690  auto & loadNode6 =
691  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode5.output(1) }, valueType, 4);
692 
693  auto & loadNode7 = LoadNonVolatileOperation::CreateNode(
694  addressArgument,
695  { callExitSplitNode.output(1) },
696  valueType,
697  4);
698 
699  auto & loadNode8 =
700  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode7.output(1) }, valueType, 4);
701 
702  auto & lambdaExitMergeNode = LambdaExitMemoryStateMergeOperation::CreateNode(
703  *lambdaNode->subregion(),
704  { loadNode6.output(1), loadNode8.output(1) },
705  { 0, 1 });
706 
707  lambdaNode->finalize({ callNode.output(0), lambdaExitMergeNode.output(0) });
708 
709  view(rvsdg, stdout);
710 
711  // Act
713  LoadChainSeparation loadChainSeparation;
714  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
715 
716  view(rvsdg, stdout);
717 
718  // Assert
719  // We expect 4 MemoryStateJoinOperation nodes in the graph
720  {
721  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
722  *lambdaExitMergeNode.input(0)->origin());
723  EXPECT_TRUE(joinOperation);
724 
725  EXPECT_EQ(joinNode->input(0)->origin(), loadNode6.output(1));
726  EXPECT_EQ(loadNode6.input(1)->origin(), callExitSplitNode.output(0));
727 
728  EXPECT_EQ(joinNode->input(1)->origin(), loadNode5.output(1));
729  EXPECT_EQ(loadNode5.input(1)->origin(), callExitSplitNode.output(0));
730  }
731 
732  {
733  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
734  *lambdaExitMergeNode.input(1)->origin());
735  EXPECT_TRUE(joinOperation);
736 
737  EXPECT_EQ(joinNode->input(0)->origin(), loadNode8.output(1));
738  EXPECT_EQ(loadNode8.input(1)->origin(), callExitSplitNode.output(1));
739 
740  EXPECT_EQ(joinNode->input(1)->origin(), loadNode7.output(1));
741  EXPECT_EQ(loadNode7.input(1)->origin(), callExitSplitNode.output(1));
742  }
743 
744  {
745  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
746  *callEntryMergeNode.input(0)->origin());
747  EXPECT_TRUE(joinOperation);
748 
749  EXPECT_EQ(joinNode->input(0)->origin(), loadNode2.output(1));
750  EXPECT_EQ(loadNode2.input(1)->origin(), lambdaEntrySplitNode.output(0));
751 
752  EXPECT_EQ(joinNode->input(1)->origin(), loadNode1.output(1));
753  EXPECT_EQ(loadNode1.input(1)->origin(), lambdaEntrySplitNode.output(0));
754  }
755 
756  {
757  auto [joinNode, joinOperation] = TryGetSimpleNodeAndOptionalOp<MemoryStateJoinOperation>(
758  *callEntryMergeNode.input(1)->origin());
759  EXPECT_TRUE(joinOperation);
760 
761  EXPECT_EQ(joinNode->input(0)->origin(), loadNode4.output(1));
762  EXPECT_EQ(loadNode4.input(1)->origin(), lambdaEntrySplitNode.output(1));
763 
764  EXPECT_EQ(joinNode->input(1)->origin(), loadNode3.output(1));
765  EXPECT_EQ(loadNode3.input(1)->origin(), lambdaEntrySplitNode.output(1));
766  }
767 }
768 
769 TEST(LoadChainSeparationTests, DeadOutputs)
770 {
771  // Arrange
772  using namespace jlm::llvm;
773  using namespace jlm::rvsdg;
774  using namespace jlm::util;
775 
776  const auto bit32Type = BitType::Create(32);
777  const auto pointerType = PointerType::Create();
778  const auto memoryStateType = MemoryStateType::Create();
779  const auto ioStateType = IOStateType::Create();
780  const auto valueType = TestType::createValueType();
781  const auto functionType = FunctionType::Create(
782  { pointerType, valueType, ioStateType, memoryStateType },
783  { valueType, ioStateType, memoryStateType });
784 
785  jlm::llvm::LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
786  auto & rvsdg = rvsdgModule.Rvsdg();
787 
788  auto lambdaNode = LambdaNode::Create(
789  rvsdg.GetRootRegion(),
791  auto & addressArgument = *lambdaNode->GetFunctionArguments()[0];
792  auto & valueArgument = *lambdaNode->GetFunctionArguments()[1];
793  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[2];
794  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[3];
795 
796  auto & storeNode = StoreNonVolatileOperation::CreateNode(
797  addressArgument,
798  valueArgument,
799  { &memoryStateArgument },
800  4);
801 
802  auto & loadNode1 =
803  LoadNonVolatileOperation::CreateNode(addressArgument, { storeNode.output(0) }, valueType, 4);
804 
805  auto & loadNode2 =
806  LoadNonVolatileOperation::CreateNode(addressArgument, { loadNode1.output(1) }, valueType, 4);
807 
808  auto undefValue = UndefValueOperation::Create(*lambdaNode->subregion(), memoryStateType);
809 
810  lambdaNode->finalize({
811  loadNode2.output(0),
812  &ioStateArgument,
813  undefValue,
814  });
815 
816  view(rvsdg, stdout);
817 
818  // Act
820  LoadChainSeparation loadChainSeparation;
821  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
822 
823  view(rvsdg, stdout);
824 
825  // Assert
826  EXPECT_TRUE(loadNode1.output(1)->IsDead());
827  EXPECT_EQ(loadNode1.input(1)->origin(), storeNode.output(0));
828  EXPECT_TRUE(loadNode2.output(1)->IsDead());
829  EXPECT_EQ(loadNode2.input(1)->origin(), storeNode.output(0));
830 }
831 
832 TEST(LoadChainSeperationTests, StoreInThetaWithMultipleUsers)
833 {
834  // Arrange
835  using namespace jlm::llvm;
836  using namespace jlm::rvsdg;
837  using namespace jlm::util;
838 
839  const auto bit32Type = BitType::Create(32);
840  const auto controlType = ControlType::Create(2);
841  const auto valueType = TestType::createValueType();
842  const auto memoryStateType = MemoryStateType::Create();
843  const auto ioStateType = IOStateType::Create();
844  const auto pointerType = PointerType::Create();
845  const auto functionType = FunctionType::Create(
846  { ioStateType, memoryStateType },
847  { valueType, ioStateType, memoryStateType });
848 
849  LlvmRvsdgModule rvsdgModule(FilePath(""), "", "");
850  auto & rvsdg = rvsdgModule.Rvsdg();
851 
852  auto lambdaNode = LambdaNode::Create(
853  rvsdg.GetRootRegion(),
855  auto & ioStateArgument = *lambdaNode->GetFunctionArguments()[0];
856  auto & memoryStateArgument = *lambdaNode->GetFunctionArguments()[1];
857 
858  auto sizeNode = TestOperation::createNode(lambdaNode->subregion(), {}, { bit32Type });
859  auto allocaResults = AllocaOperation::create(valueType, sizeNode->output(0), 4);
860 
861  auto lambdaValueNode = TestOperation::createNode(lambdaNode->subregion(), {}, { valueType });
862  auto & lambdaStoreNode = StoreNonVolatileOperation::CreateNode(
863  *allocaResults[0],
864  *lambdaValueNode->output(0),
865  { allocaResults[1] },
866  4);
867 
868  // theta node
869  auto thetaNode = ThetaNode::create(lambdaNode->subregion());
870  auto loopVar1 = thetaNode->AddLoopVar(&memoryStateArgument);
871  auto loopVar2 = thetaNode->AddLoopVar(lambdaStoreNode.output(0));
872 
873  auto thetaAddressNode = TestOperation::createNode(thetaNode->subregion(), {}, { pointerType });
874  auto thetaValueNode = TestOperation::createNode(thetaNode->subregion(), {}, { valueType });
875  auto & thetaLoadNode = LoadNonVolatileOperation::CreateNode(
876  *thetaAddressNode->output(0),
877  { loopVar2.pre },
878  valueType,
879  4);
880  auto & thetaStoreNode = StoreNonVolatileOperation::CreateNode(
881  *thetaAddressNode->output(0),
882  *thetaValueNode->output(0),
883  { thetaLoadNode.output(1) },
884  4);
885 
886  // gamma node
887  auto gammaPredicateNode = TestOperation::createNode(thetaNode->subregion(), {}, { controlType });
888  auto gammaNode = GammaNode::create(gammaPredicateNode->output(0), 2);
889 
890  auto entryVar1 = gammaNode->AddEntryVar(thetaStoreNode.output(0));
891  auto entryVar2 = gammaNode->AddEntryVar(loopVar1.pre);
892 
893  auto gammaAddressNode = TestOperation::createNode(gammaNode->subregion(1), {}, { pointerType });
894  auto gammaValueNode = TestOperation::createNode(gammaNode->subregion(1), {}, { valueType });
895  auto & gammaStoreNode = StoreNonVolatileOperation::CreateNode(
896  *gammaAddressNode->output(0),
897  *gammaValueNode->output(0),
898  { entryVar2.branchArgument[1] },
899  4);
900 
901  auto exitVar1 =
902  gammaNode->AddExitVar({ entryVar1.branchArgument[0], entryVar1.branchArgument[1] });
903  auto exitVar2 = gammaNode->AddExitVar({ entryVar2.branchArgument[0], gammaStoreNode.output(0) });
904  // done with gamma gamma node
905 
906  loopVar1.post->divert_to(exitVar2.output);
907  loopVar2.post->divert_to(thetaStoreNode.output(0));
908  // done with theta node
909 
910  auto & lambdaLoadNode =
911  LoadNonVolatileOperation::CreateNode(*allocaResults[0], { loopVar2.output }, valueType, 4);
912 
913  auto lambdaOutput =
914  lambdaNode->finalize({ lambdaLoadNode.output(0), &ioStateArgument, loopVar1.output });
915 
916  GraphExport::Create(*lambdaOutput, "f");
917 
918  view(rvsdg, stdout);
919 
920  // Act
922  LoadChainSeparation loadChainSeparation;
923  loadChainSeparation.Run(rvsdgModule, statisticsCollector);
924 
925  view(rvsdg, stdout);
926 
927  // Assert
928  // We expect no new join nodes added to any of the regions
929  EXPECT_EQ(lambdaNode->subregion()->numNodes(), 6u);
930  EXPECT_EQ(thetaNode->subregion()->numNodes(), 7u);
931  EXPECT_EQ(gammaNode->subregion(0)->numNodes(), 0u);
932  EXPECT_EQ(gammaNode->subregion(1)->numNodes(), 3u);
933 
934  EXPECT_NE(TryGetOwnerNode<ThetaNode>(*GetMemoryStateRegionResult(*lambdaNode).origin()), nullptr);
935  EXPECT_NE(TryGetOwnerNode<ThetaNode>(*lambdaLoadNode.input(1)->origin()), nullptr);
936 }
static jlm::util::StatisticsCollector statisticsCollector
TEST(LoadChainSeparationTests, LoadNonVolatile)
static std::vector< rvsdg::Output * > create(std::shared_ptr< const rvsdg::Type > allocatedType, rvsdg::Output *count, const size_t alignment)
Definition: alloca.hpp:131
static rvsdg::SimpleNode & CreateNode(rvsdg::Region &region, const std::vector< rvsdg::Output * > &operands, std::vector< MemoryNodeId > memoryNodeIds)
static rvsdg::SimpleNode & CreateNode(rvsdg::Output &operand, std::vector< MemoryNodeId > memoryNodeIds)
static rvsdg::SimpleNode & CreateNode(rvsdg::Region &region, std::unique_ptr< CallOperation > callOperation, const std::vector< rvsdg::Output * > &operands)
Definition: call.hpp:489
static std::shared_ptr< const IOStateType > Create()
Definition: types.cpp:343
static rvsdg::SimpleNode & CreateNode(rvsdg::Output &operand, std::vector< MemoryNodeId > memoryNodeIds)
static rvsdg::SimpleNode & CreateNode(rvsdg::Region &region, const std::vector< rvsdg::Output * > &operands, const std::vector< MemoryNodeId > &memoryNodeIds)
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:84
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static rvsdg::SimpleNode & CreateNode(rvsdg::Region &region, std::unique_ptr< LoadNonVolatileOperation > loadOperation, const std::vector< rvsdg::Output * > &operands)
Definition: Load.hpp:466
static rvsdg::Node::OutputIteratorRange MemoryStateOutputs(const rvsdg::Node &node) noexcept
Definition: Load.hpp:116
static rvsdg::Node::InputIteratorRange MemoryStateInputs(const rvsdg::Node &node) noexcept
Definition: Load.hpp:139
static rvsdg::SimpleNode & CreateNode(rvsdg::Region &region, std::unique_ptr< LoadVolatileOperation > loadOperation, const std::vector< rvsdg::Output * > &operands)
Definition: Load.cpp:388
static rvsdg::Output & IOStateOutput(const rvsdg::Node &node)
Definition: Load.hpp:234
static std::shared_ptr< const MemoryStateType > Create()
Definition: types.cpp:379
static std::shared_ptr< const PointerType > Create()
Definition: types.cpp:45
static rvsdg::SimpleNode & CreateNode(rvsdg::Output &address, rvsdg::Output &value, const std::vector< rvsdg::Output * > &memoryStates, size_t alignment)
Definition: Store.hpp:323
static jlm::rvsdg::Output * Create(rvsdg::Region &region, std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: operators.hpp:1055
static GraphImport & Create(Graph &graph, std::shared_ptr< const rvsdg::Type > type, std::string name)
Definition: graph.cpp:36
Output * origin() const noexcept
Definition: node.hpp:58
Global memory state passed between functions.
rvsdg::Input & GetMemoryStateRegionResult(const rvsdg::LambdaNode &lambdaNode) noexcept
std::string view(const rvsdg::Region *region)
Definition: view.cpp:142