Jlm
DeadNodeElimination.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/rvsdg/gamma.hpp>
10 #include <jlm/rvsdg/MatchType.hpp>
11 #include <jlm/rvsdg/theta.hpp>
12 #include <jlm/rvsdg/traverser.hpp>
13 #include <jlm/util/Statistics.hpp>
14 #include <jlm/util/time.hpp>
15 
16 namespace jlm::llvm
17 {
18 
37 {
38 public:
39  void
40  MarkAlive(const jlm::rvsdg::Output & output)
41  {
42  if (auto simpleNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(output))
43  {
44  SimpleNodes_.insert(simpleNode);
45  return;
46  }
47 
48  Outputs_.insert(&output);
49  }
50 
51  bool
52  IsAlive(const jlm::rvsdg::Output & output) const noexcept
53  {
54  if (auto simpleNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(output))
55  {
56  return SimpleNodes_.Contains(simpleNode);
57  }
58 
59  return Outputs_.Contains(&output);
60  }
61 
62  bool
63  IsAlive(const rvsdg::Node & node) const noexcept
64  {
65  if (auto simpleNode = dynamic_cast<const jlm::rvsdg::SimpleNode *>(&node))
66  {
67  return SimpleNodes_.Contains(simpleNode);
68  }
69 
70  for (size_t n = 0; n < node.noutputs(); n++)
71  {
72  if (IsAlive(*node.output(n)))
73  {
74  return true;
75  }
76  }
77 
78  return false;
79  }
80 
81  static std::unique_ptr<Context>
83  {
84  return std::make_unique<Context>();
85  }
86 
87 private:
90 };
91 
96 {
97  const char * MarkTimerLabel_ = "MarkTime";
98  const char * SweepTimerLabel_ = "SweepTime";
99 
100 public:
101  ~Statistics() override = default;
102 
103  explicit Statistics(const util::FilePath & sourceFile)
104  : util::Statistics(Statistics::Id::DeadNodeElimination, sourceFile)
105  {}
106 
107  void
108  StartMarkStatistics(const rvsdg::Graph & graph) noexcept
109  {
110  AddMeasurement(Label::NumRvsdgNodesBefore, rvsdg::nnodes(&graph.GetRootRegion()));
113  }
114 
115  void
117  {
119  }
120 
121  void
123  {
125  }
126 
127  void
128  StopSweepStatistics(const rvsdg::Graph & graph) noexcept
129  {
131  AddMeasurement(Label::NumRvsdgNodesAfter, rvsdg::nnodes(&graph.GetRootRegion()));
132  AddMeasurement(Label::NumRvsdgInputsAfter, rvsdg::ninputs(&graph.GetRootRegion()));
133  }
134 
135  static std::unique_ptr<Statistics>
136  Create(const util::FilePath & sourceFile)
137  {
138  return std::make_unique<Statistics>(sourceFile);
139  }
140 };
141 
142 DeadNodeElimination::~DeadNodeElimination() noexcept = default;
143 
145  : Transformation("DeadNodeElimination")
146 {}
147 
148 void
150 {
152 
153  MarkRegion(region);
154  SweepRegion(region);
155 
156  // Discard internal state to free up memory after we are done
157  Context_.reset();
158 }
159 
160 void
162  rvsdg::RvsdgModule & module,
163  util::StatisticsCollector & statisticsCollector)
164 {
166 
167  auto & rvsdg = module.Rvsdg();
168  auto statistics = Statistics::Create(module.SourceFilePath().value());
169  statistics->StartMarkStatistics(rvsdg);
170  MarkRegion(rvsdg.GetRootRegion());
171  statistics->StopMarkStatistics();
172 
173  statistics->StartSweepStatistics();
174  SweepRvsdg(rvsdg);
175  statistics->StopSweepStatistics(rvsdg);
176 
177  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
178 
179  // Discard internal state to free up memory after we are done
180  Context_.reset();
181 }
182 
183 void
185 {
186  for (size_t n = 0; n < region.nresults(); n++)
187  {
188  MarkOutput(*region.result(n)->origin());
189  }
190 }
191 
192 void
194 {
195  if (Context_->IsAlive(output))
196  {
197  return;
198  }
199 
200  Context_->MarkAlive(output);
201 
202  if (is<rvsdg::GraphImport>(&output))
203  {
204  return;
205  }
206 
207  if (auto gamma = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(output))
208  {
209  MarkOutput(*gamma->predicate()->origin());
210  for (const auto & result : gamma->MapOutputExitVar(output).branchResult)
211  {
212  MarkOutput(*result->origin());
213  }
214  return;
215  }
216 
217  if (auto gamma = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(output))
218  {
219  auto external_origin = std::visit(
220  [](const auto & rolevar) -> rvsdg::Output *
221  {
222  return rolevar.input->origin();
223  },
224  gamma->MapBranchArgument(output));
225  MarkOutput(*external_origin);
226  return;
227  }
228 
229  if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
230  {
231  auto loopvar = theta->MapOutputLoopVar(output);
232  MarkOutput(*theta->predicate()->origin());
233  MarkOutput(*loopvar.post->origin());
234  MarkOutput(*loopvar.input->origin());
235  return;
236  }
237 
238  if (auto theta = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(output))
239  {
240  auto loopvar = theta->MapPreLoopVar(output);
241  MarkOutput(*loopvar.output);
242  MarkOutput(*loopvar.input->origin());
243  return;
244  }
245 
246  if (auto lambda = rvsdg::TryGetOwnerNode<rvsdg::LambdaNode>(output))
247  {
248  for (auto & result : lambda->GetFunctionResults())
249  {
250  MarkOutput(*result->origin());
251  }
252  return;
253  }
254 
255  if (auto lambda = rvsdg::TryGetRegionParentNode<rvsdg::LambdaNode>(output))
256  {
257  if (auto ctxvar = lambda->MapBinderContextVar(output))
258  {
259  // Bound context variable.
260  MarkOutput(*ctxvar->input->origin());
261  return;
262  }
263  else
264  {
265  // Function argument.
266  return;
267  }
268  }
269 
270  if (auto phi = rvsdg::TryGetOwnerNode<rvsdg::PhiNode>(output))
271  {
272  MarkOutput(*phi->MapOutputFixVar(output).result->origin());
273  return;
274  }
275 
276  if (auto phi = rvsdg::TryGetRegionParentNode<rvsdg::PhiNode>(output))
277  {
278  auto var = phi->MapArgument(output);
279  if (auto fix = std::get_if<rvsdg::PhiNode::FixVar>(&var))
280  {
281  // Recursion argument
282  MarkOutput(*fix->result->origin());
283  return;
284  }
285  else if (auto ctx = std::get_if<rvsdg::PhiNode::ContextVar>(&var))
286  {
287  // Bound context variable.
288  MarkOutput(*ctx->input->origin());
289  return;
290  }
291  else
292  {
293  JLM_UNREACHABLE("Phi argument must be either fixpoint or context variable");
294  }
295  }
296 
297  if (const auto deltaNode = rvsdg::TryGetOwnerNode<rvsdg::DeltaNode>(output))
298  {
299  const auto result = deltaNode->subregion()->result(0);
300  MarkOutput(*result->origin());
301  return;
302  }
303 
304  if (rvsdg::TryGetRegionParentNode<rvsdg::DeltaNode>(output))
305  {
306  const auto argument = util::assertedCast<const rvsdg::RegionArgument>(&output);
307  MarkOutput(*argument->input()->origin());
308  return;
309  }
310 
311  if (const auto simpleNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(output))
312  {
313  for (size_t n = 0; n < simpleNode->ninputs(); n++)
314  {
315  MarkOutput(*simpleNode->input(n)->origin());
316  }
317  return;
318  }
319 
320  JLM_UNREACHABLE("We should have never reached this statement.");
321 }
322 
323 void
325 {
326  SweepRegion(rvsdg.GetRootRegion());
327 
328  // Remove dead imports
329  util::HashSet<size_t> indices;
330  for (const auto argument : rvsdg.GetRootRegion().Arguments())
331  {
332  if (!Context_->IsAlive(*argument))
333  {
334  indices.insert(argument->index());
335  }
336  }
337  [[maybe_unused]] const auto numRemovedArguments = rvsdg.GetRootRegion().RemoveArguments(indices);
338  JLM_ASSERT(numRemovedArguments == indices.Size());
339 }
340 
341 void
343 {
344  region.prune(false);
345 
346  for (const auto node : rvsdg::BottomUpTraverser(&region))
347  {
348  if (!Context_->IsAlive(*node))
349  {
350  remove(node);
351  }
352  else if (const auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(node))
353  {
354  SweepStructuralNode(*structuralNode);
355  }
356  }
357 
358  JLM_ASSERT(region.numBottomNodes() == 0);
359 }
360 
361 void
363 {
365  node,
366  [this](rvsdg::GammaNode & node)
367  {
368  SweepGamma(node);
369  },
370  [this](rvsdg::ThetaNode & node)
371  {
372  SweepTheta(node);
373  },
374  [this](rvsdg::LambdaNode & node)
375  {
376  SweepLambda(node);
377  },
378  [this](rvsdg::PhiNode & node)
379  {
380  SweepPhi(node);
381  },
382  [](rvsdg::DeltaNode & node)
383  {
384  SweepDelta(node);
385  });
386 }
387 
388 void
390 {
391  // Remove dead exit vars.
392  std::vector<rvsdg::GammaNode::ExitVar> deadExitVars;
393  for (const auto & exitvar : gammaNode.GetExitVars())
394  {
395  if (!Context_->IsAlive(*exitvar.output))
396  {
397  deadExitVars.push_back(exitvar);
398  }
399  }
400  gammaNode.RemoveExitVars(deadExitVars);
401 
402  // Sweep gamma subregions
403  for (size_t r = 0; r < gammaNode.nsubregions(); r++)
404  {
405  SweepRegion(*gammaNode.subregion(r));
406  }
407 
408  // Remove dead entry vars.
409  std::vector<rvsdg::GammaNode::EntryVar> deadEntryVars;
410  for (const auto & entryvar : gammaNode.GetEntryVars())
411  {
412  bool alive = std::any_of(
413  entryvar.branchArgument.begin(),
414  entryvar.branchArgument.end(),
415  [this](const rvsdg::Output * arg)
416  {
417  return Context_->IsAlive(*arg);
418  });
419  if (!alive)
420  {
421  deadEntryVars.push_back(entryvar);
422  }
423  }
424  gammaNode.RemoveEntryVars(deadEntryVars);
425 }
426 
427 void
429 {
430  // Determine loop variables to be removed.
431  std::vector<rvsdg::ThetaNode::LoopVar> loopvars;
432  for (const auto & loopvar : thetaNode.GetLoopVars())
433  {
434  if (!Context_->IsAlive(*loopvar.pre) && !Context_->IsAlive(*loopvar.output))
435  {
436  loopvar.post->divert_to(loopvar.pre);
437  loopvars.push_back(loopvar);
438  }
439  }
440 
441  // Now that the loop variables to be eliminated only point to
442  // their own pre-iteration values, any outputs within the subregion
443  // that only contributed to computing the post-iteration values
444  // of the variables are unlinked and can be removed as well.
445  SweepRegion(*thetaNode.subregion());
446 
447  // There are now no other users of the pre-iteration values of the
448  // variables to be removed left in the subregion anymore.
449  // The variables have become "loop-invariant" and can simply
450  // be eliminated from the theta node.
451  thetaNode.RemoveLoopVars(std::move(loopvars));
452 }
453 
454 void
456 {
457  SweepRegion(*lambdaNode.subregion());
458  lambdaNode.PruneLambdaInputs();
459 }
460 
461 void
463 {
464  std::vector<rvsdg::PhiNode::FixVar> deadFixvars;
465  std::vector<rvsdg::PhiNode::ContextVar> deadCtxvars;
466 
467  for (const auto & fixvar : phiNode.GetFixVars())
468  {
469  bool isDead = !Context_->IsAlive(*fixvar.output) && !Context_->IsAlive(*fixvar.recref);
470  if (isDead)
471  {
472  deadFixvars.push_back(fixvar);
473  // Temporarily redirect the variable so it refers to itself
474  // (so the object is simply defined to be "itself").
475  fixvar.result->divert_to(fixvar.recref);
476  }
477  }
478 
479  SweepRegion(*phiNode.subregion());
480 
481  for (const auto & ctxvar : phiNode.GetContextVars())
482  {
483  if (ctxvar.inner->IsDead())
484  {
485  deadCtxvars.push_back(ctxvar);
486  }
487  }
488 
489  phiNode.RemoveContextVars(std::move(deadCtxvars));
490  phiNode.RemoveFixVars(std::move(deadFixvars));
491 }
492 
493 void
495 {
496  // A delta subregion can only contain simple nodes. Thus, a simple prune is sufficient.
497  deltaNode.subregion()->prune(false);
498 
499  deltaNode.PruneDeltaInputs();
500 }
501 
502 }
Dead Node Elimination context class.
void MarkAlive(const jlm::rvsdg::Output &output)
static std::unique_ptr< Context > Create()
util::HashSet< const jlm::rvsdg::Output * > Outputs_
bool IsAlive(const rvsdg::Node &node) const noexcept
util::HashSet< const jlm::rvsdg::SimpleNode * > SimpleNodes_
bool IsAlive(const jlm::rvsdg::Output &output) const noexcept
Dead Node Elimination statistics class.
void StartMarkStatistics(const rvsdg::Graph &graph) noexcept
static std::unique_ptr< Statistics > Create(const util::FilePath &sourceFile)
Statistics(const util::FilePath &sourceFile)
void StopSweepStatistics(const rvsdg::Graph &graph) noexcept
Dead Node Elimination Optimization.
void SweepRegion(rvsdg::Region &region) const
void SweepGamma(rvsdg::GammaNode &gammaNode) const
void SweepStructuralNode(rvsdg::StructuralNode &node) const
static void SweepDelta(rvsdg::DeltaNode &deltaNode)
void run(rvsdg::Region &region)
void MarkOutput(const jlm::rvsdg::Output &output)
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
std::unique_ptr< Context > Context_
void SweepRvsdg(rvsdg::Graph &rvsdg) const
void SweepLambda(rvsdg::LambdaNode &lambdaNode) const
void SweepTheta(rvsdg::ThetaNode &thetaNode) const
~DeadNodeElimination() noexcept override
void MarkRegion(const rvsdg::Region &region)
void SweepPhi(rvsdg::PhiNode &phiNode) const
Delta node.
Definition: delta.hpp:129
rvsdg::Region * subregion() const noexcept
Definition: delta.hpp:234
size_t PruneDeltaInputs()
Definition: delta.hpp:277
Conditional operator / pattern matching.
Definition: gamma.hpp:99
void RemoveEntryVars(const std::vector< EntryVar > &entryVars)
Removes the given entry variables.
Definition: gamma.cpp:420
void RemoveExitVars(const std::vector< ExitVar > &exitVars)
Removes the given exit variables.
Definition: gamma.cpp:401
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
Definition: gamma.cpp:361
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Definition: gamma.cpp:303
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
size_t PruneLambdaInputs()
Definition: lambda.hpp:245
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
rvsdg::Region * subregion() const noexcept
Definition: Phi.hpp:320
std::vector< FixVar > GetFixVars() const noexcept
Gets all fixpoint variables.
Definition: Phi.cpp:63
void RemoveContextVars(std::vector< ContextVar > vars)
Removes context variables from phi node.
Definition: Phi.cpp:150
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: Phi.cpp:50
void RemoveFixVars(std::vector< FixVar > vars)
Removes fixpoint variables from the phi node.
Definition: Phi.cpp:168
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
RegionResult * result(size_t index) const noexcept
Definition: region.hpp:471
RegionArgumentRange Arguments() noexcept
Definition: region.hpp:272
size_t nresults() const noexcept
Definition: region.hpp:465
void prune(bool recursive)
Definition: region.cpp:323
size_t numBottomNodes() const noexcept
Definition: region.hpp:499
size_t RemoveArguments(const util::HashSet< size_t > &indices)
Definition: region.cpp:210
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
size_t nsubregions() const noexcept
rvsdg::Region * subregion(size_t index) const noexcept
void RemoveLoopVars(std::vector< LoopVar > loopVars)
Removes loop variables.
Definition: theta.cpp:63
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
rvsdg::Region * subregion() const noexcept
Definition: theta.hpp:79
bool insert(ItemType item)
Definition: HashSet.hpp:210
std::size_t Size() const noexcept
Definition: HashSet.hpp:187
bool Contains(const ItemType &item) const noexcept
Definition: HashSet.hpp:150
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:563
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:134
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:155
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:174
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
Global memory state passed between functions.
void MatchTypeOrFail(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
static void remove(Node *node)
Definition: region.hpp:932
size_t nnodes(const jlm::rvsdg::Region *region) noexcept
Definition: region.cpp:629
size_t ninputs(const rvsdg::Region *region) noexcept
Definition: region.cpp:682
static const char * NumRvsdgNodesBefore
Definition: Statistics.hpp:211
static const char * NumRvsdgNodesAfter
Definition: Statistics.hpp:212
static const char * NumRvsdgInputsAfter
Definition: Statistics.hpp:215
static const char * NumRvsdgInputsBefore
Definition: Statistics.hpp:214