Jlm
UnusedStateRemoval.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
10 #include <jlm/rvsdg/gamma.hpp>
11 #include <jlm/rvsdg/theta.hpp>
12 #include <jlm/rvsdg/traverser.hpp>
13 
14 #include <algorithm>
15 
16 namespace jlm::hls
17 {
18 
19 static bool
21 {
22  return rvsdg::ThetaLoopVarIsInvariant(loopVar) && loopVar.pre->nusers() == 1;
23 }
24 
25 static bool
27 {
28  if (argument.nusers() != 1)
29  {
30  return false;
31  }
32 
33  return rvsdg::is<rvsdg::RegionResult>(*argument.Users().begin());
34 }
35 
36 static bool
38 {
39  auto argument = dynamic_cast<rvsdg::RegionArgument *>(result.origin());
40  return argument != nullptr;
41 }
42 
43 static void
45 {
46  const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(lambdaNode.GetOperation());
47  auto & oldFunctionType = op.type();
48 
49  std::vector<std::shared_ptr<const jlm::rvsdg::Type>> newArgumentTypes;
50  for (size_t i = 0; i < oldFunctionType.NumArguments(); ++i)
51  {
52  auto argument = lambdaNode.subregion()->argument(i);
53  auto argumentType = oldFunctionType.Arguments()[i];
54  JLM_ASSERT(*argumentType == *argument->Type());
55 
56  if (!IsPassthroughArgument(*argument))
57  {
58  newArgumentTypes.push_back(argumentType);
59  }
60  }
61 
62  std::vector<std::shared_ptr<const jlm::rvsdg::Type>> newResultTypes;
63  for (size_t i = 0; i < oldFunctionType.NumResults(); ++i)
64  {
65  auto result = lambdaNode.subregion()->result(i);
66  auto resultType = oldFunctionType.Results()[i];
67  JLM_ASSERT(*resultType == *result->Type());
68 
69  if (!IsPassthroughResult(*result))
70  {
71  newResultTypes.push_back(resultType);
72  }
73  }
74 
75  auto newFunctionType = rvsdg::FunctionType::Create(newArgumentTypes, newResultTypes);
76  auto newLambda = rvsdg::LambdaNode::Create(
77  *lambdaNode.region(),
79  newFunctionType,
80  op.name(),
81  op.linkage(),
82  op.callingConvention(),
83  op.attributes()));
84 
85  rvsdg::SubstitutionMap substitutionMap;
86  for (const auto & ctxvar : lambdaNode.GetContextVars())
87  {
88  auto oldArgument = ctxvar.inner;
89  auto origin = ctxvar.input->origin();
90 
91  auto newArgument = newLambda->AddContextVar(*origin).inner;
92  substitutionMap.insert(oldArgument, newArgument);
93  }
94 
95  size_t new_i = 0;
96  auto newArgs = newLambda->GetFunctionArguments();
97  for (auto argument : lambdaNode.GetFunctionArguments())
98  {
99  if (!IsPassthroughArgument(*argument))
100  {
101  substitutionMap.insert(argument, newArgs[new_i]);
102  new_i++;
103  }
104  }
105  lambdaNode.subregion()->copy(newLambda->subregion(), substitutionMap);
106 
107  std::vector<jlm::rvsdg::Output *> newResults;
108  for (auto result : lambdaNode.GetFunctionResults())
109  {
110  if (!IsPassthroughResult(*result))
111  {
112  newResults.push_back(&substitutionMap.lookup(*result->origin()));
113  }
114  }
115  auto newLambdaOutput = newLambda->finalize(newResults);
116 
117  // TODO handle functions at other levels?
118  JLM_ASSERT(lambdaNode.region() == &lambdaNode.region()->graph()->GetRootRegion());
119  JLM_ASSERT(
120  (*lambdaNode.output()->Users().begin()).region()
121  == &lambdaNode.region()->graph()->GetRootRegion());
122 
123  JLM_ASSERT(lambdaNode.output()->nusers() == 1);
124  lambdaNode.region()->RemoveResults({ (*lambdaNode.output()->Users().begin()).index() });
125  auto oldExport = jlm::llvm::ComputeCallSummary(lambdaNode).GetRvsdgExport();
126  rvsdg::GraphExport::Create(*newLambdaOutput, oldExport ? oldExport->Name() : "");
127  remove(&lambdaNode);
128 }
129 
130 // If this output has a single user and that single user happens to be
131 // the exit variable of this gamma node, then return it.
132 static std::optional<rvsdg::GammaNode::ExitVar>
134 {
135  if (argument.nusers() == 1)
136  {
137  rvsdg::Input * user = &*argument.Users().begin();
138  if (rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(*user) == &gammaNode)
139  {
140  return gammaNode.MapBranchResultExitVar(*user);
141  }
142  }
143  return std::nullopt;
144 }
145 
146 static void
148 {
149  std::vector<rvsdg::GammaNode::EntryVar> deadEntryVars;
150  std::vector<rvsdg::Output *> deadGammaOutputs;
151 
152  for (const auto & entryvar : gammaNode.GetEntryVars())
153  {
154  std::optional<rvsdg::GammaNode::ExitVar> exitvar0 =
155  TryGetSingleUserExitVar(gammaNode, *entryvar.branchArgument[0]);
156 
157  bool shouldRemove = exitvar0
158  && std::all_of(
159  entryvar.branchArgument.begin(),
160  entryvar.branchArgument.end(),
161  [&gammaNode, &exitvar0](rvsdg::Output * argument) -> bool
162  {
163  auto exitvar = TryGetSingleUserExitVar(gammaNode, *argument);
164  return exitvar && exitvar->output == exitvar0->output;
165  });
166 
167  if (shouldRemove)
168  {
169  exitvar0->output->divert_users(entryvar.input->origin());
170  deadEntryVars.push_back(entryvar);
171  deadGammaOutputs.push_back(exitvar0->output);
172  }
173  }
174 
175  gammaNode.RemoveExitVars(deadGammaOutputs);
176  gammaNode.RemoveEntryVars(deadEntryVars);
177 }
178 
179 static void
181 {
182  std::vector<rvsdg::ThetaNode::LoopVar> passthroughLoopVars;
183  for (auto & loopVar : thetaNode.GetLoopVars())
184  {
185  if (IsPassthroughLoopVar(loopVar))
186  {
187  loopVar.output->divert_users(loopVar.input->origin());
188  passthroughLoopVars.emplace_back(loopVar);
189  }
190  }
191 
192  thetaNode.RemoveLoopVars(std::move(passthroughLoopVars));
193 }
194 
195 static void
197 
198 static void
200 {
201  // Remove unused states from innermost regions first
202  for (size_t n = 0; n < structuralNode.nsubregions(); n++)
203  {
204  RemoveUnusedStatesInRegion(*structuralNode.subregion(n));
205  }
206 
207  if (auto gammaNode = dynamic_cast<rvsdg::GammaNode *>(&structuralNode))
208  {
210  }
211  else if (auto thetaNode = dynamic_cast<rvsdg::ThetaNode *>(&structuralNode))
212  {
214  }
215  else if (auto lambdaNode = dynamic_cast<rvsdg::LambdaNode *>(&structuralNode))
216  {
217  RemoveUnusedStatesFromLambda(*lambdaNode);
218  }
219 }
220 
221 static void
223 {
224  for (auto & node : rvsdg::TopDownTraverser(&region))
225  {
226  if (auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(node))
227  {
228  RemoveUnusedStatesInStructuralNode(*structuralNode);
229  }
230  }
231 }
232 
233 UnusedStateRemoval::~UnusedStateRemoval() noexcept = default;
234 
237 {}
238 
239 void
241 {
243 }
244 
245 }
~UnusedStateRemoval() noexcept override
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
rvsdg::GraphExport * GetRvsdgExport() const noexcept
Lambda operation.
Definition: lambda.hpp:30
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:84
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
Conditional operator / pattern matching.
Definition: gamma.hpp:99
void RemoveEntryVars(const std::vector< EntryVar > &entryVars)
Removes the given entry variables.
Definition: gamma.cpp:459
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Definition: gamma.cpp:305
void RemoveExitVars(const std::vector< Output * > &gammaOutputs)
Removes the exit variables corresponding to the given gammaOutputs.
Definition: gamma.cpp:439
ExitVar MapBranchResultExitVar(const rvsdg::Input &input) const
Maps gamma region exit result to exit variable description.
Definition: gamma.cpp:409
static GraphExport & Create(Output &origin, std::string name)
Definition: graph.cpp:62
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
rvsdg::Output * output() const noexcept
Definition: lambda.cpp:176
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
const FunctionType & type() const noexcept
Definition: lambda.hpp:36
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
UsersRange Users()
Definition: node.hpp:354
size_t nusers() const noexcept
Definition: node.hpp:280
Represents the argument of a region.
Definition: region.hpp:41
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
RegionResult * result(size_t index) const noexcept
Definition: region.hpp:471
size_t RemoveResults(const util::HashSet< size_t > &indices)
Definition: region.cpp:278
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
Graph * graph() const noexcept
Definition: region.hpp:363
RegionArgument * argument(size_t index) const noexcept
Definition: region.hpp:437
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
size_t nsubregions() const noexcept
rvsdg::Region * subregion(size_t index) const noexcept
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
void RemoveLoopVars(std::vector< LoopVar > loopVars)
Removes loop variables.
Definition: theta.cpp:63
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
static bool IsPassthroughResult(const rvsdg::Input &result)
static void RemoveUnusedStatesInRegion(rvsdg::Region &region)
static std::optional< rvsdg::GammaNode::ExitVar > TryGetSingleUserExitVar(rvsdg::GammaNode &gammaNode, rvsdg::Output &argument)
static void RemoveUnusedStatesFromThetaNode(rvsdg::ThetaNode &thetaNode)
static bool IsPassthroughArgument(const rvsdg::Output &argument)
static void RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode &gammaNode)
static bool IsPassthroughLoopVar(const rvsdg::ThetaNode::LoopVar &loopVar)
static void RemoveUnusedStatesInStructuralNode(rvsdg::StructuralNode &structuralNode)
static void RemoveUnusedStatesFromLambda(rvsdg::LambdaNode &lambdaNode)
CallSummary ComputeCallSummary(const rvsdg::LambdaNode &lambdaNode)
Definition: CallSummary.cpp:30
static bool ThetaLoopVarIsInvariant(const ThetaNode::LoopVar &loopVar) noexcept
Definition: theta.hpp:227
static void remove(Node *node)
Definition: region.hpp:978
Description of a loop-carried variable.
Definition: theta.hpp:50
rvsdg::Output * pre
Variable before iteration (input argument to subregion).
Definition: theta.hpp:58