Jlm
PredicateCorrelation.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/rvsdg/delta.hpp>
10 #include <jlm/rvsdg/gamma.hpp>
11 #include <jlm/rvsdg/lambda.hpp>
12 #include <jlm/rvsdg/MatchType.hpp>
13 #include <jlm/rvsdg/Phi.hpp>
15 #include <jlm/rvsdg/theta.hpp>
16 #include <jlm/util/Statistics.hpp>
17 
18 namespace jlm::llvm
19 {
20 
30 static std::optional<std::vector<uint64_t>>
32 {
33  const auto & gammaNode = rvsdg::AssertGetOwnerNode<rvsdg::GammaNode>(gammaOutput);
34 
35  std::vector<uint64_t> alternatives;
36  auto [branchResults, _] = gammaNode.MapOutputExitVar(gammaOutput);
37  for (const auto branchResult : branchResults)
38  {
39  {
40  auto [constantNode, constantOperation] =
41  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::ControlConstantOperation>(
42  *branchResult->origin());
43  if (constantOperation)
44  {
45  alternatives.push_back(constantOperation->value().alternative());
46  continue;
47  }
48  }
49 
50  {
51  auto [constantNode, constantOperation] =
52  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::BitConstantOperation>(
53  *branchResult->origin());
54  if (constantOperation)
55  {
56  alternatives.push_back(constantOperation->value().to_uint());
57  continue;
58  }
59  }
60 
61  {
62  auto [constantNode, constantOperation] =
63  rvsdg::TryGetSimpleNodeAndOptionalOp<IntegerConstantOperation>(*branchResult->origin());
64  if (constantOperation)
65  {
66  alternatives.push_back(constantOperation->Representation().to_uint());
67  continue;
68  }
69  }
70 
71  return std::nullopt;
72  }
73 
74  return alternatives;
75 }
76 
77 static std::optional<std::unique_ptr<ThetaGammaPredicateCorrelation>>
79 {
80  const auto & thetaPredicateOperand = *thetaNode.predicate()->origin();
81  auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(thetaPredicateOperand);
82  if (!gammaNode)
83  {
84  return std::nullopt;
85  }
86 
87  const auto controlAlternativesOpt = extractConstantAlternatives(thetaPredicateOperand);
88  if (!controlAlternativesOpt.has_value())
89  {
90  return std::nullopt;
91  }
92  const auto controlAlternatives = controlAlternativesOpt.value();
93 
95  thetaNode,
96  *gammaNode,
97  controlAlternatives);
98 }
99 
100 static std::optional<std::unique_ptr<ThetaGammaPredicateCorrelation>>
102 {
103  auto [matchNode, matchOperation] =
104  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::MatchOperation>(*thetaNode.predicate()->origin());
105  if (!matchOperation)
106  {
107  return std::nullopt;
108  }
109 
110  const auto & gammaOutput = *matchNode->input(0)->origin();
111  const auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(gammaOutput);
112  if (!gammaNode)
113  {
114  return std::nullopt;
115  }
116 
117  const auto alternativesOpt = extractConstantAlternatives(gammaOutput);
118  if (!alternativesOpt.has_value())
119  {
120  return std::nullopt;
121  }
122  const auto alternatives = alternativesOpt.value();
123 
125  thetaNode,
126  *gammaNode,
127  { matchNode, alternatives });
128 }
129 
130 static std::optional<std::unique_ptr<ThetaGammaPredicateCorrelation>>
132 {
133  auto [matchNode, matchOperation] =
134  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::MatchOperation>(*thetaNode.predicate()->origin());
135  if (!matchOperation)
136  {
137  return std::nullopt;
138  }
139 
140  for (auto & user : matchNode->output(0)->Users())
141  {
142  if (const auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user))
143  {
145  thetaNode,
146  *gammaNode,
147  { matchNode });
148  }
149  }
150 
151  return std::nullopt;
152 }
153 
154 std::optional<std::unique_ptr<ThetaGammaPredicateCorrelation>>
156 {
157  if (auto correlationOpt = computeControlConstantCorrelation(thetaNode))
158  {
159  return correlationOpt;
160  }
161 
162  if (auto correlationOpt = computeMatchConstantCorrelation(thetaNode))
163  {
164  return correlationOpt;
165  }
166 
167  if (auto correlationOpt = computeMatchCorrelation(thetaNode))
168  {
169  return correlationOpt;
170  }
171 
172  return std::nullopt;
173 }
174 
175 static std::optional<std::unique_ptr<GammaGammaPredicateCorrelation>>
177 {
178  auto [matchNode, matchOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::MatchOperation>(
179  *gammaNode1.predicate()->origin());
180  if (!matchOperation)
181  {
182  return std::nullopt;
183  }
184 
185  for (auto & user : matchNode->output(0)->Users())
186  {
187  if (const auto gammaNode2 = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user);
188  gammaNode2 != &gammaNode1)
189  {
191  gammaNode1,
192  *gammaNode2,
193  { matchNode });
194  }
195  }
196 
197  return std::nullopt;
198 }
199 
200 std::optional<std::unique_ptr<GammaGammaPredicateCorrelation>>
202 {
203  if (auto correlationOpt = computeMatchCorrelation(gammaNode))
204  {
205  return correlationOpt;
206  }
207 
208  return std::nullopt;
209 }
210 
211 std::optional<GammaSubregionRoles>
213 {
214  switch (correlation.type())
215  {
217  {
218  const auto controlAlternatives =
219  std::get<ThetaGammaPredicateCorrelation::ControlConstantCorrelationData>(
220  correlation.data());
221  if (controlAlternatives.size() != 2)
222  {
223  return std::nullopt;
224  }
225 
226  GammaSubregionRoles roles;
227  if (controlAlternatives[0] == 0)
228  {
229  roles.exitSubregion = correlation.gammaNode().subregion(0);
230  roles.repetitionSubregion = correlation.gammaNode().subregion(1);
231  }
232  else
233  {
234  roles.exitSubregion = correlation.gammaNode().subregion(1);
235  roles.repetitionSubregion = correlation.gammaNode().subregion(0);
236  }
237 
238  return roles;
239  }
241  {
242  const auto [matchNode, alternatives] =
243  std::get<ThetaGammaPredicateCorrelation::MatchConstantCorrelationData>(correlation.data());
244 
245  if (alternatives.size() != 2)
246  {
247  return std::nullopt;
248  }
249 
250  GammaSubregionRoles roles;
251  const auto matchOperation =
252  util::assertedCast<const rvsdg::MatchOperation>(&matchNode->GetOperation());
253  if (matchOperation->alternative(alternatives[0]) == 0)
254  {
255  roles.exitSubregion = correlation.gammaNode().subregion(0);
256  roles.repetitionSubregion = correlation.gammaNode().subregion(1);
257  }
258  else
259  {
260  roles.exitSubregion = correlation.gammaNode().subregion(1);
261  roles.repetitionSubregion = correlation.gammaNode().subregion(0);
262  }
263 
264  return roles;
265  }
267  {
268  JLM_ASSERT(correlation.gammaNode().nsubregions() == 2);
269 
270  GammaSubregionRoles roles;
271  roles.exitSubregion = correlation.gammaNode().subregion(0);
272  roles.repetitionSubregion = correlation.gammaNode().subregion(1);
273  return roles;
274  }
275  default:
276  return std::nullopt;
277  }
278 }
279 
281 
282 void
283 PredicateCorrelation::correlatePredicatesInRegion(rvsdg::Region & region)
284 {
285  for (auto & node : region.Nodes())
286  {
288  node,
289  [](rvsdg::LambdaNode & lambdaNode)
290  {
291  correlatePredicatesInRegion(*lambdaNode.subregion());
292  },
293  [](rvsdg::PhiNode & phiNode)
294  {
295  correlatePredicatesInRegion(*phiNode.subregion());
296  },
297  [](const rvsdg::DeltaNode &)
298  {
299  // Nothing needs to be done
300  },
301  [](rvsdg::ThetaNode & thetaNode)
302  {
303  // Handle innermost subregions first
304  correlatePredicatesInRegion(*thetaNode.subregion());
305 
306  correlatePredicatesInTheta(thetaNode);
307  },
308  [](rvsdg::GammaNode & gammaNode)
309  {
310  for (auto & subregion : gammaNode.Subregions())
311  {
312  correlatePredicatesInRegion(subregion);
313  }
314  },
315  [](rvsdg::SimpleNode &)
316  {
317  // Nothing needs to be done
318  });
319  }
320 }
321 
322 void
324 {
325  // FIXME: Reevaluate the fix-point computation after we introduced gamma-gamma predicate
326  // correlation. The pattern is a strict top-down pattern, which means that once we resolved the
327  // gamma-gamma predicate correlations, there should only be a single theta-gamma predicate
328  // correlation left, if any. Thus, it might be that the fix-point computation is unnecessary.
329  bool predicateWasRedirected = false;
330  do
331  {
332  predicateWasRedirected = false;
333 
334  const auto correlationOpt = computeThetaGammaPredicateCorrelation(thetaNode);
335  if (!correlationOpt.has_value())
336  {
337  return;
338  }
339  const auto & correlation = correlationOpt.value();
340 
341  switch (correlation->type())
342  {
344  predicateWasRedirected = handleControlConstantCorrelation(*correlation);
345  break;
347  predicateWasRedirected = handleMatchConstantCorrelation(*correlation);
348  break;
350  predicateWasRedirected = false;
351  break;
352  default:
353  throw std::logic_error("Unhandled theta-gamma predicate correlation.");
354  }
355  } while (predicateWasRedirected);
356 }
357 
358 bool
360  const ThetaGammaPredicateCorrelation & correlation)
361 {
363  const auto & gammaNode = correlation.gammaNode();
364  const auto & thetaNode = correlation.thetaNode();
365 
366  const auto controlAlternatives =
367  std::get<ThetaGammaPredicateCorrelation::ControlConstantCorrelationData>(correlation.data());
368  if (controlAlternatives.size() != 2 || controlAlternatives[0] != 0 || controlAlternatives[1] != 1)
369  {
370  return false;
371  }
372 
373  thetaNode.predicate()->divert_to(gammaNode.predicate()->origin());
374  return true;
375 }
376 
377 bool
379  const ThetaGammaPredicateCorrelation & correlation)
380 {
382  const auto & gammaNode = correlation.gammaNode();
383  const auto & thetaNode = correlation.thetaNode();
384 
385  const auto [matchNode, alternatives] =
386  std::get<ThetaGammaPredicateCorrelation::MatchConstantCorrelationData>(correlation.data());
387 
388  if (alternatives.size() != 2)
389  {
390  return false;
391  }
392 
393  const auto matchOperation =
394  util::assertedCast<const rvsdg::MatchOperation>(&matchNode->GetOperation());
395  if (matchOperation->alternative(alternatives[0]) != 0
396  || matchOperation->alternative(alternatives[1]) != 1)
397  {
398  return false;
399  }
400 
401  thetaNode.predicate()->divert_to(gammaNode.predicate()->origin());
402  return true;
403 }
404 
405 void
407 {
409 }
410 
411 }
static std::unique_ptr< GammaGammaPredicateCorrelation > CreateMatchCorrelation(rvsdg::GammaNode &gammaNode1, rvsdg::GammaNode &gammaNode2, MatchCorrelationData correlationData)
static bool handleMatchConstantCorrelation(const ThetaGammaPredicateCorrelation &correlation)
static void correlatePredicatesInRegion(rvsdg::Region &region)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static bool handleControlConstantCorrelation(const ThetaGammaPredicateCorrelation &correlation)
static void correlatePredicatesInTheta(rvsdg::ThetaNode &thetaNode)
~PredicateCorrelation() noexcept override
const CorrelationData & data() const noexcept
rvsdg::GammaNode & gammaNode() const noexcept
static std::unique_ptr< ThetaGammaPredicateCorrelation > CreateControlConstantCorrelation(rvsdg::ThetaNode &thetaNode, rvsdg::GammaNode &gammaNode, ControlConstantCorrelationData data)
rvsdg::ThetaNode & thetaNode() const noexcept
static std::unique_ptr< ThetaGammaPredicateCorrelation > CreateMatchConstantCorrelation(rvsdg::ThetaNode &thetaNode, rvsdg::GammaNode &gammaNode, MatchConstantCorrelationData data)
static std::unique_ptr< ThetaGammaPredicateCorrelation > CreateMatchCorrelation(rvsdg::ThetaNode &thetaNode, rvsdg::GammaNode &gammaNode, MatchCorrelationData data)
Delta node.
Definition: delta.hpp:129
Conditional operator / pattern matching.
Definition: gamma.hpp:99
rvsdg::Input * predicate() const noexcept
Definition: gamma.hpp:398
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
size_t nsubregions() const noexcept
rvsdg::Region * subregion(size_t index) const noexcept
RegionResult * predicate() const noexcept
Definition: theta.hpp:85
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
static std::optional< std::unique_ptr< ThetaGammaPredicateCorrelation > > computeMatchCorrelation(rvsdg::ThetaNode &thetaNode)
std::optional< std::unique_ptr< ThetaGammaPredicateCorrelation > > computeThetaGammaPredicateCorrelation(rvsdg::ThetaNode &thetaNode)
static std::optional< std::unique_ptr< ThetaGammaPredicateCorrelation > > computeControlConstantCorrelation(rvsdg::ThetaNode &thetaNode)
static std::optional< std::unique_ptr< ThetaGammaPredicateCorrelation > > computeMatchConstantCorrelation(rvsdg::ThetaNode &thetaNode)
std::optional< std::unique_ptr< GammaGammaPredicateCorrelation > > computeGammaGammaPredicateCorrelation(rvsdg::GammaNode &gammaNode)
static std::optional< std::vector< uint64_t > > extractConstantAlternatives(const rvsdg::Output &gammaOutput)
std::optional< GammaSubregionRoles > determineGammaSubregionRoles(const ThetaGammaPredicateCorrelation &correlation)
void MatchTypeOrFail(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.