Jlm
UnusedStateRemoval.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
10 #include <jlm/rvsdg/gamma.hpp>
11 #include <jlm/rvsdg/theta.hpp>
12 #include <jlm/rvsdg/traverser.hpp>
13 
14 #include <algorithm>
15 
16 namespace jlm::hls
17 {
18 
19 static bool
21 {
22  return rvsdg::ThetaLoopVarIsInvariant(loopVar) && loopVar.pre->nusers() == 1;
23 }
24 
25 static bool
27 {
28  if (argument.nusers() != 1)
29  {
30  return false;
31  }
32 
33  return rvsdg::is<rvsdg::RegionResult>(*argument.Users().begin());
34 }
35 
36 static bool
38 {
39  auto argument = dynamic_cast<rvsdg::RegionArgument *>(result.origin());
40  return argument != nullptr;
41 }
42 
43 static void
45 {
46  const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(lambdaNode.GetOperation());
47  auto & oldFunctionType = op.type();
48 
49  std::vector<std::shared_ptr<const jlm::rvsdg::Type>> newArgumentTypes;
50  for (size_t i = 0; i < oldFunctionType.NumArguments(); ++i)
51  {
52  auto argument = lambdaNode.subregion()->argument(i);
53  auto argumentType = oldFunctionType.Arguments()[i];
54  JLM_ASSERT(*argumentType == *argument->Type());
55 
56  if (!IsPassthroughArgument(*argument))
57  {
58  newArgumentTypes.push_back(argumentType);
59  }
60  }
61 
62  std::vector<std::shared_ptr<const jlm::rvsdg::Type>> newResultTypes;
63  for (size_t i = 0; i < oldFunctionType.NumResults(); ++i)
64  {
65  auto result = lambdaNode.subregion()->result(i);
66  auto resultType = oldFunctionType.Results()[i];
67  JLM_ASSERT(*resultType == *result->Type());
68 
69  if (!IsPassthroughResult(*result))
70  {
71  newResultTypes.push_back(resultType);
72  }
73  }
74 
75  auto newFunctionType = rvsdg::FunctionType::Create(newArgumentTypes, newResultTypes);
76  auto newLambda = rvsdg::LambdaNode::Create(
77  *lambdaNode.region(),
78  llvm::LlvmLambdaOperation::Create(newFunctionType, op.name(), op.linkage(), op.attributes()));
79 
80  rvsdg::SubstitutionMap substitutionMap;
81  for (const auto & ctxvar : lambdaNode.GetContextVars())
82  {
83  auto oldArgument = ctxvar.inner;
84  auto origin = ctxvar.input->origin();
85 
86  auto newArgument = newLambda->AddContextVar(*origin).inner;
87  substitutionMap.insert(oldArgument, newArgument);
88  }
89 
90  size_t new_i = 0;
91  auto newArgs = newLambda->GetFunctionArguments();
92  for (auto argument : lambdaNode.GetFunctionArguments())
93  {
94  if (!IsPassthroughArgument(*argument))
95  {
96  substitutionMap.insert(argument, newArgs[new_i]);
97  new_i++;
98  }
99  }
100  lambdaNode.subregion()->copy(newLambda->subregion(), substitutionMap);
101 
102  std::vector<jlm::rvsdg::Output *> newResults;
103  for (auto result : lambdaNode.GetFunctionResults())
104  {
105  if (!IsPassthroughResult(*result))
106  {
107  newResults.push_back(&substitutionMap.lookup(*result->origin()));
108  }
109  }
110  auto newLambdaOutput = newLambda->finalize(newResults);
111 
112  // TODO handle functions at other levels?
113  JLM_ASSERT(lambdaNode.region() == &lambdaNode.region()->graph()->GetRootRegion());
114  JLM_ASSERT(
115  (*lambdaNode.output()->Users().begin()).region()
116  == &lambdaNode.region()->graph()->GetRootRegion());
117 
118  JLM_ASSERT(lambdaNode.output()->nusers() == 1);
119  lambdaNode.region()->RemoveResults({ (*lambdaNode.output()->Users().begin()).index() });
120  auto oldExport = jlm::llvm::ComputeCallSummary(lambdaNode).GetRvsdgExport();
121  rvsdg::GraphExport::Create(*newLambdaOutput, oldExport ? oldExport->Name() : "");
122  remove(&lambdaNode);
123 }
124 
125 // If this output has a single user and that single user happens to be
126 // the exit variable of this gamma node, then return it.
127 static std::optional<rvsdg::GammaNode::ExitVar>
129 {
130  if (argument.nusers() == 1)
131  {
132  rvsdg::Input * user = &*argument.Users().begin();
133  if (rvsdg::TryGetRegionParentNode<rvsdg::GammaNode>(*user) == &gammaNode)
134  {
135  return gammaNode.MapBranchResultExitVar(*user);
136  }
137  }
138  return std::nullopt;
139 }
140 
141 static void
143 {
144  std::vector<rvsdg::GammaNode::EntryVar> deadEntryVars;
145  std::vector<rvsdg::GammaNode::ExitVar> deadExitVars;
146 
147  for (const auto & entryvar : gammaNode.GetEntryVars())
148  {
149  std::optional<rvsdg::GammaNode::ExitVar> exitvar0 =
150  TryGetSingleUserExitVar(gammaNode, *entryvar.branchArgument[0]);
151 
152  bool shouldRemove = exitvar0
153  && std::all_of(
154  entryvar.branchArgument.begin(),
155  entryvar.branchArgument.end(),
156  [&gammaNode, &exitvar0](rvsdg::Output * argument) -> bool
157  {
158  auto exitvar = TryGetSingleUserExitVar(gammaNode, *argument);
159  return exitvar && exitvar->output == exitvar0->output;
160  });
161 
162  if (shouldRemove)
163  {
164  exitvar0->output->divert_users(entryvar.input->origin());
165  deadEntryVars.push_back(entryvar);
166  deadExitVars.push_back(*exitvar0);
167  }
168  }
169 
170  gammaNode.RemoveExitVars(deadExitVars);
171  gammaNode.RemoveEntryVars(deadEntryVars);
172 }
173 
174 static void
176 {
177  std::vector<rvsdg::ThetaNode::LoopVar> passthroughLoopVars;
178  for (auto & loopVar : thetaNode.GetLoopVars())
179  {
180  if (IsPassthroughLoopVar(loopVar))
181  {
182  loopVar.output->divert_users(loopVar.input->origin());
183  passthroughLoopVars.emplace_back(loopVar);
184  }
185  }
186 
187  thetaNode.RemoveLoopVars(std::move(passthroughLoopVars));
188 }
189 
190 static void
192 
193 static void
195 {
196  // Remove unused states from innermost regions first
197  for (size_t n = 0; n < structuralNode.nsubregions(); n++)
198  {
199  RemoveUnusedStatesInRegion(*structuralNode.subregion(n));
200  }
201 
202  if (auto gammaNode = dynamic_cast<rvsdg::GammaNode *>(&structuralNode))
203  {
205  }
206  else if (auto thetaNode = dynamic_cast<rvsdg::ThetaNode *>(&structuralNode))
207  {
209  }
210  else if (auto lambdaNode = dynamic_cast<rvsdg::LambdaNode *>(&structuralNode))
211  {
212  RemoveUnusedStatesFromLambda(*lambdaNode);
213  }
214 }
215 
216 static void
218 {
219  for (auto & node : rvsdg::TopDownTraverser(&region))
220  {
221  if (auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(node))
222  {
223  RemoveUnusedStatesInStructuralNode(*structuralNode);
224  }
225  }
226 }
227 
228 UnusedStateRemoval::~UnusedStateRemoval() noexcept = default;
229 
232 {}
233 
234 void
236 {
238 }
239 
240 }
~UnusedStateRemoval() noexcept override
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
rvsdg::GraphExport * GetRvsdgExport() const noexcept
Lambda operation.
Definition: lambda.hpp:30
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:77
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
Conditional operator / pattern matching.
Definition: gamma.hpp:99
void RemoveEntryVars(const std::vector< EntryVar > &entryVars)
Removes the given entry variables.
Definition: gamma.cpp:420
void RemoveExitVars(const std::vector< ExitVar > &exitVars)
Removes the given exit variables.
Definition: gamma.cpp:401
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Definition: gamma.cpp:303
ExitVar MapBranchResultExitVar(const rvsdg::Input &input) const
Maps gamma region exit result to exit variable description.
Definition: gamma.cpp:389
static GraphExport & Create(Output &origin, std::string name)
Definition: graph.cpp:62
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
rvsdg::Output * output() const noexcept
Definition: lambda.cpp:176
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
const FunctionType & type() const noexcept
Definition: lambda.hpp:36
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
UsersRange Users()
Definition: node.hpp:354
size_t nusers() const noexcept
Definition: node.hpp:280
Represents the argument of a region.
Definition: region.hpp:41
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
RegionResult * result(size_t index) const noexcept
Definition: region.hpp:471
size_t RemoveResults(const util::HashSet< size_t > &indices)
Definition: region.cpp:278
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
Graph * graph() const noexcept
Definition: region.hpp:363
RegionArgument * argument(size_t index) const noexcept
Definition: region.hpp:437
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
size_t nsubregions() const noexcept
rvsdg::Region * subregion(size_t index) const noexcept
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
void RemoveLoopVars(std::vector< LoopVar > loopVars)
Removes loop variables.
Definition: theta.cpp:63
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
static bool IsPassthroughResult(const rvsdg::Input &result)
static void RemoveUnusedStatesInRegion(rvsdg::Region &region)
static std::optional< rvsdg::GammaNode::ExitVar > TryGetSingleUserExitVar(rvsdg::GammaNode &gammaNode, rvsdg::Output &argument)
static void RemoveUnusedStatesFromThetaNode(rvsdg::ThetaNode &thetaNode)
static bool IsPassthroughArgument(const rvsdg::Output &argument)
static void RemoveUnusedStatesFromGammaNode(rvsdg::GammaNode &gammaNode)
static bool IsPassthroughLoopVar(const rvsdg::ThetaNode::LoopVar &loopVar)
static void RemoveUnusedStatesInStructuralNode(rvsdg::StructuralNode &structuralNode)
static void RemoveUnusedStatesFromLambda(rvsdg::LambdaNode &lambdaNode)
CallSummary ComputeCallSummary(const rvsdg::LambdaNode &lambdaNode)
Definition: CallSummary.cpp:30
static bool ThetaLoopVarIsInvariant(const ThetaNode::LoopVar &loopVar) noexcept
Definition: theta.hpp:227
static void remove(Node *node)
Definition: region.hpp:932
Description of a loop-carried variable.
Definition: theta.hpp:50
rvsdg::Output * pre
Variable before iteration (input argument to subregion).
Definition: theta.hpp:58