Jlm
distribute-constants.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
8 #include <jlm/hls/ir/hls.hpp>
9 #include <jlm/hls/util/view.hpp>
10 #include <jlm/rvsdg/delta.hpp>
11 #include <jlm/rvsdg/gamma.hpp>
12 #include <jlm/rvsdg/lambda.hpp>
13 #include <jlm/rvsdg/MatchType.hpp>
14 #include <jlm/rvsdg/Phi.hpp>
15 #include <jlm/rvsdg/theta.hpp>
16 #include <jlm/rvsdg/traverser.hpp>
17 
18 namespace jlm::hls
19 {
20 
21 void
23 {
24  for (auto & node : region.Nodes())
25  {
27  node,
28  [&](rvsdg::LambdaNode & lambdaNode)
29  {
30  distributeConstantsInLambda(lambdaNode);
31  },
32  [&](rvsdg::PhiNode & phiNode)
33  {
34  distributeConstantsInRootRegion(*phiNode.subregion());
35  },
36  [](rvsdg::DeltaNode &)
37  {
38  // Nothing needs to be done
39  });
40  }
41 }
42 
43 void
45 {
46  const auto constants = collectConstants(*lambdaNode.subregion());
47  for (const auto constant : constants.Items())
48  {
49  // Keep track to which regions we already distributed a constant such that we avoid to create
50  // duplicated instances
51  std::unordered_map<rvsdg::Region *, rvsdg::Node *> distributedConstants;
52  distributedConstants[constant->region()] = constant;
53 
54  auto insertAndDivertToNewConstant =
55  [&distributedConstants](rvsdg::Output & output, const rvsdg::Node & oldConstant)
56  {
57  rvsdg::Node * newConstant = nullptr;
58  const auto region = output.region();
59 
60  const auto it = distributedConstants.find(region);
61  if (it == distributedConstants.end())
62  {
63  newConstant = oldConstant.copy(region, {});
64  distributedConstants[region] = newConstant;
65  }
66  else
67  {
68  newConstant = it->second;
69  }
70 
71  output.divert_users(newConstant->output(0));
72  };
73 
74  auto outputs = collectOutputs(*constant);
75  for (const auto output : outputs.Items())
76  {
77  // Handle gamma, theta, simple node outputs, as well as theta subregion arguments
78  if (rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(*output)
79  || rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*output)
80  || rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*output)
81  || rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*output))
82  {
83  insertAndDivertToNewConstant(*output, *constant);
84  }
85  // Handle gamma subregion arguments
86  else if (const auto gammaNode = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(*output))
87  {
88  // We would like to create constants in every gamma subregion
89  auto roleVar = gammaNode->MapBranchArgument(*output);
90  if (const auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
91  {
92  for (const auto argument : entryVar->branchArgument)
93  {
94  insertAndDivertToNewConstant(*argument, *constant);
95  }
96  }
97  }
98  else
99  {
100  throw std::logic_error("Unhandled output type");
101  }
102  }
103  }
104 }
105 
108 {
109  std::function<void(rvsdg::Region &, util::HashSet<rvsdg::SimpleNode *> &)> collect =
110  [&collect](rvsdg::Region & region, util::HashSet<rvsdg::SimpleNode *> & constants)
111  {
112  for (auto & node : region.Nodes())
113  {
115  node,
116  [&](rvsdg::GammaNode & gammaNode)
117  {
118  for (auto & subregion : gammaNode.Subregions())
119  {
120  collect(subregion, constants);
121  }
122  },
123  [&](rvsdg::ThetaNode & thetaNode)
124  {
125  collect(*thetaNode.subregion(), constants);
126  },
127  [&constants](rvsdg::SimpleNode & simpleNode)
128  {
129  if (is_constant(&simpleNode))
130  {
131  constants.insert(&simpleNode);
132  }
133  });
134  }
135  };
136 
138  collect(region, constants);
139  return constants;
140 }
141 
144 {
145  JLM_ASSERT(is_constant(&constantNode));
146  JLM_ASSERT(constantNode.noutputs() == 1);
147 
150  {
151  for (auto & user : output.Users())
152  {
153  // Handle theta node inputs
154  if (const auto thetaNode = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(user))
155  {
156  const auto loopVar = thetaNode->MapInputLoopVar(user);
157  if (rvsdg::ThetaLoopVarIsInvariant(loopVar))
158  {
159  collectOutputs(*loopVar.pre, outputs);
160  }
161  }
162  // Handle theta subregion results
163  else if (const auto thetaNode = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(user))
164  {
165  if (&user != thetaNode->predicate())
166  {
167  const auto loopVar = thetaNode->MapPostLoopVar(user);
168  collectOutputs(*loopVar.output, outputs);
169  }
170  }
171  // Handle gamma node inputs
172  else if (const auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user))
173  {
174  auto roleVar = gammaNode->MapInput(user);
175  if (const auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
176  {
177  for (const auto argument : entryVar->branchArgument)
178  {
179  collectOutputs(*argument, outputs);
180  }
181  }
182  }
183  // Handle gamma subregion results
184  else if (const auto gammaNode = rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(user))
185  {
186  const auto exitVar = gammaNode->MapBranchResultExitVar(user);
187  if (rvsdg::GetGammaInvariantOrigin(*gammaNode, exitVar))
188  {
189  collectOutputs(*exitVar.output, outputs);
190  }
191  }
192  // Handle simple nodes and lambda subregion results
193  else if (
194  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user)
195  || rvsdg::TryGetRegionParentNode<rvsdg::LambdaNode>(user))
196  {
197  outputs.insert(&output);
198  }
199  else
200  {
201  throw std::logic_error("Unexpected node type");
202  }
203  }
204  };
205 
207  collectOutputs(*constantNode.output(0), outputs);
208  return outputs;
209 }
210 
212 
215 {}
216 
217 void
219 {
221 }
222 
223 }
static void distributeConstantsInRootRegion(rvsdg::Region &region)
static util::HashSet< rvsdg::Output * > collectOutputs(const rvsdg::SimpleNode &constantNode)
static void distributeConstantsInLambda(const rvsdg::LambdaNode &lambdaNode)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
~ConstantDistribution() noexcept override
static util::HashSet< rvsdg::SimpleNode * > collectConstants(rvsdg::Region &region)
Delta node.
Definition: delta.hpp:129
Conditional operator / pattern matching.
Definition: gamma.hpp:99
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
NodeOutput * output(size_t index) const noexcept
Definition: node.hpp:650
size_t noutputs() const noexcept
Definition: node.hpp:644
virtual Node * copy(rvsdg::Region *region, const std::vector< jlm::rvsdg::Output * > &operands) const
Definition: node.cpp:369
rvsdg::Region * region() const noexcept
Definition: node.cpp:151
UsersRange Users()
Definition: node.hpp:354
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
A phi node represents the fixpoint of mutually recursive definitions.
Definition: Phi.hpp:46
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
NodeRange Nodes() noexcept
Definition: region.hpp:328
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
SubregionIteratorRange Subregions()
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
static bool is_constant(const rvsdg::Node *node)
Definition: rvsdg2rhls.hpp:20
void MatchTypeOrFail(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
static bool ThetaLoopVarIsInvariant(const ThetaNode::LoopVar &loopVar) noexcept
Definition: theta.hpp:227
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058
std::optional< rvsdg::Output * > GetGammaInvariantOrigin(const GammaNode &gamma, const GammaNode::ExitVar &exitvar)
Determines whether a gamma exit var is path-invariant.
Definition: gamma.cpp:470