Jlm
LoadChainSeparation.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
15 #include <jlm/rvsdg/delta.hpp>
16 #include <jlm/rvsdg/gamma.hpp>
17 #include <jlm/rvsdg/lambda.hpp>
18 #include <jlm/rvsdg/MatchType.hpp>
19 #include <jlm/rvsdg/Phi.hpp>
22 #include <jlm/rvsdg/theta.hpp>
23 #include <jlm/rvsdg/traverser.hpp>
24 
25 namespace jlm::llvm
26 {
27 
29 {
30 public:
32  {
34  };
35 
36  bool
37  hasModRefChainLinkType(const rvsdg::Output & output) const noexcept
38  {
39  return Types_.find(&output) != Types_.end();
40  }
41 
42  void
43  add(const rvsdg::Output & output, const ModRefChainLink::Type & type)
44  {
45  JLM_ASSERT(is<MemoryStateType>(output.Type()));
47  Types_[&output] = type;
48  }
49 
51  getModRefChainLinkType(const rvsdg::Output & output) const
52  {
54  return Types_.find(&output)->second;
55  }
56 
57  void
59  const rvsdg::Output & output,
60  ModRefChainInformation modRefChainInformation)
61  {
62  auto & outputMap = getOrInsertModRefChainInformationMap(*output.region());
63  JLM_ASSERT(outputMap.find(&output) == outputMap.end());
64  outputMap[&output] = std::move(modRefChainInformation);
65  }
66 
67  std::optional<ModRefChainInformation>
69  {
70  const auto regionMapIt = RegionMap_.find(output.region());
71  if (regionMapIt == RegionMap_.end())
72  {
73  return std::nullopt;
74  }
75 
76  auto & outputMap = regionMapIt->second;
77  const auto outputMapIt = outputMap.find(&output);
78  if (outputMapIt == outputMap.end())
79  {
80  return std::nullopt;
81  }
82 
83  return outputMapIt->second;
84  }
85 
86  void
88  {
89  RegionMap_.erase(&region);
90  }
91 
92  static std::unique_ptr<Context>
94  {
95  return std::make_unique<Context>();
96  }
97 
98 private:
100  std::unordered_map<const rvsdg::Output *, ModRefChainInformation>;
101 
104  {
105  if (const auto it = RegionMap_.find(&region); it != RegionMap_.end())
106  {
107  return it->second;
108  }
109 
110  return RegionMap_.emplace(&region, ModRefChainInformationMap()).first->second;
111  }
112 
113  std::unordered_map<const rvsdg::Output *, ModRefChainLink::Type> Types_{};
114  std::unordered_map<const rvsdg::Region *, ModRefChainInformationMap> RegionMap_{};
115 };
116 
117 LoadChainSeparation::~LoadChainSeparation() noexcept = default;
118 
120  : Transformation("LoadChainSeparation")
121 {}
122 
123 void
125 {
127 
129 }
130 
131 void
133 {
134  util::HashSet<rvsdg::Output *> visitedOutputs;
135 
136  // We require a top-down traverser to ensure that lambda nodes are handled before call nodes
137  for (const auto & node : rvsdg::TopDownTraverser(&region))
138  {
140  *node,
141  [&](rvsdg::LambdaNode & lambdaNode)
142  {
144  },
145  [&](rvsdg::PhiNode & phiNode)
146  {
147  separateReferenceChainsInRegion(*phiNode.subregion());
148  },
149  [&](rvsdg::GammaNode & gammaNode)
150  {
151  separateRefenceChainsInGamma(gammaNode);
152  },
153  [&](rvsdg::ThetaNode & thetaNode)
154  {
155  separateRefenceChainsInTheta(thetaNode);
156  },
157  [](rvsdg::DeltaNode &)
158  {
159  // Nothing needs to be done
160  },
161  [&](rvsdg::SimpleNode & simpleNode)
162  {
163  for (auto & output : simpleNode.Outputs())
164  {
165  if (output.IsDead() && is<MemoryStateType>(output.Type()))
166  {
167  // Dead memory state outputs will never be reachable from structural node results.
168  // Thus, we need to handle them here in order to separate all reference chains.
169  separateReferenceChains(output);
170  }
171  }
172  },
173  [&]()
174  {
175  throw std::logic_error(util::strfmt("Unhandled node type: ", node->DebugString()));
176  });
177  }
178 }
179 
180 void
182 {
183  // Handle innermost regions first
185 
187 
188  // We are done with the lambda subregion.
189  // Clean up all information we temporarily stored.
190  Context_->dropModRefChainInformation(*lambdaNode.subregion());
191 }
192 
193 void
195 {
196  // Handle innermost regions first
197  for (auto & subregion : gammaNode.Subregions())
198  {
200  }
201 
202  for (auto & [branchResults, output] : gammaNode.GetExitVars())
203  {
204  if (is<MemoryStateType>(output->Type()))
205  {
206  for (const auto branchResult : branchResults)
207  {
208  separateReferenceChains(*branchResult->origin());
209  }
210  }
211  }
212 
213  // We are done with the gamma subregions.
214  // Clean up all information we temporarily stored.
215  for (auto & subregion : gammaNode.Subregions())
216  {
217  Context_->dropModRefChainInformation(subregion);
218  }
219 }
220 
221 void
223 {
224  // Handle innermost region first
226 
227  util::HashSet<rvsdg::Output *> visitedOutputsSubregion;
228  for (const auto loopVar : thetaNode.GetLoopVars())
229  {
230  if (!is<MemoryStateType>(loopVar.output->Type()))
231  continue;
232 
233  // Separate reference chains in theta subregion
234  const auto hasModificationChainLinkAboveInRegion =
235  separateReferenceChains(*loopVar.post->origin());
236  Context_->add(
237  *loopVar.output,
238  hasModificationChainLinkAboveInRegion ? ModRefChainLink::Type::Modification
240 
241  // Handle dead theta outputs
242  if (loopVar.output->IsDead())
243  {
244  separateReferenceChains(*loopVar.output);
245  }
246  }
247 
248  // We are done with the theta subregion.
249  // Clean up all information we temporarily stored.
250  Context_->dropModRefChainInformation(*thetaNode.subregion());
251 }
252 
253 bool
255 {
256  JLM_ASSERT(is<MemoryStateType>(startOutput.Type()));
257 
258  ModRefChainSummary summary;
259  const bool hasModRefChainLinkAboveInRegion = traceModRefChains(startOutput, summary);
260  for (auto & modRefChain : summary.modRefChains)
261  {
262  const auto refSubchains = extractReferenceSubchains(modRefChain);
263  for (const auto & [links] : refSubchains)
264  {
265  // Divert the operands of the respective inputs for each encountered reference node and
266  // collect join operands
267  std::vector<rvsdg::Output *> joinOperands;
268  const auto newMemoryStateOperand = mapMemoryStateOutputToInput(*links.back().output).origin();
269  for (auto [linkOutput, linkModRefType] : links)
270  {
271  JLM_ASSERT(linkModRefType == ModRefChainLink::Type::Reference);
272  auto & modRefChainInput = mapMemoryStateOutputToInput(*linkOutput);
273  modRefChainInput.divert_to(newMemoryStateOperand);
274  joinOperands.push_back(linkOutput);
275  }
276 
277  // Create join node and divert the current memory state output
278  if (!links.front().output->IsDead())
279  {
280  auto & joinNode = MemoryStateJoinOperation::CreateNode(joinOperands);
281  links.front().output->divertUsersWhere(
282  *joinNode.output(0),
283  [&joinNode](const rvsdg::Input & user)
284  {
285  return rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user) != &joinNode;
286  });
287  }
288  }
289  }
290 
291  return hasModRefChainLinkAboveInRegion;
292 }
293 
294 rvsdg::Input &
296 {
297  if (auto [loadNode, loadOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<LoadOperation>(output);
298  loadOperation)
299  {
301  }
302 
303  if (const auto thetaNode = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
304  {
305  return *thetaNode->MapOutputLoopVar(output).input;
306  }
307 
308  throw std::logic_error("Unhandled node type!");
309 }
310 
311 std::vector<LoadChainSeparation::ModRefChain>
313 {
314  std::vector<ModRefChain> refSubchains;
315  for (auto linkIt = modRefChain.links.begin(); linkIt != modRefChain.links.end();)
316  {
317  if (linkIt->type != ModRefChainLink::Type::Reference)
318  {
319  // The current link is not a reference. Let's continue with the next one.
320  ++linkIt;
321  continue;
322  }
323 
324  auto nextLinkIt = std::next(linkIt);
325  if (nextLinkIt == modRefChain.links.end()
326  || nextLinkIt->type != ModRefChainLink::Type::Reference)
327  {
328  // We only want to separate reference chains with at least two links
329  ++linkIt;
330  continue;
331  }
332 
333  // We found a new reference subchain. Let's grab all the links
334  refSubchains.push_back({});
335  while (linkIt != modRefChain.links.end() && linkIt->type == ModRefChainLink::Type::Reference)
336  {
337  refSubchains.back().links.push_back(*linkIt);
338  ++linkIt;
339  }
340  }
341 
342  return refSubchains;
343 }
344 
345 bool
347 {
348  JLM_ASSERT(is<MemoryStateType>(startOutput.Type()));
349 
350  if (const auto modRefChainInformationOpt = Context_->tryGetModRefChainInformation(startOutput))
351  {
352  // This output was visited before.
353  return modRefChainInformationOpt.value().hasModificationChainLinkAboveInRegion;
354  }
355 
356  ModRefChain currentModRefChain;
357  rvsdg::Output * currentOutput = &startOutput;
358  bool doneTracing = false;
359  bool hasModRefChainLinkAboveInRegion = false;
360  do
361  {
362  if (rvsdg::TryGetOwnerRegion(*currentOutput))
363  {
364  // We have a region argument. Stop tracing.
365  break;
366  }
367 
368  auto & node = rvsdg::AssertGetOwnerNode<rvsdg::Node>(*currentOutput);
370  node,
371  [&](const rvsdg::GammaNode & gammaNode)
372  {
373  // FIXME: I really would like that state edges through gammas would be recognized as
374  // either modifying or just referencing. However, we would need to know what the
375  // operations in the gamma on all branches are and which memory state exit variable maps
376  // to which memory state entry variable. We need some more machinery for it first before
377  // we can do that.
378  hasModRefChainLinkAboveInRegion = true;
379  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
380  for (auto [entryVarInput, _] : gammaNode.GetEntryVars())
381  {
382  if (is<MemoryStateType>(entryVarInput->Type()))
383  {
384  hasModRefChainLinkAboveInRegion |=
385  traceModRefChains(*entryVarInput->origin(), summary);
386  }
387  }
388  doneTracing = true;
389  },
390  [&](const rvsdg::ThetaNode &)
391  {
392  const auto modRefChainLinkType = Context_->getModRefChainLinkType(*currentOutput);
393  hasModRefChainLinkAboveInRegion |=
394  modRefChainLinkType == ModRefChainLink::Type::Modification;
395  currentModRefChain.add({ currentOutput, modRefChainLinkType });
396  currentOutput = mapMemoryStateOutputToInput(*currentOutput).origin();
397  },
398  [&](const rvsdg::SimpleNode & simpleNode)
399  {
400  auto & operation = simpleNode.GetOperation();
402  operation,
403  [&](const LoadOperation &)
404  {
405  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Reference });
406  currentOutput = LoadOperation::MapMemoryStateOutputToInput(*currentOutput).origin();
407  },
408  [&](const StoreOperation &)
409  {
410  hasModRefChainLinkAboveInRegion = true;
411  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
412  currentOutput =
414  },
415  [&](const FreeOperation &)
416  {
417  hasModRefChainLinkAboveInRegion = true;
418  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
419  currentOutput = FreeOperation::mapMemoryStateOutputToInput(*currentOutput).origin();
420  },
421  [&](const MemCpyOperation &)
422  {
423  // FIXME: We really would like to know here which memory state belongs to the source
424  // and which to the dst address. This would allow us to be more precise in the
425  // separation.
426  hasModRefChainLinkAboveInRegion = true;
427  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
428  currentOutput =
430  },
431  [&](const MemSetOperation &)
432  {
433  hasModRefChainLinkAboveInRegion = true;
434  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
435  currentOutput =
437  },
438  [&](const CallOperation &)
439  {
440  // FIXME: I really would like that state edges through calls would be recognized as
441  // either modifying or just referencing.
442  traceModRefChains(*CallOperation::GetMemoryStateInput(node).origin(), summary);
443  doneTracing = true;
444  },
446  {
447  for (auto & nodeInput : node.Inputs())
448  {
449  hasModRefChainLinkAboveInRegion |=
450  traceModRefChains(*nodeInput.origin(), summary);
451  }
452  doneTracing = true;
453  },
455  {
456  // LambdaEntryMemoryStateSplitOperation nodes should always be connected to a lambda
457  // argument. In other words, this is as far as we can trace in the graph. Just
458  // return what we found so far.
459  doneTracing = true;
460  },
462  {
463  // FIXME: I really would like that state edges through calls would be recognized as
464  // either modifying or just referencing.
465  hasModRefChainLinkAboveInRegion |=
466  traceModRefChains(*node.input(0)->origin(), summary);
467  doneTracing = true;
468  },
470  {
471  for (auto & nodeInput : node.Inputs())
472  {
473  hasModRefChainLinkAboveInRegion |=
474  traceModRefChains(*nodeInput.origin(), summary);
475  }
476  doneTracing = true;
477  },
478  [&](const MemoryStateJoinOperation &)
479  {
480  for (auto & nodeInput : node.Inputs())
481  {
482  hasModRefChainLinkAboveInRegion |=
483  traceModRefChains(*nodeInput.origin(), summary);
484  }
485  doneTracing = true;
486  },
487  [&](const MemoryStateMergeOperation &)
488  {
489  for (auto & nodeInput : node.Inputs())
490  {
491  hasModRefChainLinkAboveInRegion |=
492  traceModRefChains(*nodeInput.origin(), summary);
493  }
494  doneTracing = true;
495  },
496  [&](const AllocaOperation &)
497  {
498  doneTracing = true;
499  },
500  [&](const MallocOperation &)
501  {
502  doneTracing = true;
503  },
504  [&](const UndefValueOperation &)
505  {
506  doneTracing = true;
507  },
508  [&]()
509  {
510  throw std::logic_error(
511  util::strfmt("Unhandled operation type: ", operation.debug_string()));
512  });
513  },
514  [&]()
515  {
516  throw std::logic_error(util::strfmt("Unhandled node type: ", node.DebugString()));
517  });
518  } while (!doneTracing);
519 
520  summary.add(std::move(currentModRefChain));
521  Context_->addModRefChainInformation(startOutput, { hasModRefChainLinkAboveInRegion });
522  return hasModRefChainLinkAboveInRegion;
523 }
524 
525 }
Call operation class.
Definition: call.hpp:251
static rvsdg::Input & GetMemoryStateInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:357
static rvsdg::Input & mapMemoryStateOutputToInput(rvsdg::Output &output) noexcept
Definition: operators.hpp:2547
std::optional< ModRefChainInformation > tryGetModRefChainInformation(const rvsdg::Output &output) const
void dropModRefChainInformation(const rvsdg::Region &region)
std::unordered_map< const rvsdg::Output *, ModRefChainInformation > ModRefChainInformationMap
static std::unique_ptr< Context > create()
std::unordered_map< const rvsdg::Output *, ModRefChainLink::Type > Types_
bool hasModRefChainLinkType(const rvsdg::Output &output) const noexcept
std::unordered_map< const rvsdg::Region *, ModRefChainInformationMap > RegionMap_
void addModRefChainInformation(const rvsdg::Output &output, ModRefChainInformation modRefChainInformation)
ModRefChainInformationMap & getOrInsertModRefChainInformationMap(const rvsdg::Region &region)
void add(const rvsdg::Output &output, const ModRefChainLink::Type &type)
ModRefChainLink::Type getModRefChainLinkType(const rvsdg::Output &output) const
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
std::unique_ptr< Context > Context_
bool separateReferenceChains(rvsdg::Output &startOutput)
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
void separateRefenceChainsInTheta(rvsdg::ThetaNode &thetaNode)
void separateRefenceChainsInGamma(rvsdg::GammaNode &gammaNode)
static std::vector< ModRefChain > extractReferenceSubchains(const ModRefChain &modRefChain)
bool traceModRefChains(rvsdg::Output &startOutput, ModRefChainSummary &summary)
void separateReferenceChainsInLambda(rvsdg::LambdaNode &lambdaNode)
~LoadChainSeparation() noexcept override
void separateReferenceChainsInRegion(rvsdg::Region &region)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
Definition: Load.hpp:157
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::SimpleNode & CreateNode(const std::vector< rvsdg::Output * > &operands)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
Definition: Store.hpp:134
UndefValueOperation class.
Definition: operators.hpp:1023
Delta node.
Definition: delta.hpp:129
Conditional operator / pattern matching.
Definition: gamma.hpp:99
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
Definition: gamma.cpp:381
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Definition: gamma.cpp:305
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
rvsdg::Region * region() const noexcept
Definition: node.cpp:151
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
SubregionIteratorRange Subregions()
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
rvsdg::Region * subregion() const noexcept
Definition: theta.hpp:79
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
rvsdg::Input & GetMemoryStateRegionResult(const rvsdg::LambdaNode &lambdaNode) noexcept
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
Region * TryGetOwnerRegion(const rvsdg::Input &input) noexcept
Definition: node.hpp:1021
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35
void add(ModRefChainLink modRefChainLink)