Jlm
inlining.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/llvm/ir/Trace.hpp>
11 #include <jlm/rvsdg/gamma.hpp>
12 #include <jlm/rvsdg/MatchType.hpp>
13 #include <jlm/rvsdg/theta.hpp>
14 #include <jlm/rvsdg/Trace.hpp>
15 #include <jlm/rvsdg/traverser.hpp>
16 #include <jlm/util/Statistics.hpp>
17 #include <jlm/util/time.hpp>
18 
19 namespace jlm::llvm
20 {
21 
23 {
24  // The total number of lambda nodes
25  static constexpr const char * NumFunctions_ = "#Functions";
26  // The total number of lambda nodes that are marked as possible to inline
27  static constexpr const char * NumInlineableFunctions_ = "#InlineableFunctions";
28  // The total number of call operations
29  static constexpr const char * NumFunctionCalls_ = "#FunctionCalls";
30  // The number of call operations that could in theory be inlined
31  static constexpr const char * NumInlineableCalls_ = "#InlinableCalls";
32  // The number of call operations that were actually inlined
33  static constexpr const char * NumCallsInlined_ = "#CallsInlined";
34 
35 public:
36  ~Statistics() override = default;
37 
38  explicit Statistics(const util::FilePath & sourceFile)
39  : util::Statistics(Id::FunctionInlining, sourceFile)
40  {}
41 
42  void
44  {
46  }
47 
48  void
50  size_t numFunctions,
51  size_t numInlineableFunctions,
52  size_t numFunctionCalls,
53  size_t numInlineableCalls,
54  size_t numCallsInlined)
55  {
57  AddMeasurement(NumFunctions_, numFunctions);
58  AddMeasurement(NumInlineableFunctions_, numInlineableFunctions);
59  AddMeasurement(NumFunctionCalls_, numFunctionCalls);
60  AddMeasurement(NumInlineableCalls_, numInlineableCalls);
61  AddMeasurement(NumCallsInlined_, numCallsInlined);
62  }
63 
64  static std::unique_ptr<Statistics>
65  create(const util::FilePath & sourceFile)
66  {
67  return std::make_unique<Statistics>(sourceFile);
68  }
69 };
70 
72 {
73  // Functions that are possible to inline
74  // Just because a function is on this list, does not mean it should be inlined
76 
77  // Functions that are not exported from the module, and only called once
79 
80  // Used for statistics
81  size_t numFunctions = 0;
82  size_t numFunctionCalls = 0;
83  size_t numInlineableCalls = 0;
84  size_t numInlinedCalls = 0;
85 };
86 
87 FunctionInlining::~FunctionInlining() noexcept = default;
88 
90  : Transformation("FunctionInlining")
91 {}
92 
100 static std::vector<rvsdg::Output *>
102 {
103  llvm::OutputTracer tracer;
104  // We avoid entering phi nodes, as we can not route from a sibling region
105  tracer.setEnterPhiNodes(false);
106 
107  std::vector<rvsdg::Output *> deps;
108  for (auto & ctxvar : callee.GetContextVars())
109  {
110  auto & traced = tracer.trace(*ctxvar.input->origin());
111  auto & routed = rvsdg::RouteToRegion(traced, region);
112  deps.push_back(&routed);
113  }
114 
115  return deps;
116 }
117 
155 static void
157  rvsdg::SimpleNode & callEntryMerge,
158  rvsdg::SimpleNode & callExitSplit)
159 {
160  const auto callEntryMergeOp =
161  rvsdg::tryGetOperation<CallEntryMemoryStateMergeOperation>(callEntryMerge);
162  const auto callExitSplitOp =
163  rvsdg::tryGetOperation<CallExitMemoryStateSplitOperation>(callExitSplit);
164  JLM_ASSERT(callEntryMergeOp);
165  JLM_ASSERT(callExitSplitOp);
166 
167  // Use the output of the callEntryMerge to look for a lambdaEntrySplit
168  auto & callEntryMergeOutput = *callEntryMerge.output(0);
169  if (callEntryMergeOutput.nusers() != 1)
170  return;
171  auto & user = callEntryMergeOutput.SingleUser();
172  const auto [lambdaEntrySplit, lambdaEntrySplitOp] =
173  rvsdg::TryGetSimpleNodeAndOptionalOp<LambdaEntryMemoryStateSplitOperation>(user);
174  if (!lambdaEntrySplitOp)
175  return;
176 
177  // Use the input of the callExitMerge to look for a lambdaExitMerge
178  auto & callExitSplitInput = *callExitSplit.input(0)->origin();
179  const auto [lambdaExitSplit, lambdaExitSplitOp] =
180  rvsdg::TryGetSimpleNodeAndOptionalOp<LambdaExitMemoryStateMergeOperation>(callExitSplitInput);
181  if (!lambdaExitSplitOp)
182  return;
183 
184  // For each memory state output of the lambdaEntrySplit, move its users or create undef nodes
185  for (auto & output : lambdaEntrySplit->Outputs())
186  {
189  callEntryMerge,
190  memoryStateId);
191  if (mergeInput)
192  {
193  output.divert_users(mergeInput->origin());
194  }
195  else
196  {
197  // The call has no matching memory state going into it, so we create an undef node
198  const auto undef = UndefValueOperation::Create(*output.region(), output.Type());
199  output.divert_users(undef);
200  }
201  }
202 
203  // For each memory state output of the callExitSplit, move its users
204  for (auto & output : callExitSplit.Outputs())
205  {
208  *lambdaExitSplit,
209  memoryStateId);
210  if (exitMergeInput)
211  {
212  output.divert_users(exitMergeInput->origin());
213  }
214  else
215  {
216  // the memory state id was never routed through the inside of the lambda, so route it around
218  callEntryMerge,
219  memoryStateId);
220  if (!entryMergeInput)
221  throw std::runtime_error("MemoryStateId in call exit split not found in call entry merge");
222  output.divert_users(entryMergeInput->origin());
223  }
224  }
225 }
226 
235 static void
237  const rvsdg::LambdaNode & callee,
238  rvsdg::LambdaNode & caller,
239  rvsdg::SubstitutionMap & smap)
240 {
241  // All alloca operations in the callee must be on the top level, with constant count,
242  // otherwise the callee would not have qualified for being inlined
243 
244  for (auto & node : callee.subregion()->Nodes())
245  {
246  if (!is<AllocaOperation>(&node))
247  continue;
248 
249  // Find the same alloca in the caller
250  auto oldAllocaNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(smap.lookup(*node.output(0)));
251  JLM_ASSERT(oldAllocaNode);
252 
253  auto countOrigin = AllocaOperation::getCountInput(*oldAllocaNode).origin();
254  countOrigin = &rvsdg::traceOutputIntraProcedurally(*countOrigin);
255  auto countNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*countOrigin);
256  if (!countNode || countNode->ninputs() != 0)
257  throw std::runtime_error("Alloca did not have a nullary count origin");
258 
259  // Create copies of the count node and alloca node at the top level
260  const auto newCountNode = countNode->copy(caller.subregion(), {});
261  const auto newAllocaNode = oldAllocaNode->copy(caller.subregion(), { newCountNode->output(0) });
262 
263  // Route the outputs of the new alloca to the region of the old alloca, and divert old users
264  for (size_t n = 0; n < newAllocaNode->noutputs(); n++)
265  {
266  auto & oldOutput = *oldAllocaNode->output(n);
267  auto & newOutput = *newAllocaNode->output(n);
268  auto & routed = rvsdg::RouteToRegion(newOutput, *oldAllocaNode->region());
269  oldOutput.divert_users(&routed);
270  }
271 
272  // Remove the old alloca node, which is now dead
273  remove(oldAllocaNode);
274  }
275 }
276 
277 void
279  rvsdg::SimpleNode & callNode,
280  rvsdg::LambdaNode & caller,
281  const rvsdg::LambdaNode & callee)
282 {
283  JLM_ASSERT(is<CallOperation>(&callNode));
284 
285  // Make note of the call's entry and exit memory state nodes, if they exist
286  auto callEntryMemoryStateMerge = CallOperation::tryGetMemoryStateEntryMerge(callNode);
287  auto callExitMemoryStateMerge = CallOperation::tryGetMemoryStateExitSplit(callNode);
288 
289  // Set up substitution map for function arguments and context variables
291  auto arguments = callee.GetFunctionArguments();
292  for (size_t n = 0; n < arguments.size(); n++)
293  {
294  smap.insert(arguments[n], callNode.input(n + 1)->origin());
295  }
296 
297  const auto routedDeps = routeContextVariablesToRegion(*callNode.region(), callee);
298  const auto contextVars = callee.GetContextVars();
299  JLM_ASSERT(contextVars.size() == routedDeps.size());
300  for (size_t n = 0; n < contextVars.size(); n++)
301  {
302  smap.insert(contextVars[n].inner, routedDeps[n]);
303  }
304 
305  // Use the substitution map to copy the function body into the caller region
306  callee.subregion()->copy(callNode.region(), smap);
307 
308  // Move all users of the call node's outputs to the callee's result origins
309  const auto calleeResults = callee.GetFunctionResults();
310  JLM_ASSERT(callNode.noutputs() == calleeResults.size());
311  for (size_t n = 0; n < callNode.noutputs(); n++)
312  {
313  const auto resultOrigin = calleeResults[n]->origin();
314  const auto newOrigin = &smap.lookup(*resultOrigin);
315  callNode.output(n)->divert_users(newOrigin);
316  }
317 
318  // If the callee was copied into a structural node within the caller function,
319  // hoist any copied alloca nodes to the top level region of the caller function
320  if (callNode.region() != caller.subregion())
321  {
322  hoistInlinedAllocas(callee, caller, smap);
323  }
324 
325  // The call node is now dead. Remove it
326  remove(&callNode);
327 
328  // If the call had memory state merge and split nodes,
329  // try connecting memory state edges directly instead
330  if (callEntryMemoryStateMerge && callExitMemoryStateMerge)
331  {
332  tryRerouteMemoryStateMergeAndSplit(*callEntryMemoryStateMerge, *callExitMemoryStateMerge);
333  }
334 }
335 
336 void
338 {
339  auto & caller = rvsdg::getSurroundingLambdaNode(callNode);
340  inlineCall(callNode, caller, callee);
341 }
342 
343 bool
344 FunctionInlining::canBeInlined(rvsdg::Region & region, bool topLevelRegion)
345 {
346  for (auto & node : region.Nodes())
347  {
348  if (const auto structural = dynamic_cast<rvsdg::StructuralNode *>(&node))
349  {
350  for (auto & subregion : structural->Subregions())
351  {
352  if (!canBeInlined(subregion, false))
353  return false;
354  }
355  }
356  else if (is<AllocaOperation>(&node))
357  {
358  // Having allocas that are not on the top level of the function disqualifies from inlining
359  if (!topLevelRegion)
360  return false;
361 
362  // Having allocation sizes that are not compile time constants also disqualifies from inlining
363  auto countOutput = AllocaOperation::getCountInput(node).origin();
364  countOutput = &rvsdg::traceOutputIntraProcedurally(*countOutput);
365  auto countNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*countOutput);
366 
367  // The count must come from a node, and it must be nullary
368  if (!countNode || countNode->ninputs() != 0)
369  return false;
370  }
371  else if (const auto [simple, callOp] =
372  rvsdg::TryGetSimpleNodeAndOptionalOp<CallOperation>(node);
373  simple && callOp)
374  {
375  const auto classification = CallOperation::ClassifyCall(*simple);
376  if (classification->isSetjmpCall())
377  {
378  // Calling setjmp weakens guarantees about local variables in the caller,
379  // but not local variables in the caller's caller. Inlining would mix them up.
380  return false;
381  }
382  if (classification->isVaStartCall())
383  {
384  // Calling va_start requires parameters to be passed in as expected by the ABI.
385  // This gets broken if we start inlining.
386  return false;
387  }
388  }
389  }
390 
391  return true;
392 }
393 
394 bool
396 {
397  return canBeInlined(*callee.subregion(), true);
398 }
399 
400 bool
402  [[maybe_unused]] rvsdg::SimpleNode & callNode,
403  [[maybe_unused]] rvsdg::LambdaNode & caller,
404  rvsdg::LambdaNode & callee)
405 {
406  // For now the inlining heuristic is very simple: Inline functions that are called exactly once
407  return context_->functionsCalledOnce.Contains(&callee);
408 }
409 
410 void
412  rvsdg::SimpleNode & callNode,
413  rvsdg::LambdaNode & callerLambda)
414 {
415  context_->numFunctionCalls++;
416 
417  auto classification = CallOperation::ClassifyCall(callNode);
418  if (!classification->IsDirectCall())
419  return;
420 
421  auto & calleeOutput = classification->GetLambdaOutput();
422  auto callee = rvsdg::TryGetOwnerNode<rvsdg::LambdaNode>(calleeOutput);
423  JLM_ASSERT(callee != nullptr);
424 
425  // We can not inline a function into itself
426  if (callee == &callerLambda)
427  return;
428 
429  // We can only inline functions that have been marked as inlineable
430  if (!context_->inlineableFunctions.Contains(callee))
431  return;
432 
433  // At this point we know that it is technically possible to do inlining
434  context_->numInlineableCalls++;
435  if (shouldInline(callNode, callerLambda, *callee))
436  {
437  context_->numInlinedCalls++;
438  inlineCall(callNode, callerLambda, *callee);
439  }
440 }
441 
442 void
444 {
445  for (auto node : rvsdg::TopDownTraverser(&region))
446  {
448  *node,
449  [&](rvsdg::StructuralNode & structural)
450  {
451  for (auto & subregion : structural.Subregions())
452  {
453  visitIntraProceduralRegion(subregion, lambda);
454  }
455  },
456  [&](rvsdg::SimpleNode & simple)
457  {
458  if (is<CallOperation>(&simple))
459  {
460  considerCallForInlining(simple, lambda);
461  }
462  });
463  }
464 }
465 
466 void
468 {
469  context_->numFunctions++;
470 
471  // Visits the lambda's body and performs inlining of calls when determined to be beneficial
472  visitIntraProceduralRegion(*lambda.subregion(), lambda);
473 
474  // After doing inlining inside lambda, we check if the function is eligible for being inlined
475  if (canBeInlined(lambda))
476  context_->inlineableFunctions.insert(&lambda);
477 
478  // Check if the function is only called once, and not exported from the module.
479  // In which case inlining it is "free" in terms of total code size
480  auto callSummary = ComputeCallSummary(lambda);
481  if (callSummary.HasOnlyDirectCalls() && callSummary.NumDirectCalls() == 1)
482  context_->functionsCalledOnce.insert(&lambda);
483 }
484 
485 void
487 {
488  for (auto node : rvsdg::TopDownTraverser(&region))
489  {
491  *node,
492  [&](rvsdg::PhiNode & phi)
493  {
495  },
496  [&](rvsdg::LambdaNode & lambda)
497  {
498  visitLambda(lambda);
499  });
500  }
501 }
502 
503 void
505 {
506  auto statistics = Statistics::create(module.SourceFilePath().value_or(util::FilePath("")));
507 
508  context_ = std::make_unique<Context>();
509  statistics->start();
511  statistics->stop(
512  context_->numFunctions,
513  context_->inlineableFunctions.Size(),
514  context_->numFunctionCalls,
515  context_->numInlineableCalls,
516  context_->numInlinedCalls);
517 
518  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
519 
520  context_.reset();
521 }
522 
523 }
static rvsdg::Input & getCountInput(rvsdg::Node &node)
Definition: alloca.hpp:69
static rvsdg::Input * tryMapMemoryNodeIdToInput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
static MemoryNodeId mapOutputToMemoryNodeId(const rvsdg::Output &output)
static rvsdg::SimpleNode * tryGetMemoryStateExitSplit(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:387
static std::unique_ptr< CallTypeClassifier > ClassifyCall(const rvsdg::SimpleNode &callNode)
Classifies a call node.
Definition: call.cpp:47
static rvsdg::SimpleNode * tryGetMemoryStateEntryMerge(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:369
static constexpr const char * NumCallsInlined_
Definition: inlining.cpp:33
void stop(size_t numFunctions, size_t numInlineableFunctions, size_t numFunctionCalls, size_t numInlineableCalls, size_t numCallsInlined)
Definition: inlining.cpp:49
static std::unique_ptr< Statistics > create(const util::FilePath &sourceFile)
Definition: inlining.cpp:65
static constexpr const char * NumFunctions_
Definition: inlining.cpp:25
static constexpr const char * NumInlineableFunctions_
Definition: inlining.cpp:27
Statistics(const util::FilePath &sourceFile)
Definition: inlining.cpp:38
static constexpr const char * NumInlineableCalls_
Definition: inlining.cpp:31
static constexpr const char * NumFunctionCalls_
Definition: inlining.cpp:29
Performs function inlining on functions that are determined to be good candidates,...
Definition: inlining.hpp:25
bool shouldInline(rvsdg::SimpleNode &callNode, rvsdg::LambdaNode &caller, rvsdg::LambdaNode &callee)
Definition: inlining.cpp:401
void visitInterProceduralRegion(rvsdg::Region &region)
Definition: inlining.cpp:486
void visitLambda(rvsdg::LambdaNode &lambda)
Definition: inlining.cpp:467
void considerCallForInlining(rvsdg::SimpleNode &callNode, rvsdg::LambdaNode &callerLambda)
Definition: inlining.cpp:411
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: inlining.cpp:504
static bool canBeInlined(rvsdg::Region &region, bool topLevelRegion)
Definition: inlining.cpp:344
void visitIntraProceduralRegion(rvsdg::Region &region, rvsdg::LambdaNode &lambda)
Definition: inlining.cpp:443
static void inlineCall(rvsdg::SimpleNode &callNode, rvsdg::LambdaNode &caller, const rvsdg::LambdaNode &callee)
Definition: inlining.cpp:278
~FunctionInlining() noexcept override
std::unique_ptr< Context > context_
Definition: inlining.hpp:129
static MemoryNodeId mapOutputToMemoryNodeId(const rvsdg::Output &output)
static rvsdg::Input * tryMapMemoryNodeIdToInput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
static jlm::rvsdg::Output * Create(rvsdg::Region &region, std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: operators.hpp:1024
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
OutputIteratorRange Outputs() noexcept
Definition: node.hpp:657
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
size_t noutputs() const noexcept
Definition: node.hpp:644
void setEnterPhiNodes(bool value) noexcept
Definition: Trace.hpp:76
Output & trace(Output &output)
Definition: Trace.cpp:20
rvsdg::Input & SingleUser() noexcept
Definition: node.hpp:347
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
rvsdg::Region * subregion() const noexcept
Definition: Phi.hpp:320
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
NodeRange Nodes() noexcept
Definition: region.hpp:328
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
NodeInput * input(size_t index) const noexcept
Definition: simple-node.hpp:82
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
SubregionIteratorRange Subregions()
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:563
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:134
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:155
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:174
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
static void tryRerouteMemoryStateMergeAndSplit(rvsdg::SimpleNode &callEntryMerge, rvsdg::SimpleNode &callExitSplit)
Definition: inlining.cpp:156
static void hoistInlinedAllocas(const rvsdg::LambdaNode &callee, rvsdg::LambdaNode &caller, rvsdg::SubstitutionMap &smap)
Definition: inlining.cpp:236
CallSummary ComputeCallSummary(const rvsdg::LambdaNode &lambdaNode)
Definition: CallSummary.cpp:30
static std::vector< rvsdg::Output * > routeContextVariablesToRegion(rvsdg::Region &region, const rvsdg::LambdaNode &callee)
Definition: inlining.cpp:101
void MatchType(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
static void remove(Node *node)
Definition: region.hpp:932
Output & RouteToRegion(Output &output, Region &region)
Definition: node.cpp:381
Output & traceOutputIntraProcedurally(Output &output)
Definition: Trace.cpp:223
rvsdg::LambdaNode & getSurroundingLambdaNode(rvsdg::Node &node)
Definition: lambda.cpp:272
util::HashSet< const rvsdg::LambdaNode * > inlineableFunctions
Definition: inlining.cpp:75
util::HashSet< const rvsdg::LambdaNode * > functionsCalledOnce
Definition: inlining.cpp:78
static const char * Timer
Definition: Statistics.hpp:240