Jlm
LoopUnswitching.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2017 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/llvm/opt/pull.hpp>
10 #include <jlm/rvsdg/gamma.hpp>
11 #include <jlm/rvsdg/node.hpp>
13 #include <jlm/rvsdg/theta.hpp>
14 #include <jlm/rvsdg/traverser.hpp>
15 #include <jlm/util/Statistics.hpp>
16 #include <jlm/util/time.hpp>
17 
18 namespace jlm::llvm
19 {
20 
22 
23 LoopUnswitchingDefaultHeuristic::~LoopUnswitchingDefaultHeuristic() noexcept = default;
24 
25 bool
26 LoopUnswitchingDefaultHeuristic::shouldUnswitchLoop(
27  ThetaGammaPredicateCorrelation & correlation) const noexcept
28 {
29  const auto & thetaNode = correlation.thetaNode();
30  const auto & gammaNode = correlation.gammaNode();
31 
32  for (const auto node : rvsdg::TopDownConstTraverser(thetaNode.subregion()))
33  {
34  if (node == &gammaNode)
35  continue;
36 
37  if (rvsdg::is<rvsdg::StructuralOperation>(node))
38  return false;
39  }
40 
41  return true;
42 }
43 
44 std::shared_ptr<const LoopUnswitchingDefaultHeuristic>
46 {
47  static const LoopUnswitchingDefaultHeuristic instance;
48  return std::shared_ptr<const LoopUnswitchingDefaultHeuristic>(std::shared_ptr<void>(), &instance);
49 }
50 
52 {
53 public:
54  ~Statistics() override = default;
55 
56  explicit Statistics(const util::FilePath & sourceFile)
57  : util::Statistics(Id::LoopUnswitching, sourceFile)
58  {}
59 
60  void
61  start(const rvsdg::Graph & graph) noexcept
62  {
66  }
67 
68  void
69  end(const rvsdg::Graph & graph) noexcept
70  {
71  AddMeasurement(Label::NumRvsdgNodesAfter, rvsdg::nnodes(&graph.GetRootRegion()));
74  }
75 
76  static std::unique_ptr<Statistics>
77  Create(const util::FilePath & sourceFile)
78  {
79  return std::make_unique<Statistics>(sourceFile);
80  }
81 };
82 
85 {
86  auto [matchNode, matchOperation] =
87  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::MatchOperation>(*theta.predicate()->origin());
88  if (!matchOperation)
89  return nullptr;
90 
91  // The output of the match node should only be connected to the theta and gamma node
92  if (matchNode->output(0)->nusers() != 2)
93  return nullptr;
94 
95  rvsdg::GammaNode * gammaNode = nullptr;
96  for (const auto & user : matchNode->output(0)->Users())
97  {
98  if (&user == theta.predicate())
99  continue;
100 
101  gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user);
102  if (!gammaNode)
103  return nullptr;
104  }
105 
106  // Only apply loop unswitching if the theta node is a converted for loop, i.e., everything but the
107  // predicate is contained in the gamma
108  for (const auto & loopVar : theta.GetLoopVars())
109  {
110  const auto origin = loopVar.post->origin();
111  if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*origin))
112  {
113  // origin is a theta subregion argument
114  }
115  else if (rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*origin) == gammaNode)
116  {
117  // origin is an output of gamma node
118  }
119  else
120  {
121  // we don't want to invert this
122  return nullptr;
123  }
124  }
125 
126  return gammaNode;
127 }
128 
129 void
131  rvsdg::GammaNode & gammaNode,
132  const rvsdg::ThetaNode & thetaNode)
133 {
134  // Ensure all loop variables are routed through the gamma node
135  for (const auto & loopVar : thetaNode.GetLoopVars())
136  {
137  if (rvsdg::TryGetOwnerNode<rvsdg::Node>(*loopVar.post->origin()) != &gammaNode)
138  {
139  auto [input, branchArgument] = gammaNode.AddEntryVar(loopVar.post->origin());
140  JLM_ASSERT(branchArgument.size() == 2);
141  auto [_, output] = gammaNode.AddExitVar({ branchArgument[0], branchArgument[1] });
142  loopVar.post->divert_to(output);
143  }
144  }
145 
146  pullin_top(&gammaNode);
147 }
148 
149 std::vector<std::vector<rvsdg::Node *>>
151  const rvsdg::ThetaNode & thetaNode,
152  const rvsdg::GammaNode & gammaNode)
153 {
154  JLM_ASSERT(gammaNode.region()->node() == &thetaNode);
155 
156  auto depthMap = rvsdg::computeDepthMap(*thetaNode.subregion());
157 
158  std::vector<std::vector<rvsdg::Node *>> nodes;
159  for (auto & node : thetaNode.subregion()->Nodes())
160  {
161  if (&node == &gammaNode)
162  continue;
163 
164  const auto depth = depthMap[&node];
165  if (depth >= nodes.size())
166  nodes.resize(depth + 1);
167  nodes[depth].push_back(&node);
168  }
169 
170  return nodes;
171 }
172 
173 void
175  rvsdg::Region & target,
176  rvsdg::SubstitutionMap & substitutionMap,
177  const std::vector<std::vector<rvsdg::Node *>> & nodes)
178 {
179  for (auto & sameDepthNodes : nodes)
180  {
181  for (const auto & node : sameDepthNodes)
182  node->copy(&target, substitutionMap);
183  }
184 }
185 
186 bool
188  const rvsdg::ThetaNode & thetaNode,
189  const rvsdg::GammaNode & gammaNode)
190 {
191  for (const auto & loopVar : thetaNode.GetLoopVars())
192  {
193  const auto node = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*loopVar.post->origin());
194  if (node != &gammaNode)
195  return false;
196  }
197 
198  return true;
199 }
200 
201 bool
203 {
204  auto oldGammaNode = IsUnswitchable(oldThetaNode);
205  if (!oldGammaNode)
206  return false;
207 
208  SinkNodesIntoGamma(*oldGammaNode, oldThetaNode);
209 
210  // At this point, we have established the following invariant:
211  // All loop variables of the original theta node are routed through its contained gamma node,
212  // i.e., the origin of a loop variables' post value must be the output of the gamma node. This
213  // helps to simplify the transformation significantly.
214  JLM_ASSERT(allLoopVarsAreRoutedThroughGamma(oldThetaNode, *oldGammaNode));
215 
216  // FIXME: We should get this correlation from the IsUnswitchable() method, if it is possible
217  // to perform the transformation.
218  const auto correlationOpt = computeThetaGammaPredicateCorrelation(oldThetaNode);
219  JLM_ASSERT(correlationOpt.has_value());
220  auto & correlation = correlationOpt.value();
221 
222  if (!heuristic_->shouldUnswitchLoop(*correlation))
223  return false;
224 
225  // The rest of the transformation is now performed in several stages:
226  // 1. Copy the predicate subgraph into the parent region of the old theta node.
227  // 2. Create gamma node (using the copied predicate) and the new theta node in the new gamma
228  // nodes' repetition subregion.
229  // 3. Copy old repetition subregion into the new theta node.
230  // 4. Copy predicate subgraph into new theta node.
231  // 5. Adjust the loop variables of the new theta node, finalizing the new theta node.
232  // 6. Add exit variables to the new gamma node, finalizing the new gamma node.
233  // 7. Copy old exit subregion into the parent region of the old theta node.
234  // 8. Divert the users of old theta nodes' loop variables, rendering the old theta
235  // node dead.
236  //
237  // Along the way, we keep track of replaced variables with substitution maps for some of the
238  // stages. These substitution maps are then utilized by succeeding stages to find the correct
239  // replacements for old outputs.
240 
241  // Stage 1 - Copy predicate subgraph into the old theta nodes' parent region
242  rvsdg::SubstitutionMap stage1SMap;
243  {
244  for (const auto & oldLoopVar : oldThetaNode.GetLoopVars())
245  stage1SMap.insert(oldLoopVar.pre, oldLoopVar.input->origin());
246 
247  auto conditionNodes = CollectPredicateNodes(oldThetaNode, *oldGammaNode);
248  CopyPredicateNodes(*oldThetaNode.region(), stage1SMap, conditionNodes);
249  }
250 
251  // Stage 2 - Create new gamma and theta node
252  auto newGammaNode = rvsdg::GammaNode::create(
253  &stage1SMap.lookup(*oldGammaNode->predicate()->origin()),
254  oldGammaNode->nsubregions());
255 
256  const auto [oldRepetitionSubregion, oldExitSubregion] =
257  determineGammaSubregionRoles(*correlation).value();
258  const auto & newRepetitionSubregion = newGammaNode->subregion(oldRepetitionSubregion->index());
259  const auto repetitionSubregionIndex = oldRepetitionSubregion->index();
260  const auto exitSubregionIndex = oldExitSubregion->index();
261 
262  auto newThetaNode = rvsdg::ThetaNode::create(newRepetitionSubregion);
263 
264  std::unordered_map<rvsdg::Input *, rvsdg::Input *> oldGammaNewGammaInputMap;
265  std::unordered_map<rvsdg::Input *, rvsdg::Input *> oldGammaNewThetaInputMap;
266  for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
267  {
268  auto & newOrigin = stage1SMap.lookup(*oldInput->origin());
269  auto newEntryVar = newGammaNode->AddEntryVar(&newOrigin);
270  auto newLoopVar =
271  newThetaNode->AddLoopVar(newEntryVar.branchArgument[repetitionSubregionIndex]);
272  oldGammaNewGammaInputMap[oldInput] = newEntryVar.input;
273  oldGammaNewThetaInputMap[oldInput] = newLoopVar.input;
274  }
275 
276  // Stage 3 - Copy repetition subregion into new theta node
277  rvsdg::SubstitutionMap stage3SMap;
278  {
279  for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
280  {
281  auto newLoopInput = oldGammaNewThetaInputMap[oldInput];
282  auto newLoopVar = newThetaNode->MapInputLoopVar(*newLoopInput);
283  stage3SMap.insert(oldBranchArgument[repetitionSubregionIndex], newLoopVar.pre);
284  }
285 
286  oldRepetitionSubregion->copy(newThetaNode->subregion(), stage3SMap);
287  }
288 
289  // Stage 4 - Copy predicate subgraph into new theta node subregion
290  rvsdg::SubstitutionMap stage4SMap;
291  {
292  for (auto oldLoopVar : oldThetaNode.GetLoopVars())
293  {
294  auto oldExitVar = oldGammaNode->MapOutputExitVar(*oldLoopVar.post->origin());
295  auto oldOrigin = oldExitVar.branchResult[repetitionSubregionIndex]->origin();
296  auto & newOrigin = stage3SMap.lookup(*oldOrigin);
297  stage4SMap.insert(oldLoopVar.pre, &newOrigin);
298  }
299 
300  auto conditionNodes = CollectPredicateNodes(oldThetaNode, *oldGammaNode);
301  CopyPredicateNodes(*newThetaNode->subregion(), stage4SMap, conditionNodes);
302  }
303 
304  // Stage 5 - Adjust loop variables
305  newThetaNode->set_predicate(&stage4SMap.lookup(*oldThetaNode.predicate()->origin()));
306  for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
307  {
308  auto newLoopVarInput = oldGammaNewThetaInputMap[oldInput];
309  auto newLoopVar = newThetaNode->MapInputLoopVar(*newLoopVarInput);
310  auto & newOrigin = stage4SMap.lookup(*oldInput->origin());
311  newLoopVar.post->divert_to(&newOrigin);
312  }
313 
314  // Stage 6 - Add new gamma exit variables
315  std::unordered_map<rvsdg::Input *, rvsdg::Output *> oldGammaNewGammaOutputMap;
316  {
317  for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
318  {
319  auto newGammaInput = oldGammaNewGammaInputMap[oldInput];
320  auto newEntryVar =
321  std::get<rvsdg::GammaNode::EntryVar>(newGammaNode->MapInput(*newGammaInput));
322  auto newLoopVarInput = oldGammaNewThetaInputMap[oldInput];
323  auto newLoopVar = newThetaNode->MapInputLoopVar(*newLoopVarInput);
324 
325  std::vector<rvsdg::Output *> values(2);
326  values[exitSubregionIndex] = newEntryVar.branchArgument[exitSubregionIndex];
327  values[repetitionSubregionIndex] = newLoopVar.output;
328  auto newExitVar = newGammaNode->AddExitVar(values);
329  oldGammaNewGammaOutputMap[oldInput] = newExitVar.output;
330  }
331  }
332 
333  // Stage 7 - Copy exit subregion into old theta node parent region
334  rvsdg::SubstitutionMap stage7SMap;
335  {
336  for (const auto & [oldInput, oldBranchArgument] : oldGammaNode->GetEntryVars())
337  {
338  auto newOrigin = oldGammaNewGammaOutputMap[oldInput];
339  stage7SMap.insert(oldBranchArgument[exitSubregionIndex], newOrigin);
340  }
341 
342  oldExitSubregion->copy(oldThetaNode.region(), stage7SMap);
343  }
344 
345  // Stage 8 - Replace old theta node outputs
346  for (auto oldLoopVar : oldThetaNode.GetLoopVars())
347  {
348  auto oldExitVar = oldGammaNode->MapOutputExitVar(*oldLoopVar.post->origin());
349  auto oldOrigin = oldExitVar.branchResult[exitSubregionIndex]->origin();
350  auto & newOrigin = stage7SMap.lookup(*oldOrigin);
351  oldLoopVar.output->divert_users(&newOrigin);
352  }
353 
354  return true;
355 }
356 
357 void
359 {
360  bool unswitchedLoop = false;
361  for (auto & node : region.Nodes())
362  {
363  if (const auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(&node))
364  {
365  // Handle innermost theta nodes first
366  for (auto & subregion : structuralNode->Subregions())
367  HandleRegion(subregion);
368 
369  if (const auto thetaNode = dynamic_cast<rvsdg::ThetaNode *>(structuralNode))
370  {
371  unswitchedLoop |= UnswitchLoop(*thetaNode);
372  }
373  }
374  }
375 
376  // If we successfully unswitched a loop, ensure the old nodes are pruned.
377  if (unswitchedLoop)
378  {
379  region.prune(false);
380  }
381 }
382 
383 LoopUnswitching::~LoopUnswitching() noexcept = default;
384 
385 void
387  rvsdg::RvsdgModule & rvsdgModule,
388  util::StatisticsCollector & statisticsCollector)
389 {
390  auto statistics = Statistics::Create(rvsdgModule.SourceFilePath().value());
391 
392  statistics->start(rvsdgModule.Rvsdg());
393  HandleRegion(rvsdgModule.Rvsdg().GetRootRegion());
394  statistics->end(rvsdgModule.Rvsdg());
395 
396  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
397 }
398 
399 void
401  rvsdg::RvsdgModule & rvsdgModule,
403  std::shared_ptr<const LoopUnswitchingHeuristic> heuristic)
404 {
405  LoopUnswitching loopUnswitching(std::move(heuristic));
406  loopUnswitching.Run(rvsdgModule, statisticsCollector);
407 }
408 
409 }
static jlm::util::StatisticsCollector statisticsCollector
static std::shared_ptr< const LoopUnswitchingDefaultHeuristic > create()
virtual ~LoopUnswitchingHeuristic() noexcept
Statistics(const util::FilePath &sourceFile)
void end(const rvsdg::Graph &graph) noexcept
static std::unique_ptr< Statistics > Create(const util::FilePath &sourceFile)
void start(const rvsdg::Graph &graph) noexcept
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static rvsdg::GammaNode * IsUnswitchable(const rvsdg::ThetaNode &thetaNode)
static bool allLoopVarsAreRoutedThroughGamma(const rvsdg::ThetaNode &thetaNode, const rvsdg::GammaNode &gammaNode)
void HandleRegion(rvsdg::Region &region)
static void SinkNodesIntoGamma(rvsdg::GammaNode &gammaNode, const rvsdg::ThetaNode &thetaNode)
static void CreateAndRun(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector, std::shared_ptr< const LoopUnswitchingHeuristic > heuristic)
std::shared_ptr< const LoopUnswitchingHeuristic > heuristic_
static void CopyPredicateNodes(rvsdg::Region &target, rvsdg::SubstitutionMap &substitutionMap, const std::vector< std::vector< rvsdg::Node * >> &nodes)
static std::vector< std::vector< rvsdg::Node * > > CollectPredicateNodes(const rvsdg::ThetaNode &thetaNode, const rvsdg::GammaNode &gammaNode)
~LoopUnswitching() noexcept override
bool UnswitchLoop(rvsdg::ThetaNode &thetaNode)
Conditional operator / pattern matching.
Definition: gamma.hpp:99
EntryVar AddEntryVar(rvsdg::Output *origin)
Routes a variable into the gamma branches.
Definition: gamma.cpp:260
static GammaNode * create(jlm::rvsdg::Output *predicate, size_t nalternatives)
Definition: gamma.hpp:161
ExitVar AddExitVar(const std::vector< rvsdg::Output * > &values)
Routes per-branch result of gamma to output.
Definition: gamma.cpp:362
Output * origin() const noexcept
Definition: node.hpp:58
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
rvsdg::StructuralNode * node() const noexcept
Definition: region.hpp:369
void prune(bool recursive)
Definition: region.cpp:323
NodeRange Nodes() noexcept
Definition: region.hpp:328
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
RegionResult * predicate() const noexcept
Definition: theta.hpp:85
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
rvsdg::Region * subregion() const noexcept
Definition: theta.hpp:79
static ThetaNode * create(rvsdg::Region *parent)
Definition: theta.hpp:73
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:574
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:137
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:158
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:177
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
void pullin_top(rvsdg::GammaNode *gamma)
Definition: pull.cpp:135
std::optional< std::unique_ptr< ThetaGammaPredicateCorrelation > > computeThetaGammaPredicateCorrelation(rvsdg::ThetaNode &thetaNode)
std::optional< GammaSubregionRoles > determineGammaSubregionRoles(const ThetaGammaPredicateCorrelation &correlation)
std::unordered_map< const Node *, size_t > computeDepthMap(const Region &region)
Definition: region.cpp:765
size_t nnodes(const jlm::rvsdg::Region *region) noexcept
Definition: region.cpp:785
size_t ninputs(const rvsdg::Region *region) noexcept
Definition: region.cpp:838
static const char * NumRvsdgNodesBefore
Definition: Statistics.hpp:214
static const char * NumRvsdgNodesAfter
Definition: Statistics.hpp:215
static const char * Timer
Definition: Statistics.hpp:251
static const char * NumRvsdgInputsAfter
Definition: Statistics.hpp:218
static const char * NumRvsdgInputsBefore
Definition: Statistics.hpp:217