Jlm
PredicateCorrelationTests.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #include <gtest/gtest.h>
7 
11 #include <jlm/rvsdg/control.hpp>
13 #include <jlm/rvsdg/theta.hpp>
14 #include <jlm/rvsdg/view.hpp>
15 
17 {
21  std::vector<jlm::rvsdg::Node *> controlConstants{};
22 };
23 
26  jlm::rvsdg::Graph & rvsdg,
27  const std::pair<uint64_t, uint64_t> & gammaSubregionControlConstants)
28 {
29  using namespace jlm::llvm;
30  using namespace jlm::rvsdg;
31 
32  auto bitType32 = BitType::Create(32);
33  auto controlType = ControlType::Create(2);
34 
35  auto thetaNode = ThetaNode::create(&rvsdg.GetRootRegion());
36 
37  auto dummy = TestOperation::createNode(thetaNode->subregion(), {}, { bitType32 })->output(0);
38  auto predicate = MatchOperation::Create(*dummy, { { 1, 1 } }, 0, 2);
39 
40  auto gammaNode = GammaNode::create(predicate, 2);
41 
42  auto controlConstant0 = &ControlConstantOperation::create(
43  *gammaNode->subregion(0),
44  2,
45  gammaSubregionControlConstants.first);
46  auto controlConstant1 = &ControlConstantOperation::create(
47  *gammaNode->subregion(1),
48  2,
49  gammaSubregionControlConstants.second);
50 
51  auto controlExitVar = gammaNode->AddExitVar({ controlConstant0, controlConstant1 });
52 
53  thetaNode->predicate()->divert_to(controlExitVar.output);
54 
55  return { *gammaNode,
56  *thetaNode,
57  *TryGetOwnerNode<Node>(*predicate),
58  { TryGetOwnerNode<Node>(*controlConstant0), TryGetOwnerNode<Node>(*controlConstant1) } };
59 }
60 
62 {
66 };
67 
70  jlm::rvsdg::Graph & rvsdg,
71  const std::pair<int64_t, int64_t> & gammaSubregionAlternatives)
72 {
73  using namespace jlm::llvm;
74  using namespace jlm::rvsdg;
75 
76  auto bitType32 = BitType::Create(32);
77  auto controlType = ControlType::Create(2);
78 
79  auto thetaNode = ThetaNode::create(&rvsdg.GetRootRegion());
80 
81  auto predicate =
82  TestOperation::createNode(thetaNode->subregion(), {}, { controlType })->output(0);
83  auto gammaNode = GammaNode::create(predicate, 2);
84 
85  auto constant0 = &BitConstantOperation::create(
86  *gammaNode->subregion(0),
87  { 64, gammaSubregionAlternatives.first });
88  auto constant1 = &BitConstantOperation::create(
89  *gammaNode->subregion(1),
90  { 64, gammaSubregionAlternatives.second });
91 
92  auto exitVar = gammaNode->AddExitVar({ constant0, constant1 });
93 
94  auto & matchNode = MatchOperation::CreateNode(*exitVar.output, { { 1, 1 } }, 0, 2);
95 
96  thetaNode->predicate()->divert_to(matchNode.output(0));
97 
98  return { *gammaNode, *thetaNode, matchNode };
99 }
100 
102 {
106 };
107 
110 {
111  using namespace jlm::llvm;
112  using namespace jlm::rvsdg;
113 
114  auto bitType32 = BitType::Create(32);
115  auto controlType = ControlType::Create(2);
116 
117  auto thetaNode = ThetaNode::create(&rvsdg.GetRootRegion());
118 
119  auto constantNode = TestOperation::createNode(thetaNode->subregion(), {}, { bitType32 });
120  auto & matchNode = MatchOperation::CreateNode(*constantNode->output(0), { { 1, 1 } }, 0, 2);
121 
122  auto gammaNode = GammaNode::create(matchNode.output(0), 2);
123 
124  thetaNode->predicate()->divert_to(matchNode.output(0));
125 
126  return { *gammaNode, *thetaNode, matchNode };
127 }
128 
129 TEST(PredicateCorrelationTests, testControlConstantCorrelation)
130 {
131  // Arrange
132  using namespace jlm::llvm;
133  using namespace jlm::rvsdg;
134 
135  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
136  auto & rvsdg = rvsdgModule->Rvsdg();
137 
138  auto [gammaNode, thetaNode, matchNode, controlConstants] =
139  setupControlConstantCorrelationTest(rvsdg, { 0, 1 });
140 
141  view(rvsdg, stdout);
142 
143  // Act
145  PredicateCorrelation predicateCorrelation;
146  predicateCorrelation.Run(*rvsdgModule, statisticsCollector);
147 
148  thetaNode.subregion()->prune(true);
149 
150  view(rvsdg, stdout);
151 
152  // Assert
153  EXPECT_EQ(thetaNode.subregion()->numNodes(), 2u);
154  EXPECT_EQ(thetaNode.predicate()->origin(), matchNode.output(0));
155 }
156 
157 TEST(PredicateCorrelationTests, testMatchConstantCorrelationDetection)
158 {
159  // Arrange
160  using namespace jlm::llvm;
161  using namespace jlm::rvsdg;
162 
163  const std::vector<std::pair<uint64_t, uint64_t>> gammaSubregionAlternatives = { { 0, 1 },
164  { 1, 0 } };
165  for (auto alternatives : gammaSubregionAlternatives)
166  {
167  // Arrange
168  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
169  auto & rvsdg = rvsdgModule->Rvsdg();
170 
171  auto [gammaNode, thetaNode, matchNode] = setupMatchConstantCorrelationTest(rvsdg, alternatives);
172 
173  view(rvsdg, stdout);
174 
175  // Act
176  const auto correlationOpt = computeThetaGammaPredicateCorrelation(thetaNode);
177 
178  // Assert
179  EXPECT_NE(correlationOpt.value(), nullptr);
180  EXPECT_EQ(correlationOpt.value()->type(), CorrelationType::MatchConstantCorrelation);
181  EXPECT_EQ(&correlationOpt.value()->thetaNode(), &thetaNode);
182  EXPECT_EQ(&correlationOpt.value()->gammaNode(), &gammaNode);
183 
184  const auto correlationData =
185  std::get<ThetaGammaPredicateCorrelation::MatchConstantCorrelationData>(
186  correlationOpt.value()->data());
187  EXPECT_EQ(correlationData.matchNode, &matchNode);
188  EXPECT_EQ(correlationData.alternatives.size(), 2u);
189  EXPECT_EQ(correlationData.alternatives[0], alternatives.first);
190  EXPECT_EQ(correlationData.alternatives[1], alternatives.second);
191  }
192 }
193 
194 TEST(PredicateCorrelationTests, testMatchConstantCorrelation_Success)
195 {
196  // Arrange
197  using namespace jlm::llvm;
198  using namespace jlm::rvsdg;
199 
200  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
201  auto & rvsdg = rvsdgModule->Rvsdg();
202 
203  auto [gammaNode, thetaNode, _] = setupMatchConstantCorrelationTest(rvsdg, { 0, 1 });
204  auto gammaPredicate = gammaNode.predicate()->origin();
205  view(rvsdg, stdout);
206 
207  // Act
209  PredicateCorrelation predicateCorrelation;
210  predicateCorrelation.Run(*rvsdgModule, statisticsCollector);
211 
212  thetaNode.subregion()->prune(true);
213 
214  view(rvsdg, stdout);
215 
216  // Assert
217  EXPECT_EQ(thetaNode.subregion()->numNodes(), 1u);
218  EXPECT_EQ(thetaNode.predicate()->origin(), gammaPredicate);
219 }
220 
221 TEST(PredicateCorrelationTests, testMatchConstantCorrelation_Failure)
222 {
223  // Arrange
224  using namespace jlm::llvm;
225  using namespace jlm::rvsdg;
226 
227  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
228  auto & rvsdg = rvsdgModule->Rvsdg();
229 
230  auto gammaSubregionAlternatives = std::make_pair(1, 0);
231  auto [gammaNode, thetaNode, _] =
232  setupMatchConstantCorrelationTest(rvsdg, gammaSubregionAlternatives);
233  auto gammaPredicate = gammaNode.predicate()->origin();
234  view(rvsdg, stdout);
235 
236  // Act
238  PredicateCorrelation predicateCorrelation;
239  predicateCorrelation.Run(*rvsdgModule, statisticsCollector);
240 
241  thetaNode.subregion()->prune(true);
242 
243  view(rvsdg, stdout);
244 
245  // Assert
246  // The theta node predicate is not redirected as the gamma subregion alternatives do not lead to
247  // the same control behavior as the match node that is currently connected to the theta node
248  // predicate. It would be necessary to create a new match node for this instead of just reusing
249  // the gamma node's control predicate.
250  EXPECT_NE(thetaNode.predicate()->origin(), gammaPredicate);
251 }
252 
253 TEST(PredicateCorrelationTests, testThetaGammaMatchCorrelationDetection)
254 {
255  // Arrange
256  using namespace jlm::llvm;
257  using namespace jlm::rvsdg;
258 
259  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
260  auto & rvsdg = rvsdgModule->Rvsdg();
261 
262  auto [gammaNode, thetaNode, matchNode] = setupThetaGammaMatchCorrelationTest(rvsdg);
263 
264  view(rvsdg, stdout);
265 
266  // Act
267  const auto correlationOpt = computeThetaGammaPredicateCorrelation(thetaNode);
268 
269  // Assert
270  EXPECT_NE(correlationOpt.value(), nullptr);
271  EXPECT_EQ(correlationOpt.value()->type(), CorrelationType::MatchCorrelation);
272  EXPECT_EQ(&correlationOpt.value()->thetaNode(), &thetaNode);
273  EXPECT_EQ(&correlationOpt.value()->gammaNode(), &gammaNode);
274 
275  const auto correlationData = std::get<ThetaGammaPredicateCorrelation::MatchCorrelationData>(
276  correlationOpt.value()->data());
277  EXPECT_EQ(correlationData.matchNode, &matchNode);
278 }
279 
280 TEST(PredicateCorrelationTests, testThetaGammaCorrelationFixPoint)
281 {
282  // Arrange
283  using namespace jlm::llvm;
284  using namespace jlm::rvsdg;
285 
286  auto bitType32 = BitType::Create(32);
287  auto controlType = ControlType::Create(2);
288 
289  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
290  auto & rvsdg = rvsdgModule->Rvsdg();
291 
292  auto thetaNode = ThetaNode::create(&rvsdg.GetRootRegion());
293 
294  // Arrange first gamma node
295  auto predicate =
296  TestOperation::createNode(thetaNode->subregion(), {}, { controlType })->output(0);
297  auto gammaNode1 = GammaNode::create(predicate, 2);
298 
299  auto constant0 = &BitConstantOperation::create(*gammaNode1->subregion(0), { 64, 0 });
300  auto constant1 = &BitConstantOperation::create(*gammaNode1->subregion(1), { 64, 1 });
301 
302  auto exitVar = gammaNode1->AddExitVar({ constant0, constant1 });
303  auto & matchNode = MatchOperation::CreateNode(*exitVar.output, { { 1, 1 } }, 0, 2);
304 
305  // Arrange second gamma node
306  auto gammaNode2 = GammaNode::create(matchNode.output(0), 2);
307 
308  auto controlConstant0 = &ControlConstantOperation::create(*gammaNode2->subregion(0), 2, 0);
309  auto controlConstant1 = &ControlConstantOperation::create(*gammaNode2->subregion(1), 2, 1);
310 
311  auto controlExitVar = gammaNode2->AddExitVar({ controlConstant0, controlConstant1 });
312 
313  thetaNode->predicate()->divert_to(controlExitVar.output);
314 
315  // Act
317  PredicateCorrelation predicateCorrelation;
318  predicateCorrelation.Run(*rvsdgModule, statisticsCollector);
319 
320  thetaNode->subregion()->prune(true);
321 
322  view(rvsdg, stdout);
323 
324  // Assert
325  EXPECT_EQ(thetaNode->subregion()->numNodes(), 1u);
326  EXPECT_EQ(thetaNode->predicate()->origin(), predicate);
327 }
328 
329 TEST(PredicateCorrelationTests, testDetermineGammaSubregionRoles_ControlConstantCorrelation)
330 {
331  using namespace jlm::llvm;
332  using namespace jlm::rvsdg;
333 
334  {
335  // Arrange
336  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
337  auto & rvsdg = rvsdgModule->Rvsdg();
338 
339  constexpr std::pair<uint64_t, uint64_t> controlAlternatives = { 0, 1 };
340  auto [gammaNode, thetaNode, matchNode, controlConstants] =
341  setupControlConstantCorrelationTest(rvsdg, controlAlternatives);
342 
344  thetaNode,
345  gammaNode,
346  { controlAlternatives.first, controlAlternatives.second });
347 
348  // Act
349  const auto gammaSubregionRoles = determineGammaSubregionRoles(*correlation);
350 
351  // Assert
352  EXPECT_EQ(gammaSubregionRoles->exitSubregion, gammaNode.subregion(0));
353  EXPECT_EQ(gammaSubregionRoles->repetitionSubregion, gammaNode.subregion(1));
354  }
355 
356  {
357  // Arrange
358  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
359  auto & rvsdg = rvsdgModule->Rvsdg();
360 
361  constexpr std::pair<uint64_t, uint64_t> controlAlternatives = { 1, 0 };
362  auto [gammaNode, thetaNode, matchNode, controlConstants] =
363  setupControlConstantCorrelationTest(rvsdg, controlAlternatives);
364 
366  thetaNode,
367  gammaNode,
368  { controlAlternatives.first, controlAlternatives.second });
369 
370  // Act
371  const auto gammaSubregionRoles = determineGammaSubregionRoles(*correlation);
372 
373  // Assert
374  EXPECT_EQ(gammaSubregionRoles->exitSubregion, gammaNode.subregion(1));
375  EXPECT_EQ(gammaSubregionRoles->repetitionSubregion, gammaNode.subregion(0));
376  }
377 }
378 
379 TEST(PredicateCorrelationTests, testDetermineGammaSubregionRoles_MatchConstantCorrelation)
380 {
381  using namespace jlm::llvm;
382  using namespace jlm::rvsdg;
383 
384  {
385  // Arrange
386  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
387  auto & rvsdg = rvsdgModule->Rvsdg();
388 
389  constexpr std::pair<uint64_t, uint64_t> gammaSubregionAlternatives = { 0, 1 };
390  auto [gammaNode, thetaNode, matchNode] =
391  setupMatchConstantCorrelationTest(rvsdg, gammaSubregionAlternatives);
392 
394  thetaNode,
395  gammaNode,
396  { &matchNode, { gammaSubregionAlternatives.first, gammaSubregionAlternatives.second } });
397 
398  // Act
399  const auto gammaSubregionRoles = determineGammaSubregionRoles(*correlation);
400 
401  // Assert
402  EXPECT_EQ(gammaSubregionRoles->exitSubregion, gammaNode.subregion(0));
403  EXPECT_EQ(gammaSubregionRoles->repetitionSubregion, gammaNode.subregion(1));
404  }
405 
406  {
407  // Arrange
408  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
409  auto & rvsdg = rvsdgModule->Rvsdg();
410 
411  constexpr std::pair<uint64_t, uint64_t> gammaSubregionAlternatives = { 1, 0 };
412  auto [gammaNode, thetaNode, matchNode] =
413  setupMatchConstantCorrelationTest(rvsdg, gammaSubregionAlternatives);
414 
416  thetaNode,
417  gammaNode,
418  { &matchNode, { gammaSubregionAlternatives.first, gammaSubregionAlternatives.second } });
419 
420  // Act
421  const auto gammaSubregionRoles = determineGammaSubregionRoles(*correlation);
422 
423  // Assert
424  EXPECT_EQ(gammaSubregionRoles->exitSubregion, gammaNode.subregion(1));
425  EXPECT_EQ(gammaSubregionRoles->repetitionSubregion, gammaNode.subregion(0));
426  }
427 }
428 
429 TEST(PredicateCorrelationTests, testDetermineGammaSubregionRoles_MatchCorrelation)
430 {
431  using namespace jlm::llvm;
432  using namespace jlm::rvsdg;
433  // Arrange
434  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
435  auto & rvsdg = rvsdgModule->Rvsdg();
436 
437  auto [gammaNode, thetaNode, matchNode] = setupThetaGammaMatchCorrelationTest(rvsdg);
438 
439  const auto correlation =
440  ThetaGammaPredicateCorrelation::CreateMatchCorrelation(thetaNode, gammaNode, { &matchNode });
441 
442  // Act
443  const auto gammaSubregionRoles = determineGammaSubregionRoles(*correlation);
444 
445  // Assert
446  EXPECT_EQ(gammaSubregionRoles->exitSubregion, gammaNode.subregion(0));
447  EXPECT_EQ(gammaSubregionRoles->repetitionSubregion, gammaNode.subregion(1));
448 }
449 
450 TEST(PredicateCorrelationTests, testGammaGammaMatchCorrelationDetection)
451 {
452  // Arrange
453  using namespace jlm::llvm;
454  using namespace jlm::rvsdg;
455 
456  auto bitType32 = BitType::Create(32);
457  auto controlType = ControlType::Create(2);
458 
459  auto rvsdgModule = jlm::llvm::LlvmRvsdgModule::Create(jlm::util::FilePath(""), "", "");
460  auto & rvsdg = rvsdgModule->Rvsdg();
461 
462  auto constantNode = TestOperation::createNode(&rvsdg.GetRootRegion(), {}, { bitType32 });
463  auto & matchNode = MatchOperation::CreateNode(*constantNode->output(0), { { 1, 1 } }, 0, 2);
464 
465  auto gammaNode1 = GammaNode::create(matchNode.output(0), 2);
466 
467  auto gammaNode2 = GammaNode::create(matchNode.output(0), 2);
468 
469  view(rvsdg, stdout);
470 
471  // Act
472  const auto correlationOpt = computeGammaGammaPredicateCorrelation(*gammaNode1);
473 
474  // Assert
475  EXPECT_NE(correlationOpt.value(), nullptr);
476  EXPECT_EQ(correlationOpt.value()->type(), CorrelationType::MatchCorrelation);
477  EXPECT_EQ(&correlationOpt.value()->gammaNode1(), gammaNode1);
478  EXPECT_EQ(&correlationOpt.value()->gammaNode2(), gammaNode2);
479 
480  const auto correlationData = std::get<GammaGammaPredicateCorrelation::MatchCorrelationData>(
481  correlationOpt.value()->correlationData());
482  EXPECT_EQ(correlationData.matchNode, &matchNode);
483 }
static jlm::util::StatisticsCollector statisticsCollector
static ControlConstantCorrelationTest setupControlConstantCorrelationTest(jlm::rvsdg::Graph &rvsdg, const std::pair< uint64_t, uint64_t > &gammaSubregionControlConstants)
static MatchConstantCorrelationTest setupMatchConstantCorrelationTest(jlm::rvsdg::Graph &rvsdg, const std::pair< int64_t, int64_t > &gammaSubregionAlternatives)
static MatchCorrelationTest setupThetaGammaMatchCorrelationTest(jlm::rvsdg::Graph &rvsdg)
TEST(PredicateCorrelationTests, testControlConstantCorrelation)
static std::unique_ptr< LlvmRvsdgModule > Create(const util::FilePath &sourceFileName, const std::string &targetTriple, const std::string &dataLayout)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static std::unique_ptr< ThetaGammaPredicateCorrelation > CreateControlConstantCorrelation(rvsdg::ThetaNode &thetaNode, rvsdg::GammaNode &gammaNode, ControlConstantCorrelationData data)
static std::unique_ptr< ThetaGammaPredicateCorrelation > CreateMatchConstantCorrelation(rvsdg::ThetaNode &thetaNode, rvsdg::GammaNode &gammaNode, MatchConstantCorrelationData data)
static std::unique_ptr< ThetaGammaPredicateCorrelation > CreateMatchCorrelation(rvsdg::ThetaNode &thetaNode, rvsdg::GammaNode &gammaNode, MatchCorrelationData data)
Conditional operator / pattern matching.
Definition: gamma.hpp:99
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Global memory state passed between functions.
std::optional< std::unique_ptr< ThetaGammaPredicateCorrelation > > computeThetaGammaPredicateCorrelation(rvsdg::ThetaNode &thetaNode)
std::optional< std::unique_ptr< GammaGammaPredicateCorrelation > > computeGammaGammaPredicateCorrelation(rvsdg::GammaNode &gammaNode)
std::optional< GammaSubregionRoles > determineGammaSubregionRoles(const ThetaGammaPredicateCorrelation &correlation)
std::string view(const rvsdg::Region *region)
Definition: view.cpp:142
std::vector< jlm::rvsdg::Node * > controlConstants
jlm::rvsdg::GammaNode & gammaNode
jlm::rvsdg::ThetaNode & thetaNode