Jlm
inlining.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
12 #include <jlm/llvm/ir/Trace.hpp>
14 #include <jlm/rvsdg/gamma.hpp>
15 #include <jlm/rvsdg/MatchType.hpp>
16 #include <jlm/rvsdg/theta.hpp>
17 #include <jlm/rvsdg/Trace.hpp>
18 #include <jlm/rvsdg/traverser.hpp>
19 #include <jlm/util/Statistics.hpp>
20 #include <jlm/util/time.hpp>
21 
22 namespace jlm::llvm
23 {
24 
26 {
27  // The total number of lambda nodes
28  static constexpr const char * NumFunctions_ = "#Functions";
29  // The total number of lambda nodes that are marked as possible to inline
30  static constexpr const char * NumInlineableFunctions_ = "#InlineableFunctions";
31  // The total number of call operations
32  static constexpr const char * NumFunctionCalls_ = "#FunctionCalls";
33  // The number of call operations that could in theory be inlined
34  static constexpr const char * NumInlineableCalls_ = "#InlinableCalls";
35  // The number of call operations that were actually inlined
36  static constexpr const char * NumCallsInlined_ = "#CallsInlined";
37 
38 public:
39  ~Statistics() override = default;
40 
41  explicit Statistics(const util::FilePath & sourceFile)
42  : util::Statistics(Id::FunctionInlining, sourceFile)
43  {}
44 
45  void
47  {
49  }
50 
51  void
53  size_t numFunctions,
54  size_t numInlineableFunctions,
55  size_t numFunctionCalls,
56  size_t numInlineableCalls,
57  size_t numCallsInlined)
58  {
60  AddMeasurement(NumFunctions_, numFunctions);
61  AddMeasurement(NumInlineableFunctions_, numInlineableFunctions);
62  AddMeasurement(NumFunctionCalls_, numFunctionCalls);
63  AddMeasurement(NumInlineableCalls_, numInlineableCalls);
64  AddMeasurement(NumCallsInlined_, numCallsInlined);
65  }
66 
67  static std::unique_ptr<Statistics>
68  create(const util::FilePath & sourceFile)
69  {
70  return std::make_unique<Statistics>(sourceFile);
71  }
72 };
73 
75 {
76  // Functions that are possible to inline
77  // Just because a function is on this list, does not mean it should be inlined
79 
80  // Functions that are not exported from the module, and only called once
82 
83  // Used for statistics
84  size_t numFunctions = 0;
85  size_t numFunctionCalls = 0;
86  size_t numInlineableCalls = 0;
87  size_t numInlinedCalls = 0;
88 };
89 
90 FunctionInlining::~FunctionInlining() noexcept = default;
91 
93  : Transformation("FunctionInlining")
94 {}
95 
103 static std::vector<rvsdg::Output *>
105 {
106  constexpr bool enableCaching = false;
107  llvm::OutputTracer tracer(enableCaching);
108  // We avoid entering phi nodes, as we can not route from a sibling region
109  tracer.setEnterPhiNodes(false);
110 
111  std::vector<rvsdg::Output *> deps;
112  for (auto & ctxvar : callee.GetContextVars())
113  {
114  auto & traced = tracer.trace(*ctxvar.input->origin());
115  auto & routed = rvsdg::RouteToRegion(traced, region);
116  deps.push_back(&routed);
117  }
118 
119  return deps;
120 }
121 
159 static void
161  rvsdg::SimpleNode & callEntryMerge,
162  rvsdg::SimpleNode & callExitSplit)
163 {
164  const auto callEntryMergeOp =
165  rvsdg::tryGetOperation<CallEntryMemoryStateMergeOperation>(callEntryMerge);
166  const auto callExitSplitOp =
167  rvsdg::tryGetOperation<CallExitMemoryStateSplitOperation>(callExitSplit);
168  JLM_ASSERT(callEntryMergeOp);
169  JLM_ASSERT(callExitSplitOp);
170 
171  // Use the output of the callEntryMerge to look for a lambdaEntrySplit
172  auto & callEntryMergeOutput = *callEntryMerge.output(0);
173  if (callEntryMergeOutput.nusers() != 1)
174  return;
175  auto & user = callEntryMergeOutput.SingleUser();
176  const auto [lambdaEntrySplit, lambdaEntrySplitOp] =
177  rvsdg::TryGetSimpleNodeAndOptionalOp<LambdaEntryMemoryStateSplitOperation>(user);
178  if (!lambdaEntrySplitOp)
179  return;
180 
181  // Use the input of the callExitMerge to look for a lambdaExitMerge
182  auto & callExitSplitInput = *callExitSplit.input(0)->origin();
183  const auto [lambdaExitSplit, lambdaExitSplitOp] =
184  rvsdg::TryGetSimpleNodeAndOptionalOp<LambdaExitMemoryStateMergeOperation>(callExitSplitInput);
185  if (!lambdaExitSplitOp)
186  return;
187 
188  // For each memory state output of the lambdaEntrySplit, move its users or create undef nodes
189  for (auto & output : lambdaEntrySplit->Outputs())
190  {
193  callEntryMerge,
194  memoryStateId);
195  if (mergeInput)
196  {
197  output.divert_users(mergeInput->origin());
198  }
199  else
200  {
201  // The call has no matching memory state going into it, so we create an undef node
202  const auto undef = UndefValueOperation::Create(*output.region(), output.Type());
203  output.divert_users(undef);
204  }
205  }
206 
207  // For each memory state output of the callExitSplit, move its users
208  for (auto & output : callExitSplit.Outputs())
209  {
212  *lambdaExitSplit,
213  memoryStateId);
214  if (exitMergeInput)
215  {
216  output.divert_users(exitMergeInput->origin());
217  }
218  else
219  {
220  // the memory state id was never routed through the inside of the lambda, so route it around
222  callEntryMerge,
223  memoryStateId);
224  if (!entryMergeInput)
225  throw std::runtime_error("MemoryStateId in call exit split not found in call entry merge");
226  output.divert_users(entryMergeInput->origin());
227  }
228  }
229 }
230 
239 static void
241  const rvsdg::LambdaNode & callee,
242  rvsdg::LambdaNode & caller,
243  rvsdg::SubstitutionMap & smap)
244 {
245  // All alloca operations in the callee must be on the top level, with constant count,
246  // otherwise the callee would not have qualified for being inlined
247 
248  for (auto & node : callee.subregion()->Nodes())
249  {
250  if (!is<AllocaOperation>(&node))
251  continue;
252 
253  // Find the same alloca in the caller
254  auto oldAllocaNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(smap.lookup(*node.output(0)));
255  JLM_ASSERT(oldAllocaNode);
256 
257  auto countOrigin = AllocaOperation::getCountInput(*oldAllocaNode).origin();
258  countOrigin = &rvsdg::traceOutputIntraProcedurally(*countOrigin);
259  auto countNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*countOrigin);
260  if (!countNode || countNode->ninputs() != 0)
261  throw std::runtime_error("Alloca did not have a nullary count origin");
262 
263  // Create copies of the count node and alloca node at the top level
264  const auto newCountNode = countNode->copy(caller.subregion(), {});
265  const auto newAllocaNode = oldAllocaNode->copy(caller.subregion(), { newCountNode->output(0) });
266 
267  // Route the outputs of the new alloca to the region of the old alloca, and divert old users
268  for (size_t n = 0; n < newAllocaNode->noutputs(); n++)
269  {
270  auto & oldOutput = *oldAllocaNode->output(n);
271  auto & newOutput = *newAllocaNode->output(n);
272  auto & routed = rvsdg::RouteToRegion(newOutput, *oldAllocaNode->region());
273  oldOutput.divert_users(&routed);
274  }
275 
276  // Remove the old alloca node, which is now dead
277  remove(oldAllocaNode);
278  }
279 }
280 
281 void
283  rvsdg::SimpleNode & callNode,
284  rvsdg::LambdaNode & caller,
285  const rvsdg::LambdaNode & callee)
286 {
287  JLM_ASSERT(is<CallOperation>(&callNode));
288 
289  // Make note of the call's entry and exit memory state nodes, if they exist
290  auto callEntryMemoryStateMerge = CallOperation::tryGetMemoryStateEntryMerge(callNode);
291  auto callExitMemoryStateMerge = CallOperation::tryGetMemoryStateExitSplit(callNode);
292 
293  // Set up substitution map for function arguments and context variables
295  auto arguments = callee.GetFunctionArguments();
296  for (size_t n = 0; n < arguments.size(); n++)
297  {
298  smap.insert(arguments[n], callNode.input(n + 1)->origin());
299  }
300 
301  const auto routedDeps = routeContextVariablesToRegion(*callNode.region(), callee);
302  const auto contextVars = callee.GetContextVars();
303  JLM_ASSERT(contextVars.size() == routedDeps.size());
304  for (size_t n = 0; n < contextVars.size(); n++)
305  {
306  smap.insert(contextVars[n].inner, routedDeps[n]);
307  }
308 
309  // Use the substitution map to copy the function body into the caller region
310  callee.subregion()->copy(callNode.region(), smap);
311 
312  // Move all users of the call node's outputs to the callee's result origins
313  const auto calleeResults = callee.GetFunctionResults();
314  JLM_ASSERT(callNode.noutputs() == calleeResults.size());
315  for (size_t n = 0; n < callNode.noutputs(); n++)
316  {
317  const auto resultOrigin = calleeResults[n]->origin();
318  const auto newOrigin = &smap.lookup(*resultOrigin);
319  callNode.output(n)->divert_users(newOrigin);
320  }
321 
322  // If the callee was copied into a structural node within the caller function,
323  // hoist any copied alloca nodes to the top level region of the caller function
324  if (callNode.region() != caller.subregion())
325  {
326  hoistInlinedAllocas(callee, caller, smap);
327  }
328 
329  // The call node is now dead. Remove it
330  remove(&callNode);
331 
332  // If the call had memory state merge and split nodes,
333  // try connecting memory state edges directly instead
334  if (callEntryMemoryStateMerge && callExitMemoryStateMerge)
335  {
336  tryRerouteMemoryStateMergeAndSplit(*callEntryMemoryStateMerge, *callExitMemoryStateMerge);
337  }
338 }
339 
340 void
342 {
343  auto & caller = rvsdg::getSurroundingLambdaNode(callNode);
344  inlineCall(callNode, caller, callee);
345 }
346 
347 bool
348 FunctionInlining::canBeInlined(rvsdg::Region & region, bool topLevelRegion)
349 {
350  for (auto & node : region.Nodes())
351  {
352  if (const auto structural = dynamic_cast<rvsdg::StructuralNode *>(&node))
353  {
354  for (auto & subregion : structural->Subregions())
355  {
356  if (!canBeInlined(subregion, false))
357  return false;
358  }
359  }
360  else if (is<AllocaOperation>(&node))
361  {
362  // Having allocas that are not on the top level of the function disqualifies from inlining
363  if (!topLevelRegion)
364  return false;
365 
366  // Having allocation sizes that are not compile time constants also disqualifies from inlining
367  auto countOutput = AllocaOperation::getCountInput(node).origin();
368  countOutput = &rvsdg::traceOutputIntraProcedurally(*countOutput);
369  auto countNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*countOutput);
370 
371  // The count must come from a node, and it must be nullary
372  if (!countNode || countNode->ninputs() != 0)
373  return false;
374  }
375  else if (const auto [simple, callOp] =
376  rvsdg::TryGetSimpleNodeAndOptionalOp<CallOperation>(node);
377  simple && callOp)
378  {
379  const auto classification = CallOperation::ClassifyCall(*simple);
380  if (classification->isSetjmpCall())
381  {
382  // Calling setjmp weakens guarantees about local variables in the caller,
383  // but not local variables in the caller's caller. Inlining would mix them up.
384  return false;
385  }
386  if (classification->isVaStartCall())
387  {
388  // Calling va_start requires parameters to be passed in as expected by the ABI.
389  // This gets broken if we start inlining.
390  return false;
391  }
392  }
393  }
394 
395  return true;
396 }
397 
398 bool
400 {
401  return canBeInlined(*callee.subregion(), true);
402 }
403 
404 bool
406  [[maybe_unused]] rvsdg::SimpleNode & callNode,
407  [[maybe_unused]] rvsdg::LambdaNode & caller,
408  rvsdg::LambdaNode & callee)
409 {
410  // For now the inlining heuristic is very simple: Inline functions that are called exactly once
411  return context_->functionsCalledOnce.Contains(&callee);
412 }
413 
414 void
416  rvsdg::SimpleNode & callNode,
417  rvsdg::LambdaNode & callerLambda)
418 {
419  context_->numFunctionCalls++;
420 
421  auto classification = CallOperation::ClassifyCall(callNode);
422  if (!classification->IsDirectCall())
423  return;
424 
425  auto & calleeOutput = classification->GetLambdaOutput();
426  auto callee = rvsdg::TryGetOwnerNode<rvsdg::LambdaNode>(calleeOutput);
427  JLM_ASSERT(callee != nullptr);
428 
429  // We can not inline a function into itself
430  if (callee == &callerLambda)
431  return;
432 
433  // We can only inline functions that have been marked as inlineable
434  if (!context_->inlineableFunctions.Contains(callee))
435  return;
436 
437  // At this point we know that it is technically possible to do inlining
438  context_->numInlineableCalls++;
439  if (shouldInline(callNode, callerLambda, *callee))
440  {
441  context_->numInlinedCalls++;
442  inlineCall(callNode, callerLambda, *callee);
443  }
444 }
445 
446 void
448 {
449  for (auto node : rvsdg::TopDownTraverser(&region))
450  {
452  *node,
453  [&](rvsdg::StructuralNode & structural)
454  {
455  for (auto & subregion : structural.Subregions())
456  {
457  visitIntraProceduralRegion(subregion, lambda);
458  }
459  },
460  [&](rvsdg::SimpleNode & simple)
461  {
462  if (is<CallOperation>(&simple))
463  {
464  considerCallForInlining(simple, lambda);
465  }
466  });
467  }
468 }
469 
470 void
472 {
473  context_->numFunctions++;
474 
475  // Visits the lambda's body and performs inlining of calls when determined to be beneficial
476  visitIntraProceduralRegion(*lambda.subregion(), lambda);
477 
478  // After doing inlining inside lambda, we check if the function is eligible for being inlined
479  if (canBeInlined(lambda))
480  context_->inlineableFunctions.insert(&lambda);
481 
482  // Check if the function is only called once, and not exported from the module.
483  // In which case inlining it is "free" in terms of total code size
484  auto callSummary = ComputeCallSummary(lambda);
485  if (callSummary.HasOnlyDirectCalls() && callSummary.NumDirectCalls() == 1)
486  context_->functionsCalledOnce.insert(&lambda);
487 }
488 
489 void
491 {
492  for (auto node : rvsdg::TopDownTraverser(&region))
493  {
495  *node,
496  [&](rvsdg::PhiNode & phi)
497  {
499  },
500  [&](rvsdg::LambdaNode & lambda)
501  {
502  visitLambda(lambda);
503  });
504  }
505 }
506 
507 void
509 {
510  auto statistics = Statistics::create(module.SourceFilePath().value_or(util::FilePath("")));
511 
512  context_ = std::make_unique<Context>();
513  statistics->start();
515  statistics->stop(
516  context_->numFunctions,
517  context_->inlineableFunctions.Size(),
518  context_->numFunctionCalls,
519  context_->numInlineableCalls,
520  context_->numInlinedCalls);
521 
522  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
523 
524  context_.reset();
525 }
526 
527 }
static jlm::util::StatisticsCollector statisticsCollector
static rvsdg::Input & getCountInput(rvsdg::Node &node)
Definition: alloca.hpp:67
static rvsdg::Input * tryMapMemoryNodeIdToInput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
static MemoryNodeId mapOutputToMemoryNodeId(const rvsdg::Output &output)
static rvsdg::SimpleNode * tryGetMemoryStateExitSplit(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:406
static std::unique_ptr< CallTypeClassifier > ClassifyCall(const rvsdg::SimpleNode &callNode)
Classifies a call node.
Definition: call.cpp:49
static rvsdg::SimpleNode * tryGetMemoryStateEntryMerge(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:388
static constexpr const char * NumCallsInlined_
Definition: inlining.cpp:36
void stop(size_t numFunctions, size_t numInlineableFunctions, size_t numFunctionCalls, size_t numInlineableCalls, size_t numCallsInlined)
Definition: inlining.cpp:52
static std::unique_ptr< Statistics > create(const util::FilePath &sourceFile)
Definition: inlining.cpp:68
static constexpr const char * NumFunctions_
Definition: inlining.cpp:28
static constexpr const char * NumInlineableFunctions_
Definition: inlining.cpp:30
Statistics(const util::FilePath &sourceFile)
Definition: inlining.cpp:41
static constexpr const char * NumInlineableCalls_
Definition: inlining.cpp:34
static constexpr const char * NumFunctionCalls_
Definition: inlining.cpp:32
Performs function inlining on functions that are determined to be good candidates,...
Definition: inlining.hpp:25
bool shouldInline(rvsdg::SimpleNode &callNode, rvsdg::LambdaNode &caller, rvsdg::LambdaNode &callee)
Definition: inlining.cpp:405
void visitInterProceduralRegion(rvsdg::Region &region)
Definition: inlining.cpp:490
void visitLambda(rvsdg::LambdaNode &lambda)
Definition: inlining.cpp:471
void considerCallForInlining(rvsdg::SimpleNode &callNode, rvsdg::LambdaNode &callerLambda)
Definition: inlining.cpp:415
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: inlining.cpp:508
static bool canBeInlined(rvsdg::Region &region, bool topLevelRegion)
Definition: inlining.cpp:348
void visitIntraProceduralRegion(rvsdg::Region &region, rvsdg::LambdaNode &lambda)
Definition: inlining.cpp:447
static void inlineCall(rvsdg::SimpleNode &callNode, rvsdg::LambdaNode &caller, const rvsdg::LambdaNode &callee)
Definition: inlining.cpp:282
~FunctionInlining() noexcept override
std::unique_ptr< Context > context_
Definition: inlining.hpp:129
static MemoryNodeId mapOutputToMemoryNodeId(const rvsdg::Output &output)
static rvsdg::Input * tryMapMemoryNodeIdToInput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
static jlm::rvsdg::Output * Create(rvsdg::Region &region, std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: operators.hpp:1055
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
OutputIteratorRange Outputs() noexcept
Definition: node.hpp:657
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
size_t noutputs() const noexcept
Definition: node.hpp:644
void setEnterPhiNodes(bool value) noexcept
Definition: Trace.hpp:83
Output & trace(Output &output)
Definition: Trace.cpp:22
rvsdg::Input & SingleUser() noexcept
Definition: node.hpp:347
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
rvsdg::Region * subregion() const noexcept
Definition: Phi.hpp:320
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
NodeRange Nodes() noexcept
Definition: region.hpp:328
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
NodeInput * input(size_t index) const noexcept
Definition: simple-node.hpp:82
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
SubregionIteratorRange Subregions()
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:574
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:137
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:158
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:177
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
static void tryRerouteMemoryStateMergeAndSplit(rvsdg::SimpleNode &callEntryMerge, rvsdg::SimpleNode &callExitSplit)
Definition: inlining.cpp:160
static void hoistInlinedAllocas(const rvsdg::LambdaNode &callee, rvsdg::LambdaNode &caller, rvsdg::SubstitutionMap &smap)
Definition: inlining.cpp:240
CallSummary ComputeCallSummary(const rvsdg::LambdaNode &lambdaNode)
Definition: CallSummary.cpp:30
static std::vector< rvsdg::Output * > routeContextVariablesToRegion(rvsdg::Region &region, const rvsdg::LambdaNode &callee)
Definition: inlining.cpp:104
void MatchType(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
static void remove(Node *node)
Definition: region.hpp:978
Output & RouteToRegion(Output &output, Region &region)
Definition: node.cpp:381
Output & traceOutputIntraProcedurally(Output &output)
Definition: Trace.cpp:283
rvsdg::LambdaNode & getSurroundingLambdaNode(rvsdg::Node &node)
Definition: lambda.cpp:272
util::HashSet< const rvsdg::LambdaNode * > inlineableFunctions
Definition: inlining.cpp:78
util::HashSet< const rvsdg::LambdaNode * > functionsCalledOnce
Definition: inlining.cpp:81
static const char * Timer
Definition: Statistics.hpp:251