Jlm
rhls-dne.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
8 #include <jlm/hls/ir/hls.hpp>
11 #include <jlm/rvsdg/traverser.hpp>
12 
13 namespace jlm::hls
14 {
15 
16 static bool
18 {
19  util::HashSet<size_t> resultIndices;
20  util::HashSet<size_t> argumentIndices;
21  const auto subregion = loopNode->subregion();
22  for (const auto argument : subregion->Arguments())
23  {
24  if ((dynamic_cast<BackEdgeArgument *>(argument) && argument->nusers() == 1)
25  || argument->IsDead())
26  {
27  auto & user = *argument->Users().begin();
28  if (const auto result = dynamic_cast<BackEdgeResult *>(&user))
29  {
30  resultIndices.insert(result->index());
31  argumentIndices.insert(argument->index());
32  }
33  }
34  }
35 
36  [[maybe_unused]] const auto numRemovedResults = subregion->RemoveResults(resultIndices);
37  JLM_ASSERT(numRemovedResults == resultIndices.Size());
38 
39  [[maybe_unused]] const auto numRemovedArguments = subregion->RemoveArguments(argumentIndices);
40  JLM_ASSERT(numRemovedArguments == argumentIndices.Size());
41 
42  return numRemovedArguments != 0;
43 }
44 
45 static bool
47 {
48  bool any_changed = false;
49  // go through in reverse because we remove some
50  for (int i = ln->noutputs() - 1; i >= 0; --i)
51  {
52  const auto out = ln->output(i);
53  if (out->nusers() == 0)
54  {
55  ln->removeLoopOutput(out);
56  any_changed = true;
57  }
58  }
59  return any_changed;
60 }
61 
62 static bool
64 {
65  bool any_changed = false;
66  // go through in reverse because we remove some
67  for (int i = ln->ninputs() - 1; i >= 0; --i)
68  {
69  const auto in = ln->input(i);
70  JLM_ASSERT(in->arguments.size() == 1);
71  const auto arg = in->arguments.begin();
72  if (arg->nusers() != 1)
73  continue;
74 
75  auto & user = *arg->Users().begin();
76  if (const auto result = dynamic_cast<rvsdg::RegionResult *>(&user))
77  {
78  result->output()->divert_users(in->origin());
79  ln->removeLoopOutput(result->output());
80  ln->removeLoopInput(arg->input());
81  any_changed = true;
82  }
83  }
84  return any_changed;
85 }
86 
87 static bool
89 {
90  bool any_changed = false;
91  auto sr = ln->subregion();
92  // go through in reverse because we remove some
93  for (int i = ln->ninputs() - 1; i >= 0; --i)
94  {
95  auto in = ln->input(i);
96  JLM_ASSERT(in->arguments.size() == 1);
97  auto arg = in->arguments.begin();
98  if (arg->nusers() == 0)
99  {
100  ln->removeLoopInput(in);
101  any_changed = true;
102  }
103  }
104  // clean up unused arguments - only ones without an input should be left
105  // go through in reverse because we remove some
106  for (int i = sr->narguments() - 1; i >= 0; --i)
107  {
108  auto arg = sr->argument(i);
109  if (auto ba = dynamic_cast<BackEdgeArgument *>(arg))
110  {
111  auto result = ba->result();
112  JLM_ASSERT(*result->Type() == *arg->Type());
113  if (arg->nusers() == 0 || (arg->nusers() == 1 && result->origin() == arg))
114  {
115  sr->RemoveResults({ result->index() });
116  sr->RemoveArguments({ arg->index() });
117  }
118  }
119  else
120  {
121  JLM_ASSERT(arg->nusers() != 0);
122  }
123  }
124  return any_changed;
125 }
126 
127 static bool
129 {
130  const auto mux_op = util::assertedCast<const MuxOperation>(&dmux_node->GetOperation());
131  JLM_ASSERT(mux_op->discarding);
132  // check if all inputs have the same origin
133  bool all_inputs_same = true;
134  auto first_origin = dmux_node->input(1)->origin();
135  for (size_t i = 2; i < dmux_node->ninputs(); ++i)
136  {
137  if (dmux_node->input(i)->origin() != first_origin)
138  {
139  all_inputs_same = false;
140  break;
141  }
142  }
143  if (all_inputs_same)
144  {
145  dmux_node->output(0)->divert_users(first_origin);
146  remove(dmux_node);
147  return true;
148  }
149  return false;
150 }
151 
152 static bool
154 {
155  auto mux_op = util::assertedCast<const MuxOperation>(&ndmux_node->GetOperation());
156  JLM_ASSERT(!mux_op->discarding);
157  // check if all inputs go to outputs of same branch
158  bool all_inputs_same_branch = true;
159  rvsdg::Node * origin_branch = nullptr;
160  for (size_t i = 1; i < ndmux_node->ninputs(); ++i)
161  {
162  if (auto node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*ndmux_node->input(i)->origin()))
163  {
164  if (dynamic_cast<const BranchOperation *>(&node->GetOperation())
165  && ndmux_node->input(i)->origin()->nusers() == 1)
166  {
167  if (i == 1)
168  {
169  origin_branch = node;
170  continue;
171  }
172  else if (origin_branch == node)
173  {
174  continue;
175  }
176  }
177  }
178  all_inputs_same_branch = false;
179  break;
180  }
181  if (all_inputs_same_branch && origin_branch->input(0)->origin() == ndmux_node->input(0)->origin())
182  {
183  // same control origin + all inputs to branch
184  ndmux_node->output(0)->divert_users(origin_branch->input(1)->origin());
185  remove(ndmux_node);
186  JLM_ASSERT(origin_branch != nullptr);
187  remove(origin_branch);
188  return true;
189  }
190  return false;
191 }
192 
193 static bool
194 dead_loop(rvsdg::Node * ndmux_node)
195 {
196  const auto mux_op = util::assertedCast<const MuxOperation>(&ndmux_node->GetOperation());
197  JLM_ASSERT(!mux_op->discarding);
198  // origin is a backedege argument
199  auto backedge_arg = dynamic_cast<BackEdgeArgument *>(ndmux_node->input(2)->origin());
200  if (!backedge_arg)
201  {
202  return false;
203  }
204  // one branch
205  if (ndmux_node->output(0)->nusers() != 1)
206  {
207  return false;
208  }
209  auto branch_in_node =
210  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*ndmux_node->output(0)->Users().begin());
211  if (!branch_in_node || !dynamic_cast<const BranchOperation *>(&branch_in_node->GetOperation()))
212  {
213  return false;
214  }
215  // one buffer
216  if (branch_in_node->output(1)->nusers() != 1)
217  {
218  return false;
219  }
220  auto buf_in_node =
221  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*branch_in_node->output(1)->Users().begin());
222  if (!buf_in_node || !dynamic_cast<const BufferOperation *>(&buf_in_node->GetOperation()))
223  {
224  return false;
225  }
226  auto buf_out = buf_in_node->output(0);
227  if (buf_out != backedge_arg->result()->origin())
228  {
229  // no connection back up
230  return false;
231  }
232  // depend on same control
233  auto branch_cond_origin = branch_in_node->input(0)->origin();
234  auto pred_buf_out_node =
235  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*ndmux_node->input(0)->origin());
236  if (!pred_buf_out_node
237  || !dynamic_cast<const PredicateBufferOperation *>(&pred_buf_out_node->GetOperation()))
238  {
239  return false;
240  }
241  auto pred_buf_cond_origin = pred_buf_out_node->input(0)->origin();
242  // TODO: remove this once predicate buffers decouple combinatorial loops
243  auto extra_buf_out_node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*pred_buf_cond_origin);
244  if (!extra_buf_out_node
245  || !dynamic_cast<const BufferOperation *>(&extra_buf_out_node->GetOperation()))
246  {
247  return false;
248  }
249  auto extra_buf_cond_origin = extra_buf_out_node->input(0)->origin();
250 
251  if (auto pred_be = dynamic_cast<BackEdgeArgument *>(extra_buf_cond_origin))
252  {
253  extra_buf_cond_origin = pred_be->result()->origin();
254  }
255  if (extra_buf_cond_origin != branch_cond_origin)
256  {
257  return false;
258  }
259  // divert users
260  branch_in_node->output(0)->divert_users(ndmux_node->input(1)->origin());
261  buf_out->divert_users(backedge_arg);
262  remove(buf_in_node);
263  remove(branch_in_node);
264  auto region = ndmux_node->region();
265  remove(ndmux_node);
266  region->RemoveResults({ backedge_arg->result()->index() });
267  region->RemoveArguments({ backedge_arg->index() });
268  return true;
269 }
270 
271 static bool
273 {
274  JLM_ASSERT(jlm::rvsdg::is<LoopConstantBufferOperation>(lcb_node));
275 
276  // one branch
277  if (lcb_node->output(0)->nusers() != 1)
278  {
279  return false;
280  }
281  auto [branchNode, branchOperation] =
282  rvsdg::TryGetSimpleNodeAndOptionalOp<BranchOperation>(*lcb_node->output(0)->Users().begin());
283  if (!branchNode || !branchOperation || !branchOperation->loop)
284  {
285  return false;
286  }
287  // no user
288  if (branchNode->output(1)->nusers())
289  {
290  return false;
291  }
292  // depend on same control
293  auto branch_cond_origin = branchNode->input(0)->origin();
294  auto pred_buf_out = dynamic_cast<rvsdg::NodeOutput *>(lcb_node->input(0)->origin());
295  if (!pred_buf_out
296  || !dynamic_cast<const PredicateBufferOperation *>(&pred_buf_out->node()->GetOperation()))
297  {
298  return false;
299  }
300  auto pred_buf_cond_origin = pred_buf_out->node()->input(0)->origin();
301  // TODO: remove this once predicate buffers decouple combinatorial loops
302  auto extra_buf_out = dynamic_cast<rvsdg::NodeOutput *>(pred_buf_cond_origin);
303  if (!extra_buf_out
304  || !dynamic_cast<const BufferOperation *>(&extra_buf_out->node()->GetOperation()))
305  {
306  return false;
307  }
308  auto extra_buf_cond_origin = extra_buf_out->node()->input(0)->origin();
309 
310  if (auto pred_be = dynamic_cast<BackEdgeArgument *>(extra_buf_cond_origin))
311  {
312  extra_buf_cond_origin = pred_be->result()->origin();
313  }
314  if (extra_buf_cond_origin != branch_cond_origin)
315  {
316  return false;
317  }
318  // divert users
319  branchNode->output(0)->divert_users(lcb_node->input(1)->origin());
320  remove(branchNode);
321  remove(lcb_node);
322  return true;
323 }
324 
325 static bool
327 {
328  if (split_node->noutputs() == 1)
329  {
330  split_node->output(0)->divert_users(split_node->input(0)->origin());
331  JLM_ASSERT(split_node->IsDead());
332  remove(split_node);
333  return true;
334  }
335  // this merges downward and removes unused outputs (should only exist as a result of eliminating
336  // merges)
337  std::vector<rvsdg::Output *> combined_outputs;
338  for (size_t i = 0; i < split_node->noutputs(); ++i)
339  {
340  if (split_node->output(i)->IsDead())
341  continue;
342  auto user = get_mem_state_user(split_node->output(i));
343  if (rvsdg::IsOwnerNodeOperation<llvm::MemoryStateSplitOperation>(*user))
344  {
345  auto sub_split = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*user);
346  for (size_t j = 0; j < sub_split->noutputs(); ++j)
347  {
348  combined_outputs.push_back(sub_split->output(j));
349  }
350  }
351  else
352  {
353  combined_outputs.push_back(split_node->output(i));
354  }
355  }
356  if (combined_outputs.size() != split_node->noutputs())
357  {
358  auto new_outputs = llvm::MemoryStateSplitOperation::Create(
359  *split_node->input(0)->origin(),
360  combined_outputs.size());
361  for (size_t i = 0; i < combined_outputs.size(); ++i)
362  {
363  combined_outputs[i]->divert_users(new_outputs[i]);
364  }
365  return true;
366  }
367  return false;
368 }
369 
370 static bool
372 {
373  // remove single merge
374  if (merge_node->ninputs() == 1)
375  {
376  merge_node->output(0)->divert_users(merge_node->input(0)->origin());
377  JLM_ASSERT(merge_node->IsDead());
378  remove(merge_node);
379  return true;
380  }
381  std::vector<rvsdg::Output *> combined_origins;
382  std::unordered_set<rvsdg::SimpleNode *> splits;
383  for (size_t i = 0; i < merge_node->ninputs(); ++i)
384  {
385  auto origin = merge_node->input(i)->origin();
386  if (rvsdg::IsOwnerNodeOperation<llvm::MemoryStateMergeOperation>(*origin))
387  {
388  auto sub_merge = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*origin);
389  for (size_t j = 0; j < sub_merge->ninputs(); ++j)
390  {
391  combined_origins.push_back(sub_merge->input(j)->origin());
392  }
393  }
394  else if (rvsdg::IsOwnerNodeOperation<llvm::MemoryStateSplitOperation>(*origin))
395  {
396  // ensure that there is only one direct connection to a split.
397  // We need to keep one, so that the optimizations for decouple edges work
398  auto split = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*origin);
399  if (!splits.count(split))
400  {
401  splits.insert(split);
402  combined_origins.push_back(origin);
403  }
404  }
405  else
406  {
407  combined_origins.push_back(merge_node->input(i)->origin());
408  }
409  }
410  if (combined_origins.empty())
411  {
412  // if none of the inputs are real keep the first one
413  combined_origins.push_back(merge_node->input(0)->origin());
414  }
415  if (combined_origins.size() != merge_node->ninputs())
416  {
417  auto new_output = llvm::MemoryStateMergeOperation::Create(combined_origins);
418  merge_node->output(0)->divert_users(new_output);
419  JLM_ASSERT(merge_node->IsDead());
420  return true;
421  }
422  return false;
423 }
424 
425 bool
427  rvsdg::Region & region,
428  util::StatisticsCollector & statisticsCollector)
429 {
430  bool any_changed = false;
431  bool changed = false;
432  do
433  {
434  changed = false;
435  for (auto & node : rvsdg::BottomUpTraverser(&region))
436  {
437  if (node->IsDead())
438  {
439  if (rvsdg::is<MemoryRequestOperation>(node))
440  {
441  // TODO: fix this once memory connections are explicit
442  continue;
443  }
444  if (rvsdg::is<LocalMemoryRequestOperation>(node))
445  {
446  continue;
447  }
448  if (rvsdg::is<LocalMemoryResponseOperation>(node))
449  {
450  // TODO: fix - this scenario has only stores and should just be optimized away completely
451  continue;
452  }
453  remove(node);
454  changed = true;
455  }
456  else if (dynamic_cast<rvsdg::LambdaNode *>(node))
457  {
458  JLM_UNREACHABLE("This function works on lambda subregions");
459  }
460  else if (auto ln = dynamic_cast<LoopNode *>(node))
461  {
462  changed |= remove_unused_loop_outputs(ln);
463  changed |= remove_unused_loop_inputs(ln);
464  changed |= remove_unused_loop_backedges(ln);
465  changed |= remove_loop_passthrough(ln);
466  changed |= Run(*ln->subregion(), statisticsCollector);
467  }
468  else if (const auto mux = dynamic_cast<const MuxOperation *>(&node->GetOperation()))
469  {
470  if (mux->discarding)
471  {
472  changed |= dead_spec_gamma(node);
473  }
474  else
475  {
476  changed |= dead_nonspec_gamma(node) || dead_loop(node);
477  }
478  }
479  else if (rvsdg::is<LoopConstantBufferOperation>(node))
480  {
481  changed |= dead_loop_lcb(node);
482  }
483  else if (dynamic_cast<const llvm::MemoryStateSplitOperation *>(&node->GetOperation()))
484  {
485  if (fix_mem_split(node))
486  {
487  changed = true;
488  }
489  }
490  else if (dynamic_cast<const llvm::MemoryStateMergeOperation *>(&node->GetOperation()))
491  {
492  if (fix_mem_merge(node))
493  {
494  changed = true;
495  }
496  }
497  if (changed)
498  {
499  // Changes might break bottom up traversal
500  break;
501  }
502  }
503  any_changed |= changed;
504  } while (changed);
505 
506  return any_changed;
507 }
508 
510 
513 {}
514 
515 void
517  rvsdg::RvsdgModule & rvsdgModule,
518  util::StatisticsCollector & statisticsCollector)
519 {
520  auto & graph = rvsdgModule.Rvsdg();
521  const auto rootRegion = &graph.GetRootRegion();
522  if (rootRegion->numNodes() != 1)
523  {
524  throw util::Error("Root should have only one node now");
525  }
526  const auto lambdaNode =
527  dynamic_cast<const rvsdg::LambdaNode *>(rootRegion->Nodes().begin().ptr());
528  if (!lambdaNode)
529  {
530  throw util::Error("Node needs to be a lambda");
531  }
532  Run(*lambdaNode->subregion(), statisticsCollector);
533 }
534 
535 } // namespace jlm::hls
rvsdg::Region * subregion() const noexcept
Definition: hls.hpp:725
void removeLoopOutput(rvsdg::StructuralOutput *output)
Definition: hls.cpp:197
void removeLoopInput(rvsdg::StructuralInput *input)
Definition: hls.cpp:209
~RhlsDeadNodeElimination() noexcept override
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: rhls-dne.cpp:516
static rvsdg::Output * Create(const std::vector< rvsdg::Output * > &operands)
static std::vector< rvsdg::Output * > Create(rvsdg::Output &operand, const size_t numResults)
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
Node * node() const noexcept
Definition: node.hpp:572
virtual const Operation & GetOperation() const noexcept=0
bool IsDead() const noexcept
Determines whether the node is dead.
Definition: node.hpp:688
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
NodeInput * input(size_t index) const noexcept
Definition: node.hpp:615
NodeOutput * output(size_t index) const noexcept
Definition: node.hpp:650
size_t ninputs() const noexcept
Definition: node.hpp:609
size_t noutputs() const noexcept
Definition: node.hpp:644
UsersRange Users()
Definition: node.hpp:354
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
bool IsDead() const noexcept
Definition: node.hpp:295
size_t nusers() const noexcept
Definition: node.hpp:280
Represents the result of a region.
Definition: region.hpp:120
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
StructuralOutput * output(size_t index) const noexcept
StructuralInput * input(size_t index) const noexcept
Represents an RVSDG transformation.
bool insert(ItemType item)
Definition: HashSet.hpp:210
std::size_t Size() const noexcept
Definition: HashSet.hpp:187
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
static bool remove_loop_passthrough(LoopNode *ln)
Definition: rhls-dne.cpp:63
static bool dead_loop_lcb(rvsdg::Node *lcb_node)
Definition: rhls-dne.cpp:272
static bool dead_loop(rvsdg::Node *ndmux_node)
Definition: rhls-dne.cpp:194
static bool remove_unused_loop_outputs(LoopNode *ln)
Definition: rhls-dne.cpp:46
static bool fix_mem_split(rvsdg::Node *split_node)
Definition: rhls-dne.cpp:326
static bool dead_spec_gamma(rvsdg::Node *dmux_node)
Definition: rhls-dne.cpp:128
static bool remove_unused_loop_inputs(LoopNode *ln)
Definition: rhls-dne.cpp:88
static bool remove_unused_loop_backedges(LoopNode *loopNode)
Definition: rhls-dne.cpp:17
static bool fix_mem_merge(rvsdg::Node *merge_node)
Definition: rhls-dne.cpp:371
rvsdg::Input * get_mem_state_user(rvsdg::Output *state_edge)
static bool dead_nonspec_gamma(rvsdg::Node *ndmux_node)
Definition: rhls-dne.cpp:153
static void remove(Node *node)
Definition: region.hpp:932