Jlm
push.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
8 #include <jlm/llvm/opt/push.hpp>
9 #include <jlm/rvsdg/gamma.hpp>
10 #include <jlm/rvsdg/theta.hpp>
11 #include <jlm/rvsdg/traverser.hpp>
12 #include <jlm/util/Statistics.hpp>
13 #include <jlm/util/time.hpp>
14 
15 #include <deque>
16 #include <jlm/rvsdg/MatchType.hpp>
17 #include <jlm/rvsdg/theta.hpp>
18 
19 namespace jlm::llvm
20 {
21 
23 {
24 public:
25  ~Statistics() override = default;
26 
27  explicit Statistics(const util::FilePath & sourceFile)
28  : util::Statistics(Statistics::Id::PushNodes, sourceFile)
29  {}
30 
31  void
32  start(const rvsdg::Graph & graph) noexcept
33  {
36  }
37 
38  void
39  end(const rvsdg::Graph & graph) noexcept
40  {
43  }
44 
45  static std::unique_ptr<Statistics>
46  Create(const util::FilePath & sourceFile)
47  {
48  return std::make_unique<Statistics>(sourceFile);
49  }
50 };
51 
53 {
54 public:
55  explicit Context(rvsdg::LambdaNode & lambdaNode)
56  : LambdaSubregion_(lambdaNode.subregion())
57  {}
58 
60  getLambdaSubregion() const noexcept
61  {
62  return *LambdaSubregion_;
63  }
64 
65  void
66  addRegionDepth(const rvsdg::Region & region, const size_t depth) noexcept
67  {
68  JLM_ASSERT(RegionDepth_.find(&region) == RegionDepth_.end());
69  RegionDepth_[&region] = depth;
70  }
71 
72  size_t
73  getRegionDeph(const rvsdg::Region & region) const noexcept
74  {
75  return RegionDepth_.at(&region);
76  }
77 
78  void
79  addTargetRegion(const rvsdg::Node & node, rvsdg::Region & region) noexcept
80  {
81  JLM_ASSERT(TargetRegion_.find(&node) == TargetRegion_.end());
82  TargetRegion_[&node] = &region;
83  }
84 
86  getTargetRegion(const rvsdg::Node & node) const noexcept
87  {
88  return *TargetRegion_.at(&node);
89  }
90 
91  static std::unique_ptr<Context>
92  create(rvsdg::LambdaNode & lambdaNode)
93  {
94  return std::make_unique<Context>(lambdaNode);
95  }
96 
97 private:
99  std::unordered_map<const rvsdg::Region *, size_t> RegionDepth_{};
100  std::unordered_map<const rvsdg::Node *, rvsdg::Region *> TargetRegion_{};
101 };
102 
103 NodeHoisting::~NodeHoisting() noexcept = default;
104 
106  : Transformation("NodeHoisting")
107 {}
108 
109 size_t
111 {
112  if (dynamic_cast<const rvsdg::LambdaNode *>(region.node()))
113  {
114  return 0;
115  }
116 
117  const auto parentRegion = region.node()->region();
118  return context_->getRegionDeph(*parentRegion) + 1;
119 }
120 
121 bool
123 {
124  if (!is<MemoryStateType>(loopVar.output->Type()))
125  return false;
126 
127  if (loopVar.pre->nusers() != 1)
128  return false;
129 
130  const auto userNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*loopVar.pre->Users().begin());
131  const auto originNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*loopVar.post->origin());
132 
133  if (userNode != originNode)
134  return false;
135 
136  return true;
137 }
138 
141 {
142  // Handle lambda region arguments
143  if (rvsdg::TryGetRegionParentNode<rvsdg::LambdaNode>(output))
144  {
145  return *output.region();
146  }
147 
148  // Handle gamma region arguments
149  if (const auto gammaNode = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(output))
150  {
151  if (output.Type()->Kind() == rvsdg::TypeKind::State)
152  {
153  // FIXME: This is a bit too conservative. For example, it avoids that load and store nodes are
154  // hoisted out of a gamma node, but we would only like to avoid store nodes being hoisted out.
155  // For load nodes, it is legal to hoist them out if they are not preceded by an IOBarrier.
156  return *output.region();
157  }
158 
159  const auto roleVar = gammaNode->MapBranchArgument(output);
160  if (const auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
161  {
162  return computeTargetRegion(*entryVar->input->origin());
163  }
164 
165  return *output.region();
166  }
167 
168  // Handle theta region arguments
169  if (const auto thetaNode = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(output))
170  {
171  const auto loopVar = thetaNode->MapPreLoopVar(output);
172  if (rvsdg::ThetaLoopVarIsInvariant(loopVar))
173  {
174  return computeTargetRegion(*loopVar.input->origin());
175  }
176 
177  if (isInvariantMemoryStateLoopVar(loopVar))
178  {
179  return computeTargetRegion(*loopVar.input->origin());
180  }
181 
182  return *output.region();
183  }
184 
185  // Handle gamma outputs
186  if (const auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(output))
187  {
188  return context_->getTargetRegion(*gammaNode);
189  }
190 
191  // Handle theta outputs
192  if (const auto thetaNode = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
193  {
194  return context_->getTargetRegion(*thetaNode);
195  }
196 
197  // Handle simple node outputs
198  if (const auto node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(output))
199  {
200  return context_->getTargetRegion(*node);
201  }
202 
203  throw std::logic_error("Unhandled output type!");
204 }
205 
208 {
209  if (node.ninputs() == 0)
210  {
211  // Nodes that can only produce states, such as UndefValueOperation, will be removed in the
212  // back-end. There is no need to hoist them.
213  JLM_ASSERT(node.noutputs() == 1);
214  return node.output(0)->Type()->Kind() == rvsdg::TypeKind::State
215  ? *node.region()
216  : context_->getLambdaSubregion();
217  }
218 
219  // Compute target regions for all the inputs of the node
220  std::vector<rvsdg::Region *> targetRegions;
221  for (auto & input : node.Inputs())
222  {
223  auto & targetRegion = computeTargetRegion(*input.origin());
224  if (&targetRegion == node.region())
225  {
226  // One of the node's predecessors cannot be hoisted, which means we can also not hoist this
227  // node
228  return *node.region();
229  }
230 
231  targetRegions.push_back(&targetRegion);
232  }
233 
234  // Compute the lowermost target region in the region tree
235  return **std::max_element(
236  targetRegions.begin(),
237  targetRegions.end(),
238  [&](const rvsdg::Region * region1, const rvsdg::Region * region2)
239  {
240  return context_->getRegionDeph(*region1) < context_->getRegionDeph(*region2);
241  });
242 }
243 
244 void
246 {
247  const auto regionDepth = computeRegionDepth(region);
248  context_->addRegionDepth(region, regionDepth);
249 
250  for (const auto node : rvsdg::TopDownConstTraverser(&region))
251  {
253  *node,
254  [&](const rvsdg::StructuralNode & structuralNode)
255  {
256  // FIXME: We currently do not allow structural nodes (gamma and theta nodes) to be hoisted
257  context_->addTargetRegion(structuralNode, *structuralNode.region());
258 
259  // Handle innermost regions
260  for (auto & subregion : structuralNode.Subregions())
261  {
262  markNodes(subregion);
263  }
264  },
265  [&](const rvsdg::SimpleNode & simpleNode)
266  {
267  rvsdg::Region & targetRegion = computeTargetRegion(simpleNode);
268  context_->addTargetRegion(*node, targetRegion);
269  },
270  []()
271  {
272  throw std::logic_error("Unhandled node type!");
273  });
274  }
275 }
276 
279 {
280  if (output.region() == &targetRegion)
281  return output;
282 
283  // Handle gamma subregion arguments
284  if (const auto gammaNode = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(output))
285  {
286  const auto roleVar = gammaNode->MapBranchArgument(output);
287  if (const auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
288  {
289  return getOperandFromTargetRegion(*entryVar->input->origin(), targetRegion);
290  }
291  }
292 
293  // Handle theta subregion arguments
294  if (const auto thetaNode = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(output))
295  {
296  const auto loopVar = thetaNode->MapPreLoopVar(output);
298  return getOperandFromTargetRegion(*loopVar.input->origin(), targetRegion);
299  }
300 
301  throw std::logic_error("Unhandled output type!");
302 }
303 
304 std::vector<rvsdg::Output *>
306 {
307  std::vector<rvsdg::Output *> operands;
308  for (auto & input : node.Inputs())
309  {
310  auto & operand = getOperandFromTargetRegion(*input.origin(), targetRegion);
311  operands.push_back(&operand);
312  }
313 
314  return operands;
315 }
316 
317 void
319 {
320  auto & targetRegion = context_->getTargetRegion(node);
321  JLM_ASSERT(&targetRegion != node.region());
322 
323  const auto operands = getOperandsFromTargetRegion(node, targetRegion);
324  const auto copiedNode = node.copy(&targetRegion, operands);
325 
326  // FIXME: I really would like to have a zip function here, but C++ does not really seem to have
327  // anything better to offer
328  auto itOrg = std::begin(node.Outputs());
329  const auto endOrg = std::end(node.Outputs());
330  auto itCpy = std::begin(copiedNode->Outputs());
331  const auto endCpy = std::end(copiedNode->Outputs());
332  JLM_ASSERT(std::distance(itOrg, endOrg) == std::distance(itCpy, endCpy));
333 
334  for (; itOrg != endOrg; ++itOrg, ++itCpy)
335  {
336  auto & outputOrg = *itOrg;
337  auto & outputCpy = *itCpy;
338  auto & newOutputOrg = rvsdg::RouteToRegion(outputCpy, *node.region());
339  outputOrg.divert_users(&newOutputOrg);
340  }
341 }
342 
343 void
345 {
346  // FIXME: We a routing unnecessary values through gamma and theta nodes. We should cluster
347  // subgraphs that need to be hoisted to avoid unnecessary routing.
348  for (const auto node : rvsdg::TopDownTraverser(&region))
349  {
350  auto & targetRegion = context_->getTargetRegion(*node);
351  if (&targetRegion != node->region())
352  {
353  copyNodeToTargetRegion(*node);
354  }
355 
356  // Handle innermost regions
357  if (const auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(node))
358  {
359  for (auto & subregion : structuralNode->Subregions())
360  {
361  hoistNodes(subregion);
362  }
363  }
364  }
365 
366  region.prune(false);
367 }
368 
369 void
371 {
372  context_ = Context::create(lambdaNode);
373 
374  markNodes(*lambdaNode.subregion());
375  hoistNodes(*lambdaNode.subregion());
376 
377  context_.reset();
378 }
379 
380 void
382 {
383  for (auto & node : rvsdg::TopDownTraverser(&region))
384  {
386  *node,
387  [&](rvsdg::LambdaNode & lambdaNode)
388  {
389  hoistNodesInLambda(lambdaNode);
390  },
391  [&](rvsdg::PhiNode & phiNode)
392  {
393  hoistNodesInRootRegion(*phiNode.subregion());
394  },
395  [](rvsdg::DeltaNode &)
396  {
397  // Nothing needs to be done
398  },
399  [](rvsdg::SimpleNode &)
400  {
401  // Nothing needs to be done
402  },
403  [&]()
404  {
405  throw std::logic_error(util::strfmt("Unhandled node type: ", node->DebugString()));
406  });
407  }
408 }
409 
410 void
412 {
413  auto statistics = Statistics::Create(rvsdgModule.SourceFilePath().value());
414 
415  statistics->start(rvsdgModule.Rvsdg());
416  hoistNodesInRootRegion(rvsdgModule.Rvsdg().GetRootRegion());
417  statistics->end(rvsdgModule.Rvsdg());
418 
419  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
420 }
421 }
std::unordered_map< const rvsdg::Region *, size_t > RegionDepth_
Definition: push.cpp:99
Context(rvsdg::LambdaNode &lambdaNode)
Definition: push.cpp:55
size_t getRegionDeph(const rvsdg::Region &region) const noexcept
Definition: push.cpp:73
rvsdg::Region & getLambdaSubregion() const noexcept
Definition: push.cpp:60
rvsdg::Region & getTargetRegion(const rvsdg::Node &node) const noexcept
Definition: push.cpp:86
rvsdg::Region * LambdaSubregion_
Definition: push.cpp:98
void addTargetRegion(const rvsdg::Node &node, rvsdg::Region &region) noexcept
Definition: push.cpp:79
void addRegionDepth(const rvsdg::Region &region, const size_t depth) noexcept
Definition: push.cpp:66
std::unordered_map< const rvsdg::Node *, rvsdg::Region * > TargetRegion_
Definition: push.cpp:100
static std::unique_ptr< Context > create(rvsdg::LambdaNode &lambdaNode)
Definition: push.cpp:92
void end(const rvsdg::Graph &graph) noexcept
Definition: push.cpp:39
Statistics(const util::FilePath &sourceFile)
Definition: push.cpp:27
static std::unique_ptr< Statistics > Create(const util::FilePath &sourceFile)
Definition: push.cpp:46
void start(const rvsdg::Graph &graph) noexcept
Definition: push.cpp:32
Node Hoisting Transformation.
Definition: push.hpp:37
void hoistNodesInRootRegion(rvsdg::Region &region)
Definition: push.cpp:381
void hoistNodes(rvsdg::Region &region)
Definition: push.cpp:344
void hoistNodesInLambda(rvsdg::LambdaNode &lambdaNode)
Definition: push.cpp:370
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: push.cpp:411
~NodeHoisting() noexcept override
rvsdg::Region & computeTargetRegion(const rvsdg::Node &node) const
Definition: push.cpp:207
static std::vector< rvsdg::Output * > getOperandsFromTargetRegion(rvsdg::Node &node, rvsdg::Region &targetRegion)
Definition: push.cpp:305
std::unique_ptr< Context > context_
Definition: push.hpp:84
static bool isInvariantMemoryStateLoopVar(const rvsdg::ThetaNode::LoopVar &loopVar)
Definition: push.cpp:122
void copyNodeToTargetRegion(rvsdg::Node &node) const
Definition: push.cpp:318
void markNodes(const rvsdg::Region &region)
Definition: push.cpp:245
static rvsdg::Output & getOperandFromTargetRegion(rvsdg::Output &output, rvsdg::Region &targetRegion)
Definition: push.cpp:278
size_t computeRegionDepth(const rvsdg::Region &region) const
Definition: push.cpp:110
Delta node.
Definition: delta.hpp:129
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
OutputIteratorRange Outputs() noexcept
Definition: node.hpp:657
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
InputIteratorRange Inputs() noexcept
Definition: node.hpp:622
NodeOutput * output(size_t index) const noexcept
Definition: node.hpp:650
size_t ninputs() const noexcept
Definition: node.hpp:609
size_t noutputs() const noexcept
Definition: node.hpp:644
virtual Node * copy(rvsdg::Region *region, const std::vector< jlm::rvsdg::Output * > &operands) const
Definition: node.cpp:369
rvsdg::Region * region() const noexcept
Definition: node.cpp:151
UsersRange Users()
Definition: node.hpp:354
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
size_t nusers() const noexcept
Definition: node.hpp:280
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
rvsdg::StructuralNode * node() const noexcept
Definition: region.hpp:369
void prune(bool recursive)
Definition: region.cpp:323
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
SubregionIteratorRange Subregions()
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:563
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:134
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:155
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:174
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
static bool ThetaLoopVarIsInvariant(const ThetaNode::LoopVar &loopVar) noexcept
Definition: theta.hpp:227
Output & RouteToRegion(Output &output, Region &region)
Definition: node.cpp:381
@ State
Designate a state type.
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
size_t ninputs(const rvsdg::Region *region) noexcept
Definition: region.cpp:682
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35
Description of a loop-carried variable.
Definition: theta.hpp:50
rvsdg::Output * pre
Variable before iteration (input argument to subregion).
Definition: theta.hpp:58
rvsdg::Output * output
Variable at loop exit (output of theta).
Definition: theta.hpp:66
rvsdg::Input * post
Variable after iteration (output result from subregion).
Definition: theta.hpp:62
static const char * Timer
Definition: Statistics.hpp:240
static const char * NumRvsdgInputsAfter
Definition: Statistics.hpp:215
static const char * NumRvsdgInputsBefore
Definition: Statistics.hpp:214