Jlm
LoadChainSeparation.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
15 #include <jlm/rvsdg/delta.hpp>
16 #include <jlm/rvsdg/gamma.hpp>
17 #include <jlm/rvsdg/lambda.hpp>
18 #include <jlm/rvsdg/MatchType.hpp>
19 #include <jlm/rvsdg/Phi.hpp>
22 #include <jlm/rvsdg/theta.hpp>
23 #include <jlm/rvsdg/traverser.hpp>
24 
25 namespace jlm::llvm
26 {
27 
29 {
30 public:
31  bool
32  hasModRefChainLinkType(const rvsdg::Output & output) const noexcept
33  {
34  return Types_.find(&output) != Types_.end();
35  }
36 
37  void
38  add(const rvsdg::Output & output, const ModRefChainLink::Type & type)
39  {
40  JLM_ASSERT(is<MemoryStateType>(output.Type()));
42  Types_[&output] = type;
43  }
44 
46  getModRefChainLinkType(const rvsdg::Output & output) const
47  {
49  return Types_.find(&output)->second;
50  }
51 
52  static std::unique_ptr<Context>
54  {
55  return std::make_unique<Context>();
56  }
57 
58 private:
59  std::unordered_map<const rvsdg::Output *, ModRefChainLink::Type> Types_{};
60 };
61 
63 
65  : Transformation("LoadChainSeparation")
66 {}
67 
68 void
70 {
72 
74 }
75 
76 void
78 {
79  util::HashSet<rvsdg::Output *> visitedOutputs;
80 
81  // We require a top-down traverser to ensure that lambda nodes are handled before call nodes
82  for (const auto & node : rvsdg::TopDownTraverser(&region))
83  {
85  *node,
86  [&](rvsdg::LambdaNode & lambdaNode)
87  {
89  },
90  [&](rvsdg::PhiNode & phiNode)
91  {
92  separateReferenceChainsInRegion(*phiNode.subregion());
93  },
94  [&](rvsdg::GammaNode & gammaNode)
95  {
97  },
98  [&](rvsdg::ThetaNode & thetaNode)
99  {
100  separateRefenceChainsInTheta(thetaNode, visitedOutputs);
101  },
102  [](rvsdg::DeltaNode &)
103  {
104  // Nothing needs to be done
105  },
106  [&](rvsdg::SimpleNode & simpleNode)
107  {
108  for (auto & output : simpleNode.Outputs())
109  {
110  if (output.IsDead() && is<MemoryStateType>(output.Type()))
111  {
112  // Dead memory state outputs will never be reachable from structural node results.
113  // Thus, we need to handle them here in order to separate all reference chains.
114  separateReferenceChains(output, visitedOutputs);
115  }
116  }
117  },
118  [&]()
119  {
120  throw std::logic_error(util::strfmt("Unhandled node type: ", node->DebugString()));
121  });
122  }
123 }
124 
125 void
127 {
128  // Handle innermost regions first
130 
131  util::HashSet<rvsdg::Output *> visitedOutputs;
132  separateReferenceChains(*GetMemoryStateRegionResult(lambdaNode).origin(), visitedOutputs);
133 }
134 
135 void
137 {
138  // Handle innermost regions first
139  for (auto & subregion : gammaNode.Subregions())
140  {
142  }
143 
144  std::vector<util::HashSet<rvsdg::Output *>> visitedOutputs(gammaNode.nsubregions());
145  for (auto & [branchResults, output] : gammaNode.GetExitVars())
146  {
147  if (is<MemoryStateType>(output->Type()))
148  {
149  for (const auto branchResult : branchResults)
150  {
151  const auto regionIndex = branchResult->region()->index();
152  JLM_ASSERT(regionIndex < visitedOutputs.size());
153  separateReferenceChains(*branchResult->origin(), visitedOutputs[regionIndex]);
154  }
155  }
156  }
157 }
158 
159 void
161  rvsdg::ThetaNode & thetaNode,
162  util::HashSet<rvsdg::Output *> & visitedOutputs)
163 {
164  // Handle innermost region first
166 
167  util::HashSet<rvsdg::Output *> visitedOutputsSubregion;
168  for (const auto loopVar : thetaNode.GetLoopVars())
169  {
170  if (!is<MemoryStateType>(loopVar.output->Type()))
171  continue;
172 
173  // Separate reference chains in theta subregion
174  auto hasModificationChainLink =
175  separateReferenceChains(*loopVar.post->origin(), visitedOutputsSubregion);
176  Context_->add(
177  *loopVar.output,
178  hasModificationChainLink ? ModRefChainLink::Type::Modification
180 
181  // Handle dead theta outputs
182  if (loopVar.output->IsDead())
183  {
184  separateReferenceChains(*loopVar.output, visitedOutputs);
185  }
186  }
187 }
188 
189 bool
191  rvsdg::Output & startOutput,
192  util::HashSet<rvsdg::Output *> & visitedOutputs)
193 {
194  JLM_ASSERT(is<MemoryStateType>(startOutput.Type()));
195 
196  ModRefChainSummary summary;
197  traceModRefChains(startOutput, visitedOutputs, summary);
198  for (auto & modRefChain : summary.modRefChains)
199  {
200  const auto refSubchains = extractReferenceSubchains(modRefChain);
201  for (const auto & [_, links] : refSubchains)
202  {
203  // Divert the operands of the respective inputs for each encountered reference node and
204  // collect join operands
205  std::vector<rvsdg::Output *> joinOperands;
206  const auto newMemoryStateOperand = mapMemoryStateOutputToInput(*links.back().output).origin();
207  for (auto [linkOutput, linkModRefType] : links)
208  {
209  JLM_ASSERT(linkModRefType == ModRefChainLink::Type::Reference);
210  auto & modRefChainInput = mapMemoryStateOutputToInput(*linkOutput);
211  modRefChainInput.divert_to(newMemoryStateOperand);
212  joinOperands.push_back(linkOutput);
213  }
214 
215  // Create join node and divert the current memory state output
216  if (!links.front().output->IsDead())
217  {
218  auto & joinNode = MemoryStateJoinOperation::CreateNode(joinOperands);
219  links.front().output->divertUsersWhere(
220  *joinNode.output(0),
221  [&joinNode](const rvsdg::Input & user)
222  {
223  return rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user) != &joinNode;
224  });
225  }
226  }
227  }
228 
229  return summary.hasModificationChainLink;
230 }
231 
232 rvsdg::Input &
234 {
235  if (auto [loadNode, loadOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<LoadOperation>(output);
236  loadOperation)
237  {
239  }
240 
241  if (const auto thetaNode = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
242  {
243  return *thetaNode->MapOutputLoopVar(output).input;
244  }
245 
246  throw std::logic_error("Unhandled node type!");
247 }
248 
249 std::vector<LoadChainSeparation::ModRefChain>
251 {
252  std::vector<ModRefChain> refSubchains;
253  for (auto linkIt = modRefChain.links.begin(); linkIt != modRefChain.links.end();)
254  {
255  if (linkIt->type != ModRefChainLink::Type::Reference)
256  {
257  // The current link is not a reference. Let's continue with the next one.
258  ++linkIt;
259  continue;
260  }
261 
262  auto nextLinkIt = std::next(linkIt);
263  if (nextLinkIt == modRefChain.links.end()
264  || nextLinkIt->type != ModRefChainLink::Type::Reference)
265  {
266  // We only want to separate reference chains with at least two links
267  ++linkIt;
268  continue;
269  }
270 
271  // We found a new reference subchain. Let's grab all the links
272  refSubchains.push_back({});
273  while (linkIt != modRefChain.links.end() && linkIt->type == ModRefChainLink::Type::Reference)
274  {
275  refSubchains.back().links.push_back(*linkIt);
276  ++linkIt;
277  }
278  }
279 
280  return refSubchains;
281 }
282 
283 void
285  rvsdg::Output & startOutput,
286  util::HashSet<rvsdg::Output *> & visitedOutputs,
287  ModRefChainSummary & summary)
288 {
289  JLM_ASSERT(is<MemoryStateType>(startOutput.Type()));
290 
291  if (!visitedOutputs.insert(&startOutput))
292  {
293  return;
294  }
295 
296  ModRefChain currentModRefChain;
297  rvsdg::Output * currentOutput = &startOutput;
298  bool doneTracing = false;
299  do
300  {
301  if (rvsdg::TryGetOwnerRegion(*currentOutput))
302  {
303  // We have a region argument. Stop tracing.
304  break;
305  }
306 
307  auto & node = rvsdg::AssertGetOwnerNode<rvsdg::Node>(*currentOutput);
309  node,
310  [&](const rvsdg::GammaNode & gammaNode)
311  {
312  // FIXME: I really would like that state edges through gammas would be recognized as
313  // either modifying or just referencing. However, we would need to know what the
314  // operations in the gamma on all branches are and which memory state exit variable maps
315  // to which memory state entry variable. We need some more machinery for it first before
316  // we can do that.
317  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
318  for (auto [entryVarInput, _] : gammaNode.GetEntryVars())
319  {
320  if (is<MemoryStateType>(entryVarInput->Type()))
321  {
322  traceModRefChains(*entryVarInput->origin(), visitedOutputs, summary);
323  }
324  }
325  doneTracing = true;
326  },
327  [&](const rvsdg::ThetaNode &)
328  {
329  const auto modRefChainLinkType = Context_->getModRefChainLinkType(*currentOutput);
330  currentModRefChain.add({ currentOutput, modRefChainLinkType });
331  currentOutput = mapMemoryStateOutputToInput(*currentOutput).origin();
332  },
333  [&](const rvsdg::SimpleNode & simpleNode)
334  {
335  auto & operation = simpleNode.GetOperation();
337  operation,
338  [&](const LoadOperation &)
339  {
340  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Reference });
341  currentOutput = LoadOperation::MapMemoryStateOutputToInput(*currentOutput).origin();
342  },
343  [&](const StoreOperation &)
344  {
345  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
346  currentOutput =
348  },
349  [&](const FreeOperation &)
350  {
351  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
352  currentOutput = FreeOperation::mapMemoryStateOutputToInput(*currentOutput).origin();
353  },
354  [&](const MemCpyOperation &)
355  {
356  // FIXME: We really would like to know here which memory state belongs to the source
357  // and which to the dst address. This would allow us to be more precise in the
358  // separation.
359  currentModRefChain.add({ currentOutput, ModRefChainLink::Type::Modification });
360  currentOutput =
362  },
363  [&](const CallOperation &)
364  {
365  // FIXME: I really would like that state edges through calls would be recognized as
366  // either modifying or just referencing.
368  *CallOperation::GetMemoryStateInput(node).origin(),
369  visitedOutputs,
370  summary);
371  doneTracing = true;
372  },
374  {
375  for (auto & nodeInput : node.Inputs())
376  {
377  traceModRefChains(*nodeInput.origin(), visitedOutputs, summary);
378  }
379  doneTracing = true;
380  },
382  {
383  // LambdaEntryMemoryStateSplitOperation nodes should always be connected to a lambda
384  // argument. In other words, this is as far as we can trace in the graph. Just
385  // return what we found so far.
386  doneTracing = true;
387  },
389  {
390  // FIXME: I really would like that state edges through calls would be recognized as
391  // either modifying or just referencing.
392  traceModRefChains(*node.input(0)->origin(), visitedOutputs, summary);
393  doneTracing = true;
394  },
396  {
397  for (auto & nodeInput : node.Inputs())
398  {
399  traceModRefChains(*nodeInput.origin(), visitedOutputs, summary);
400  }
401  doneTracing = true;
402  },
403  [&](const MemoryStateJoinOperation &)
404  {
405  for (auto & nodeInput : node.Inputs())
406  {
407  traceModRefChains(*nodeInput.origin(), visitedOutputs, summary);
408  }
409  doneTracing = true;
410  },
411  [&](const MemoryStateMergeOperation &)
412  {
413  for (auto & nodeInput : node.Inputs())
414  {
415  traceModRefChains(*nodeInput.origin(), visitedOutputs, summary);
416  }
417  doneTracing = true;
418  },
419  [&](const AllocaOperation &)
420  {
421  doneTracing = true;
422  },
423  [&](const MallocOperation &)
424  {
425  doneTracing = true;
426  },
427  [&](const UndefValueOperation &)
428  {
429  doneTracing = true;
430  },
431  [&]()
432  {
433  throw std::logic_error(
434  util::strfmt("Unhandled operation type: ", operation.debug_string()));
435  });
436  },
437  [&]()
438  {
439  throw std::logic_error(util::strfmt("Unhandled node type: ", node.DebugString()));
440  });
441  } while (!doneTracing);
442 
443  summary.add(std::move(currentModRefChain));
444 }
445 
446 }
Call operation class.
Definition: call.hpp:249
static rvsdg::Input & GetMemoryStateInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:338
static rvsdg::Input & mapMemoryStateOutputToInput(rvsdg::Output &output) noexcept
Definition: operators.hpp:2543
static std::unique_ptr< Context > create()
std::unordered_map< const rvsdg::Output *, ModRefChainLink::Type > Types_
bool hasModRefChainLinkType(const rvsdg::Output &output) const noexcept
void add(const rvsdg::Output &output, const ModRefChainLink::Type &type)
ModRefChainLink::Type getModRefChainLinkType(const rvsdg::Output &output) const
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
std::unique_ptr< Context > Context_
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
void separateRefenceChainsInGamma(rvsdg::GammaNode &gammaNode)
static std::vector< ModRefChain > extractReferenceSubchains(const ModRefChain &modRefChain)
void traceModRefChains(rvsdg::Output &startOutput, util::HashSet< rvsdg::Output * > &visitedOutputs, ModRefChainSummary &summary)
void separateReferenceChainsInLambda(rvsdg::LambdaNode &lambdaNode)
bool separateReferenceChains(rvsdg::Output &startOutput, util::HashSet< rvsdg::Output * > &visitedOutputs)
~LoadChainSeparation() noexcept override
void separateReferenceChainsInRegion(rvsdg::Region &region)
void separateRefenceChainsInTheta(rvsdg::ThetaNode &thetaNode, util::HashSet< rvsdg::Output * > &visitedOutputs)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
Definition: Load.hpp:142
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
Definition: MemCpy.hpp:107
static rvsdg::SimpleNode & CreateNode(const std::vector< rvsdg::Output * > &operands)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
Definition: Store.hpp:111
UndefValueOperation class.
Definition: operators.hpp:992
Delta node.
Definition: delta.hpp:129
Conditional operator / pattern matching.
Definition: gamma.hpp:99
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
Definition: gamma.cpp:361
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Definition: gamma.cpp:303
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
SubregionIteratorRange Subregions()
size_t nsubregions() const noexcept
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
rvsdg::Region * subregion() const noexcept
Definition: theta.hpp:79
bool insert(ItemType item)
Definition: HashSet.hpp:210
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
rvsdg::Input & GetMemoryStateRegionResult(const rvsdg::LambdaNode &lambdaNode) noexcept
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
static std::string type(const Node *n)
Definition: view.cpp:255
Region * TryGetOwnerRegion(const rvsdg::Input &input) noexcept
Definition: node.hpp:1021
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35
void add(ModRefChainLink modRefChainLink)