Jlm
DeadNodeElimination.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/rvsdg/delta.hpp>
10 #include <jlm/rvsdg/gamma.hpp>
11 #include <jlm/rvsdg/MatchType.hpp>
12 #include <jlm/rvsdg/theta.hpp>
13 #include <jlm/rvsdg/traverser.hpp>
14 #include <jlm/util/Statistics.hpp>
15 #include <jlm/util/time.hpp>
16 
17 #include <algorithm>
18 
19 namespace jlm::llvm
20 {
21 
40 {
41 public:
46  bool
47  markAlive(const rvsdg::Output & output)
48  {
49  return Outputs_.insert(&output);
50  }
51 
56  bool
57  markAlive(const rvsdg::SimpleNode & simpleNode)
58  {
59  return SimpleNodes_.insert(&simpleNode);
60  }
61 
62  bool
63  isAlive(const rvsdg::Output & output) const noexcept
64  {
65  return Outputs_.Contains(&output);
66  }
67 
68  bool
69  isAlive(const rvsdg::SimpleNode & simpleNode) const noexcept
70  {
71  return SimpleNodes_.Contains(&simpleNode);
72  }
73 
74  static std::unique_ptr<Context>
76  {
77  return std::make_unique<Context>();
78  }
79 
80 private:
83 };
84 
89 {
90  const char * MarkTimerLabel_ = "MarkTime";
91  const char * SweepTimerLabel_ = "SweepTime";
92 
93 public:
94  ~Statistics() override = default;
95 
96  explicit Statistics(const util::FilePath & sourceFile)
97  : util::Statistics(Id::DeadNodeElimination, sourceFile)
98  {}
99 
100  void
101  startMarkStatistics(const rvsdg::Graph & graph) noexcept
102  {
103  AddMeasurement(Label::NumRvsdgNodesBefore, rvsdg::nnodes(&graph.GetRootRegion()));
106  }
107 
108  void
110  {
112  }
113 
114  void
116  {
118  }
119 
120  void
121  stopSweepStatistics(const rvsdg::Graph & graph) noexcept
122  {
124  AddMeasurement(Label::NumRvsdgNodesAfter, rvsdg::nnodes(&graph.GetRootRegion()));
125  AddMeasurement(Label::NumRvsdgInputsAfter, rvsdg::ninputs(&graph.GetRootRegion()));
126  }
127 
128  static std::unique_ptr<Statistics>
129  create(const util::FilePath & sourceFile)
130  {
131  return std::make_unique<Statistics>(sourceFile);
132  }
133 };
134 
135 DeadNodeElimination::~DeadNodeElimination() noexcept = default;
136 
138  : Transformation("DeadNodeElimination")
139 {}
140 
141 void
143 {
145 
146  markRegion(region);
147  sweepRegion(region);
148 
149  // Discard internal state to free up memory after we are done
150  Context_.reset();
151 }
152 
153 void
155  rvsdg::RvsdgModule & module,
157 {
159 
160  auto & rvsdg = module.Rvsdg();
161  auto statistics = Statistics::create(module.SourceFilePath().value());
162  statistics->startMarkStatistics(rvsdg);
163  markRegion(rvsdg.GetRootRegion());
164  statistics->stopMarkStatistics();
165 
166  statistics->startSweepStatistics();
167  sweepRvsdg(rvsdg);
168  statistics->stopSweepStatistics(rvsdg);
169 
170  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
171 
172  // Discard internal state to free up memory after we are done
173  Context_.reset();
174 }
175 
176 static bool
178 {
179  const auto simpleNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(output);
180  return simpleNode && is<LoadNonVolatileOperation>(simpleNode)
181  && is<MemoryStateType>(output.Type());
182 }
183 
184 void
186 {
187  for (const auto result : region.Results())
188  {
189  markOutput(*result->origin());
190  }
191 }
192 
193 void
195 {
196  if (auto simpleNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(output))
197  {
198  if (Context_->isAlive(*simpleNode))
199  return;
200 
201  // LoadNonVolatile operations get special handling,
202  // where memory state outputs are marked as alive individually.
203  // The SimpleNode itself only gets marked as alive if the loaded value is used.
205  {
206  // Mark the memory state output as alive, or return if it was already marked.
207  if (!Context_->markAlive(output))
208  return;
209 
210  // Continue marking only the origin of the corresponding memory state input
212 
213  return;
214  }
215 
216  // The simple node is alive, mark it along with the origins of all its inputs
217  Context_->markAlive(*simpleNode);
218  for (auto & input : simpleNode->Inputs())
219  {
220  markOutput(*input.origin());
221  }
222 
223  return;
224  }
225 
226  // Mark the output as alive, or return if it has already been marked
227  if (!Context_->markAlive(output))
228  {
229  return;
230  }
231 
232  if (is<rvsdg::GraphImport>(&output))
233  {
234  return;
235  }
236 
237  if (const auto gamma = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(output))
238  {
239  markOutput(*gamma->predicate()->origin());
240  for (const auto & result : gamma->MapOutputExitVar(output).branchResult)
241  {
242  markOutput(*result->origin());
243  }
244  return;
245  }
246 
247  if (const auto gamma = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(output))
248  {
249  markOutput(*gamma->mapBranchArgumentToInput(output).origin());
250  return;
251  }
252 
253  if (const auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
254  {
255  const auto loopVar = theta->MapOutputLoopVar(output);
256  markOutput(*theta->predicate()->origin());
257  markOutput(*loopVar.post->origin());
258  markOutput(*loopVar.input->origin());
259  return;
260  }
261 
262  if (const auto theta = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(output))
263  {
264  const auto loopVar = theta->MapPreLoopVar(output);
265  markOutput(*loopVar.output);
266  markOutput(*loopVar.input->origin());
267  return;
268  }
269 
270  if (const auto lambda = rvsdg::TryGetOwnerNode<rvsdg::LambdaNode>(output))
271  {
272  for (const auto result : lambda->GetFunctionResults())
273  {
274  markOutput(*result->origin());
275  }
276  return;
277  }
278 
279  if (const auto lambda = rvsdg::TryGetRegionParentNode<rvsdg::LambdaNode>(output))
280  {
281  if (const auto ctxVar = lambda->MapBinderContextVar(output))
282  {
283  // Bound context variable.
284  markOutput(*ctxVar->input->origin());
285  return;
286  }
287 
288  // Function argument.
289  return;
290  }
291 
292  if (const auto phi = rvsdg::TryGetOwnerNode<rvsdg::PhiNode>(output))
293  {
294  markOutput(*phi->MapOutputFixVar(output).result->origin());
295  return;
296  }
297 
298  if (const auto phi = rvsdg::TryGetRegionParentNode<rvsdg::PhiNode>(output))
299  {
300  const auto var = phi->MapArgument(output);
301  if (const auto fixVar = std::get_if<rvsdg::PhiNode::FixVar>(&var))
302  {
303  // Recursion argument
304  markOutput(*fixVar->result->origin());
305  return;
306  }
307 
308  if (const auto ctxVar = std::get_if<rvsdg::PhiNode::ContextVar>(&var))
309  {
310  // Bound context variable.
311  markOutput(*ctxVar->input->origin());
312  return;
313  }
314 
315  throw std::logic_error("Phi argument must be either fixpoint or context variable");
316  }
317 
318  if (const auto deltaNode = rvsdg::TryGetOwnerNode<rvsdg::DeltaNode>(output))
319  {
320  const auto result = deltaNode->subregion()->result(0);
321  markOutput(*result->origin());
322  return;
323  }
324 
325  if (rvsdg::TryGetRegionParentNode<rvsdg::DeltaNode>(output))
326  {
327  const auto argument = util::assertedCast<const rvsdg::RegionArgument>(&output);
328  markOutput(*argument->input()->origin());
329  return;
330  }
331 
332  throw std::logic_error("We should have never reached this statement.");
333 }
334 
335 void
337 {
338  sweepRegion(rvsdg.GetRootRegion());
339 
340  // Remove dead imports
341  util::HashSet<size_t> indices;
342  for (const auto argument : rvsdg.GetRootRegion().Arguments())
343  {
344  if (!Context_->isAlive(*argument))
345  {
346  indices.insert(argument->index());
347  }
348  }
349  [[maybe_unused]] const auto numRemovedArguments = rvsdg.GetRootRegion().RemoveArguments(indices);
350  JLM_ASSERT(numRemovedArguments == indices.Size());
351 }
352 
353 void
355 {
356  auto isAlive = [this](const rvsdg::Node & node)
357  {
358  if (const auto simpleNode = dynamic_cast<const rvsdg::SimpleNode *>(&node))
359  {
360  return Context_->isAlive(*simpleNode);
361  }
362 
363  for (auto & output : node.Outputs())
364  {
365  if (Context_->isAlive(output))
366  {
367  return true;
368  }
369  }
370 
371  return false;
372  };
373 
374  for (const auto node : rvsdg::BottomUpTraverser(&region))
375  {
376  if (!isAlive(*node))
377  {
378  removeNode(*node);
379  }
380  else if (const auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(node))
381  {
382  sweepStructuralNode(*structuralNode);
383  }
384  }
385 
386  JLM_ASSERT(region.numBottomNodes() == 0);
387 }
388 
389 void
391 {
393  node,
394  [this](rvsdg::GammaNode & node)
395  {
396  sweepGamma(node);
397  },
398  [this](rvsdg::ThetaNode & node)
399  {
400  sweepTheta(node);
401  },
402  [this](rvsdg::LambdaNode & node)
403  {
404  sweepLambda(node);
405  },
406  [this](rvsdg::PhiNode & node)
407  {
408  sweepPhi(node);
409  },
410  [](rvsdg::DeltaNode & node)
411  {
412  sweepDelta(node);
413  },
414  [&node]()
415  {
416  throw std::logic_error(util::strfmt("Unhandled node type: ", node.DebugString()));
417  });
418 }
419 
420 void
422 {
423  // Remove dead exit variables.
424  const auto deadGammaOutputs = gammaNode.GetOutputsWhere(
425  [this](const rvsdg::Output & output)
426  {
427  return !Context_->isAlive(output);
428  });
429  gammaNode.RemoveExitVars(deadGammaOutputs);
430 
431  // Sweep gamma subregions
432  for (auto & subregion : gammaNode.Subregions())
433  {
434  sweepRegion(subregion);
435  }
436 
437  // Remove dead entry variables.
438  std::vector<rvsdg::GammaNode::EntryVar> deadEntryVars;
439  for (const auto & entryVar : gammaNode.GetEntryVars())
440  {
441  const bool isAlive = std::any_of(
442  entryVar.branchArgument.begin(),
443  entryVar.branchArgument.end(),
444  [this](const rvsdg::Output * arg)
445  {
446  return Context_->isAlive(*arg);
447  });
448  if (!isAlive)
449  {
450  deadEntryVars.push_back(entryVar);
451  }
452  }
453  gammaNode.RemoveEntryVars(deadEntryVars);
454 }
455 
456 void
458 {
459  // Determine dead loop variables.
460  std::vector<rvsdg::ThetaNode::LoopVar> loopVars;
461  for (const auto & loopVar : thetaNode.GetLoopVars())
462  {
463  if (!Context_->isAlive(*loopVar.pre) && !Context_->isAlive(*loopVar.output))
464  {
465  loopVar.post->divert_to(loopVar.pre);
466  loopVars.push_back(loopVar);
467  }
468  }
469 
470  // Now that the loop variables to be eliminated only point to
471  // their own pre-iteration values, any outputs within the subregion
472  // that only contributed to computing the post-iteration values
473  // of the variables are unlinked and can be removed as well.
474  sweepRegion(*thetaNode.subregion());
475 
476  // There are now no other users of the pre-iteration values of the
477  // variables to be removed left in the subregion anymore.
478  // The variables have become "loop-invariant" and can simply
479  // be eliminated from the theta node.
480  thetaNode.RemoveLoopVars(std::move(loopVars));
481 }
482 
483 void
485 {
486  sweepRegion(*lambdaNode.subregion());
487  lambdaNode.PruneLambdaInputs();
488 }
489 
490 void
492 {
493  std::vector<rvsdg::PhiNode::FixVar> deadFixVars;
494  std::vector<rvsdg::PhiNode::ContextVar> deadCtxVars;
495 
496  for (const auto & fixVar : phiNode.GetFixVars())
497  {
498  if (!Context_->isAlive(*fixVar.output) && !Context_->isAlive(*fixVar.recref))
499  {
500  deadFixVars.push_back(fixVar);
501  // Temporarily redirect the variable so it refers to itself
502  // (so the object is simply defined to be "itself").
503  fixVar.result->divert_to(fixVar.recref);
504  }
505  }
506 
507  sweepRegion(*phiNode.subregion());
508 
509  for (const auto & ctxvar : phiNode.GetContextVars())
510  {
511  if (ctxvar.inner->IsDead())
512  {
513  deadCtxVars.push_back(ctxvar);
514  }
515  }
516 
517  phiNode.RemoveContextVars(std::move(deadCtxVars));
518  phiNode.RemoveFixVars(std::move(deadFixVars));
519 }
520 
521 void
523 {
524  // A delta subregion can only contain simple nodes. Thus, a simple prune is sufficient.
525  deltaNode.subregion()->prune(false);
526 
527  deltaNode.PruneDeltaInputs();
528 }
529 
530 void
532 {
533  if (is<LoadNonVolatileOperation>(node.GetOperation()))
534  {
535  for (auto & memoryStateOutput : LoadOperation::MemoryStateOutputs(node))
536  {
537  const auto origin = LoadOperation::MapMemoryStateOutputToInput(memoryStateOutput).origin();
538  memoryStateOutput.divert_users(origin);
539  }
541  }
542 
543  remove(&node);
544 }
545 
546 }
static jlm::util::StatisticsCollector statisticsCollector
Dead Node Elimination context class.
util::HashSet< const rvsdg::SimpleNode * > SimpleNodes_
util::HashSet< const rvsdg::Output * > Outputs_
bool markAlive(const rvsdg::Output &output)
bool markAlive(const rvsdg::SimpleNode &simpleNode)
static std::unique_ptr< Context > create()
bool isAlive(const rvsdg::Output &output) const noexcept
bool isAlive(const rvsdg::SimpleNode &simpleNode) const noexcept
Dead Node Elimination statistics class.
static std::unique_ptr< Statistics > create(const util::FilePath &sourceFile)
void stopSweepStatistics(const rvsdg::Graph &graph) noexcept
Statistics(const util::FilePath &sourceFile)
void startMarkStatistics(const rvsdg::Graph &graph) noexcept
Dead Node Elimination Optimization.
void sweepStructuralNode(rvsdg::StructuralNode &node) const
static void sweepDelta(rvsdg::DeltaNode &deltaNode)
void sweepGamma(rvsdg::GammaNode &gammaNode) const
void sweepLambda(rvsdg::LambdaNode &lambdaNode) const
void run(rvsdg::Region &region)
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
std::unique_ptr< Context > Context_
void markRegion(const rvsdg::Region &region)
void sweepPhi(rvsdg::PhiNode &phiNode) const
void sweepRvsdg(rvsdg::Graph &rvsdg) const
~DeadNodeElimination() noexcept override
static void removeNode(rvsdg::Node &node)
void markOutput(const rvsdg::Output &output)
void sweepRegion(rvsdg::Region &region) const
void sweepTheta(rvsdg::ThetaNode &thetaNode) const
static rvsdg::Output & LoadedValueOutput(const rvsdg::Node &node)
Definition: Load.hpp:84
static rvsdg::Node::OutputIteratorRange MemoryStateOutputs(const rvsdg::Node &node) noexcept
Definition: Load.hpp:116
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
Definition: Load.hpp:157
Delta node.
Definition: delta.hpp:129
rvsdg::Region * subregion() const noexcept
Definition: delta.hpp:234
size_t PruneDeltaInputs()
Definition: delta.hpp:277
Conditional operator / pattern matching.
Definition: gamma.hpp:99
void RemoveEntryVars(const std::vector< EntryVar > &entryVars)
Removes the given entry variables.
Definition: gamma.cpp:459
std::vector< Output * > GetOutputsWhere(const F &match)
Gets all gamma outputs that match the condition defined by match.
Definition: gamma.hpp:317
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Definition: gamma.cpp:305
void RemoveExitVars(const std::vector< Output * > &gammaOutputs)
Removes the exit variables corresponding to the given gammaOutputs.
Definition: gamma.cpp:439
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
size_t PruneLambdaInputs()
Definition: lambda.hpp:245
virtual const Operation & GetOperation() const noexcept=0
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
rvsdg::Region * subregion() const noexcept
Definition: Phi.hpp:320
std::vector< FixVar > GetFixVars() const noexcept
Gets all fixpoint variables.
Definition: Phi.cpp:63
void RemoveContextVars(std::vector< ContextVar > vars)
Removes context variables from phi node.
Definition: Phi.cpp:150
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: Phi.cpp:50
void RemoveFixVars(std::vector< FixVar > vars)
Removes fixpoint variables from the phi node.
Definition: Phi.cpp:168
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
RegionArgumentRange Arguments() noexcept
Definition: region.hpp:272
void prune(bool recursive)
Definition: region.cpp:323
size_t numBottomNodes() const noexcept
Definition: region.hpp:499
size_t RemoveArguments(const util::HashSet< size_t > &indices)
Definition: region.cpp:210
RegionResultRange Results() noexcept
Definition: region.hpp:290
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
std::string DebugString() const override
SubregionIteratorRange Subregions()
void RemoveLoopVars(std::vector< LoopVar > loopVars)
Removes loop variables.
Definition: theta.cpp:63
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
rvsdg::Region * subregion() const noexcept
Definition: theta.hpp:79
bool insert(ItemType item)
Definition: HashSet.hpp:210
std::size_t Size() const noexcept
Definition: HashSet.hpp:187
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:574
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:137
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:158
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:177
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
static bool isLoadNonVolatileMemoryStateOutput(const rvsdg::Output &output)
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
static void remove(Node *node)
Definition: region.hpp:978
size_t nnodes(const jlm::rvsdg::Region *region) noexcept
Definition: region.cpp:785
size_t ninputs(const rvsdg::Region *region) noexcept
Definition: region.cpp:838
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35
static const char * NumRvsdgNodesBefore
Definition: Statistics.hpp:214
static const char * NumRvsdgNodesAfter
Definition: Statistics.hpp:215
static const char * NumRvsdgInputsAfter
Definition: Statistics.hpp:218
static const char * NumRvsdgInputsBefore
Definition: Statistics.hpp:217