Jlm
AggregateAllocaSplitting.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2026 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
12 #include <jlm/llvm/ir/Trace.hpp>
14 #include <jlm/rvsdg/delta.hpp>
15 #include <jlm/rvsdg/gamma.hpp>
16 #include <jlm/rvsdg/lambda.hpp>
17 #include <jlm/rvsdg/MatchType.hpp>
18 #include <jlm/rvsdg/Phi.hpp>
21 #include <jlm/rvsdg/theta.hpp>
22 #include <jlm/util/Statistics.hpp>
23 
24 #include <deque>
25 
26 namespace jlm::llvm
27 {
28 
30 {
31  const char * numAggregateAllocaNodesLabel_ = "#AggregateAllocaNodes";
32  const char * numAggregateStructAllocaNodesLabel_ = "#AggregateStructAllocaNodes";
33  const char * numSplitableTypeAggregateAllocaNodesLabel_ = "#SplitableTypeAggregateAllocaNodes";
34  const char * numSplitAggregateAllocaNodesLabel_ = "#SplitAggregateAllocaNodes";
35  const char * aggregateAllocaSplittingTimerLabel_ = "AggregateAllocaSplittingTime";
36 
37 public:
38  ~Statistics() noexcept override = default;
39 
40  explicit Statistics(util::FilePath filePath)
41  : util::Statistics(Id::AggregateAllocaSplitting, std::move(filePath))
42  {}
43 
44  void
46  {
48  }
49 
50  void
52  const size_t numAggregateAllocaNodes,
53  const size_t numAggregateStructAllocaNodes,
54  const size_t numSplitableTypeAggregateAllocaNodes,
55  const size_t numSplitAggregateAllocaNodes)
56  {
58  AddMeasurement(numAggregateAllocaNodesLabel_, numAggregateAllocaNodes);
59  AddMeasurement(numAggregateStructAllocaNodesLabel_, numAggregateStructAllocaNodes);
62  numSplitableTypeAggregateAllocaNodes);
63  AddMeasurement(numSplitAggregateAllocaNodesLabel_, numSplitAggregateAllocaNodes);
64  }
65 
66  static std::unique_ptr<Statistics>
68  {
69  return std::make_unique<Statistics>(std::move(filePath));
70  }
71 };
72 
74 {
79 };
80 
82 {
85  {}
86 
88  std::vector<rvsdg::SimpleNode *> allocaConsumers{};
89 };
90 
92 
94  : Transformation("AggregateAllocaSplitting")
95 {}
96 
97 bool
99 {
100  // FIXME: We currently only look at alloca nodes with a struct type. We might be able
101  // to do something for alloca nodes with array types as well.
102  const auto structType = dynamic_cast<const StructType *>(&type);
103  if (!structType)
104  return false;
105 
106  for (const auto & elementType : structType->elementTypes())
107  {
108  if (IsAggregateType(*elementType))
109  {
110  // FIXME: We currently only look at alloca nodes that do not contain nested aggregate types.
111  return false;
112  }
113  }
114 
115  return true;
116 }
117 
118 std::optional<AggregateAllocaSplitting::AllocaTraceInfo>
120 {
121  [[maybe_unused]] auto allocaOperation =
122  dynamic_cast<const AllocaOperation *>(&allocaNode.GetOperation());
123  JLM_ASSERT(allocaOperation && isSplitableType(*allocaOperation->allocatedType()));
124 
125  auto & address = AllocaOperation::getPointerOutput(allocaNode);
126 
127  bool isSplitable = true;
128  AllocaTraceInfo allocaTraceInfo(allocaNode);
129 
131  std::deque<rvsdg::Output *> toVisit{ &address };
132  auto addToVisitSet = [&](rvsdg::Output & output)
133  {
134  if (!seen.Contains(&output))
135  {
136  toVisit.push_back(&output);
137  }
138  seen.insert(&output);
139  };
140  auto removeFromVisitSet = [&]()
141  {
142  const auto output = toVisit.front();
143  toVisit.pop_front();
144  return output;
145  };
146 
147  while (!toVisit.empty() && isSplitable)
148  {
149  const auto currentOutput = removeFromVisitSet();
150 
151  for (auto & user : currentOutput->Users())
152  {
153  if (!isSplitable)
154  {
155  // Stop handling users if the previous user was already not splitable
156  break;
157  }
158 
159  if (auto userRegion = rvsdg::TryGetOwnerRegion(user))
160  {
161  // We should never have an alloca connected to a graph export
162  JLM_ASSERT(userRegion->node());
163 
165  *userRegion->node(),
166  [&](const rvsdg::GammaNode & gammaNode)
167  {
168  auto & gammaOutput = gammaNode.mapBranchResultToOutput(user);
169  addToVisitSet(gammaOutput);
170  return true;
171  },
172  [&](const rvsdg::ThetaNode & thetaNode)
173  {
174  const auto loopVar = thetaNode.MapPostLoopVar(user);
175  addToVisitSet(*loopVar.pre);
176  addToVisitSet(*loopVar.output);
177  return true;
178  },
179  [&](const rvsdg::LambdaNode &)
180  {
181  return false;
182  },
183  [&]()
184  {
185  throw std::logic_error(util::strfmt(
186  "Unhandled owner region node type: ",
187  userRegion->node()->DebugString()));
188  // Silence compiler
189  return false;
190  });
191  }
192  else if (auto userNode = rvsdg::TryGetOwnerNode<rvsdg::Node>(user))
193  {
195  *userNode,
196  [&](rvsdg::GammaNode & gammaNode)
197  {
198  auto roleVar = gammaNode.MapInput(user);
199  if (auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
200  {
201  for (auto argument : entryVar->branchArgument)
202  {
203  addToVisitSet(*argument);
204  }
205  }
206  else
207  {
208  throw std::logic_error(util::strfmt("Unhandled role variable."));
209  }
210 
211  return true;
212  },
213  [&](rvsdg::ThetaNode & thetaNode)
214  {
215  const auto loopVar = thetaNode.MapInputLoopVar(user);
216  addToVisitSet(*loopVar.pre);
217  return true;
218  },
219  [&](rvsdg::SimpleNode & simpleNode)
220  {
221  auto & operation = simpleNode.GetOperation();
223  operation,
224  [&](const GetElementPtrOperation &)
225  {
226  JLM_ASSERT(userNode->input(0) == &user);
227  if (const auto indicesOpt =
229  !indicesOpt.has_value())
230  return false;
231 
232  allocaTraceInfo.allocaConsumers.push_back(&simpleNode);
233  return true;
234  },
235  [&]()
236  {
237  return false;
238  });
239  },
240  [&]()
241  {
242  throw std::logic_error(
243  util::strfmt("Unhandled node type: ", userNode->DebugString()));
244  // Silence compiler
245  return false;
246  });
247  }
248  else
249  {
250  throw std::logic_error("Unhandled owner type");
251  }
252  }
253  }
254 
255  if (!isSplitable)
256  return std::nullopt;
257 
258  for (const auto allocaConsumer : allocaTraceInfo.allocaConsumers)
259  {
260  if (!checkGetElementPtrUsers(*allocaConsumer))
261  return std::nullopt;
262  }
263 
264  return std::make_optional(allocaTraceInfo);
265 }
266 
267 bool
269 {
270  [[maybe_unused]] auto gepOperation =
271  dynamic_cast<const GetElementPtrOperation *>(&gepNode.GetOperation());
272  auto & address = *gepNode.output(0);
273 
274  bool hasOnlyLoadsAndStores = true;
275 
277  std::deque<rvsdg::Output *> toVisit{ &address };
278  auto addToVisitSet = [&](rvsdg::Output & output)
279  {
280  if (!seen.Contains(&output))
281  {
282  toVisit.push_back(&output);
283  }
284  seen.insert(&output);
285  };
286  auto removeFromVisitSet = [&]()
287  {
288  const auto output = toVisit.front();
289  toVisit.pop_front();
290  return output;
291  };
292 
293  while (!toVisit.empty() && hasOnlyLoadsAndStores)
294  {
295  const auto currentOutput = removeFromVisitSet();
296  for (auto & user : currentOutput->Users())
297  {
298  if (!hasOnlyLoadsAndStores)
299  {
300  // Stop handling users if the previous user was already not a load or store
301  break;
302  }
303 
304  if (auto userRegion = rvsdg::TryGetOwnerRegion(user))
305  {
306  // We should never have a gep node connected to a graph export
307  JLM_ASSERT(userRegion->node());
308 
309  hasOnlyLoadsAndStores = rvsdg::MatchTypeWithDefault(
310  *userRegion->node(),
311  [&](const rvsdg::GammaNode & gammaNode)
312  {
313  auto & gammaOutput = gammaNode.mapBranchResultToOutput(user);
314  addToVisitSet(gammaOutput);
315  return true;
316  },
317  [&](const rvsdg::ThetaNode & thetaNode)
318  {
319  const auto loopVar = thetaNode.MapPostLoopVar(user);
320  addToVisitSet(*loopVar.pre);
321  addToVisitSet(*loopVar.output);
322  return true;
323  },
324  [&](const rvsdg::LambdaNode &)
325  {
326  return false;
327  },
328  [&]()
329  {
330  throw std::logic_error(util::strfmt(
331  "Unhandled owner region node type: ",
332  userRegion->node()->DebugString()));
333  // Silence compiler
334  return false;
335  });
336  }
337  else if (auto userNode = rvsdg::TryGetOwnerNode<rvsdg::Node>(user))
338  {
339  hasOnlyLoadsAndStores = rvsdg::MatchTypeWithDefault(
340  *userNode,
341  [&](const rvsdg::GammaNode & gammaNode)
342  {
343  auto roleVar = gammaNode.MapInput(user);
344  if (auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
345  {
346  for (auto argument : entryVar->branchArgument)
347  {
348  addToVisitSet(*argument);
349  }
350  }
351  else
352  {
353  throw std::logic_error(util::strfmt("Unhandled role variable."));
354  }
355 
356  return true;
357  },
358  [&](const rvsdg::ThetaNode & thetaNode)
359  {
360  const auto loopVar = thetaNode.MapInputLoopVar(user);
361  addToVisitSet(*loopVar.pre);
362  return true;
363  },
364  [&](const rvsdg::SimpleNode & simpleNode)
365  {
366  auto & operation = simpleNode.GetOperation();
368  operation,
369  [&](const LoadOperation &)
370  {
371  return true;
372  },
373  [&](const StoreOperation &)
374  {
375  if (&user != &StoreOperation::AddressInput(simpleNode))
376  return false;
377 
378  return true;
379  },
380  [&](const IOBarrierOperation &)
381  {
382  addToVisitSet(*simpleNode.output(0));
383  return true;
384  },
385  [&]()
386  {
387  return false;
388  });
389  },
390  [&]()
391  {
392  throw std::logic_error(
393  util::strfmt("Unhandled node type: ", userNode->DebugString()));
394  // Silence compiler
395  return false;
396  });
397  }
398  else
399  {
400  throw std::logic_error("Unhandled owner type");
401  }
402  }
403  }
404 
405  return hasOnlyLoadsAndStores;
406 }
407 
408 std::vector<AggregateAllocaSplitting::AllocaTraceInfo>
410 {
411  std::function<void(rvsdg::Region &, std::vector<AllocaTraceInfo> &)> findAllocaNodes =
412  [&](rvsdg::Region & region, std::vector<AllocaTraceInfo> & traceInfo)
413  {
414  for (auto & node : region.Nodes())
415  {
417  node,
418  [&](rvsdg::GammaNode & gammaNode)
419  {
420  for (auto & subregion : gammaNode.Subregions())
421  findAllocaNodes(subregion, traceInfo);
422  },
423  [&](rvsdg::ThetaNode & thetaNode)
424  {
425  findAllocaNodes(*thetaNode.subregion(), traceInfo);
426  },
427  [&](rvsdg::LambdaNode & lambdaNode)
428  {
429  findAllocaNodes(*lambdaNode.subregion(), traceInfo);
430  },
431  [&](rvsdg::PhiNode & phiNode)
432  {
433  findAllocaNodes(*phiNode.subregion(), traceInfo);
434  },
435  [&](rvsdg::DeltaNode &)
436  {
437  // Nothing needs to be done
438  },
439  [&](rvsdg::SimpleNode & simpleNode)
440  {
441  const auto allocaOperation =
442  dynamic_cast<const AllocaOperation *>(&simpleNode.GetOperation());
443  if (!allocaOperation)
444  return;
445 
446  auto & allocaType = *allocaOperation->allocatedType();
447  if (is<StructType>(allocaType))
448  {
449  context_->numAggregateStructAllocaNodes++;
450  context_->numAggregateAllocaNodes++;
451  }
452  else if (IsAggregateType(allocaType))
453  {
454  context_->numAggregateAllocaNodes++;
455  }
456 
457  if (isSplitableType(*allocaOperation->allocatedType()))
458  {
459  context_->numSplitableTypeAggregateAllocaNodes++;
460  if (auto allocaTraceInfo = isSplitable(simpleNode))
461  {
462  context_->numSplitAggregateAllocaNodes++;
463  traceInfo.emplace_back(*allocaTraceInfo);
464  }
465  }
466  },
467  [&]()
468  {
469  throw std::logic_error("Unhandled node type.");
470  });
471  }
472  };
473 
474  std::vector<AllocaTraceInfo> traceInfo;
475  findAllocaNodes(region, traceInfo);
476  return traceInfo;
477 }
478 
479 void
481 {
482  auto & allocaNode = *allocaTraceInfo.allocaNode;
483  const auto allocaOperation = dynamic_cast<const AllocaOperation *>(&allocaNode.GetOperation());
484  JLM_ASSERT(allocaOperation && isSplitableType(*allocaOperation->allocatedType()));
485  auto & allocaType = *std::static_pointer_cast<const StructType>(allocaOperation->allocatedType());
486  const auto & countInput = AllocaOperation::getCountInput(allocaNode);
487  const auto alignment = allocaOperation->alignment();
488 
489  // Create alloca nodes for each element in the aggregate type
490  std::vector<rvsdg::Node *> elementAllocaNodes;
491  std::vector<rvsdg::Output *> allocaMemoryStates;
492  for (const auto & elementType : allocaType.elementTypes())
493  {
494  auto & elementAlloca =
495  AllocaOperation::createNode(elementType, *countInput.origin(), alignment);
496  elementAllocaNodes.push_back(&elementAlloca);
497  allocaMemoryStates.push_back(&AllocaOperation::getMemoryStateOutput(elementAlloca));
498  }
499 
500  // Replace alloca node's memory state output
501  const auto memoryState = MemoryStateMergeOperation::Create(allocaMemoryStates);
502  AllocaOperation::getMemoryStateOutput(allocaNode).divert_users(memoryState);
503 
504  // Replace alloca node consumers
505  for (auto allocaConsumer : allocaTraceInfo.allocaConsumers)
506  {
508  allocaConsumer->GetOperation(),
509  [&](const GetElementPtrOperation &)
510  {
511  JLM_ASSERT(GetElementPtrOperation::numIndices(*allocaConsumer) == 2);
512  auto & consumerRegion = *allocaConsumer->region();
513 
514  const auto indices =
515  GetElementPtrOperation::tryGetConstantIndices(*allocaConsumer).value();
516  JLM_ASSERT(indices.size() == 2 && indices[0] == 0);
517 
518  auto elementAlloca = elementAllocaNodes[indices[1]];
519  // FIXME: Introduce caching of routed values to avoid duplicated routing.
520  auto & routedAddress = rvsdg::RouteToRegion(
521  AllocaOperation::getPointerOutput(*elementAlloca),
522  consumerRegion);
523  allocaConsumer->output(0)->divert_users(&routedAddress);
524  },
525  [&]()
526  {
527  throw std::logic_error(
528  util::strfmt("Unhandled node type: ", allocaConsumer->DebugString()));
529  });
530  }
531 }
532 
533 void
535 {
536  const auto traceInfo = findSplitableAllocaNodes(rvsdgModule.Rvsdg().GetRootRegion());
537  for (const auto & allocaTraceInfo : traceInfo)
538  {
539  splitAllocaNode(allocaTraceInfo);
540  }
541 
542  // Remove all nodes that became dead throughout the transformation
543  rvsdgModule.Rvsdg().PruneNodes();
544 }
545 
546 void
548  rvsdg::RvsdgModule & module,
550 {
551  context_ = std::make_unique<Context>();
552  auto statistics = Statistics::create(module.SourceFilePath().value());
553 
554  statistics->start();
555  splitAllocaNodes(module);
556  statistics->stop(
557  context_->numAggregateAllocaNodes,
558  context_->numAggregateStructAllocaNodes,
559  context_->numSplitableTypeAggregateAllocaNodes,
560  context_->numSplitAggregateAllocaNodes);
561 
562  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
563 
564  // Discard internal state to free up memory after we are done
565  context_.reset();
566 }
567 
568 }
static jlm::util::StatisticsCollector statisticsCollector
void stop(const size_t numAggregateAllocaNodes, const size_t numAggregateStructAllocaNodes, const size_t numSplitableTypeAggregateAllocaNodes, const size_t numSplitAggregateAllocaNodes)
static std::unique_ptr< Statistics > create(util::FilePath filePath)
Aggregate Alloca Splitting Transformation.
static bool checkGetElementPtrUsers(const rvsdg::SimpleNode &gepNode)
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static void splitAllocaNode(const AllocaTraceInfo &allocaTraceInfo)
void splitAllocaNodes(rvsdg::RvsdgModule &rvsdgModule)
static bool isSplitableType(const rvsdg::Type &type)
static std::optional< AllocaTraceInfo > isSplitable(rvsdg::SimpleNode &allocaNode)
~AggregateAllocaSplitting() noexcept override
std::vector< AllocaTraceInfo > findSplitableAllocaNodes(rvsdg::Region &region) const
static rvsdg::Input & getCountInput(rvsdg::Node &node)
Definition: alloca.hpp:67
static rvsdg::SimpleNode & createNode(std::shared_ptr< const rvsdg::Type > allocatedType, rvsdg::Output &count, const size_t alignment)
Definition: alloca.hpp:109
static rvsdg::Output & getMemoryStateOutput(rvsdg::Node &node)
Definition: alloca.hpp:81
static rvsdg::Output & getPointerOutput(rvsdg::Node &node)
Definition: alloca.hpp:74
const std::shared_ptr< const rvsdg::Type > & allocatedType() const noexcept
Definition: alloca.hpp:55
static std::optional< std::vector< uint64_t > > tryGetConstantIndices(const rvsdg::Node &node) noexcept
static rvsdg::Output * Create(const std::vector< rvsdg::Output * > &operands)
static rvsdg::Input & AddressInput(const rvsdg::Node &node) noexcept
Definition: Store.hpp:75
StructType class.
Definition: types.hpp:184
Delta node.
Definition: delta.hpp:129
Conditional operator / pattern matching.
Definition: gamma.hpp:99
std::variant< MatchVar, EntryVar > MapInput(const rvsdg::Input &input) const
Maps gamma input to its role (match variable or entry variable).
Definition: gamma.cpp:316
void PruneNodes()
Definition: graph.hpp:116
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Lambda node.
Definition: lambda.hpp:83
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
NodeRange Nodes() noexcept
Definition: region.hpp:328
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
const SimpleOperation & GetOperation() const noexcept override
Definition: simple-node.cpp:48
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
SubregionIteratorRange Subregions()
bool insert(ItemType item)
Definition: HashSet.hpp:210
bool Contains(const ItemType &item) const noexcept
Definition: HashSet.hpp:150
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:574
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:137
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:158
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:177
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
bool IsAggregateType(const jlm::rvsdg::Type &type)
Definition: types.hpp:531
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
Region * TryGetOwnerRegion(const rvsdg::Input &input) noexcept
Definition: node.hpp:1021
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35