Jlm
mem-sep.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
10 #include <jlm/hls/ir/hls.hpp>
17 #include <jlm/rvsdg/gamma.hpp>
19 #include <jlm/rvsdg/theta.hpp>
20 #include <jlm/rvsdg/traverser.hpp>
21 #include <jlm/rvsdg/view.hpp>
22 
23 #include <algorithm>
24 
25 namespace jlm::hls
26 {
27 
28 static rvsdg::RegionResult *
30  jlm::rvsdg::Output * common_edge,
31  jlm::rvsdg::Output * new_edge,
32  std::vector<rvsdg::Node *> & load_nodes,
33  const std::vector<rvsdg::Node *> & store_nodes,
34  std::vector<rvsdg::Node *> & decouple_nodes)
35 {
36  // follows along common edge and routes new edge through the same regions
37  // redirects the supplied loads, stores and decouples to the new edge
38  // the new edge might be routed through unnecessary regions. This should be fixed by running DNE
39  while (true)
40  {
41  // each iteration should update common_edge and/or new_edge
42  JLM_ASSERT(common_edge->nusers() == 1);
43  JLM_ASSERT(new_edge->nusers() == 1);
44  auto & user = *common_edge->Users().begin();
45  auto & new_next = *new_edge->Users().begin();
46  if (auto res = dynamic_cast<rvsdg::RegionResult *>(&user))
47  {
48  // end of region reached
49  return res;
50  }
51  else if (auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user))
52  {
53  auto ip = gammaNode->AddEntryVar(new_edge);
54  std::vector<jlm::rvsdg::Output *> vec;
55  new_edge = gammaNode->AddExitVar(ip.branchArgument).output;
56  new_next.divert_to(new_edge);
57 
58  auto rolevar = gammaNode->MapInput(user);
59 
60  if (auto entryvar = std::get_if<rvsdg::GammaNode::EntryVar>(&rolevar))
61  {
62  for (size_t i = 0; i < gammaNode->nsubregions(); ++i)
63  {
64  auto subres = trace_edge(
65  entryvar->branchArgument[i],
66  ip.branchArgument[i],
67  load_nodes,
68  store_nodes,
69  decouple_nodes);
70  common_edge = subres->output();
71  }
72  }
73  }
74  else if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(user))
75  {
76  auto olv = theta->MapInputLoopVar(user);
77  auto lv = theta->AddLoopVar(new_edge);
78  trace_edge(olv.pre, lv.pre, load_nodes, store_nodes, decouple_nodes);
79  common_edge = olv.output;
80  new_edge = lv.output;
81  new_next.divert_to(new_edge);
82  }
83  else if (auto sn = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user))
84  {
85  auto op = &sn->GetOperation();
86  if (dynamic_cast<const jlm::llvm::StoreNonVolatileOperation *>(op))
87  {
88  JLM_ASSERT(sn->noutputs() == 1);
89  if (store_nodes.end() != std::find(store_nodes.begin(), store_nodes.end(), sn))
90  {
91  user.divert_to(new_edge);
92  sn->output(0)->divert_users(common_edge);
93  new_edge = sn->output(0);
94  new_next.divert_to(new_edge);
95  }
96  else
97  {
98  common_edge = sn->output(0);
99  }
100  }
101  else if (dynamic_cast<const jlm::llvm::LoadNonVolatileOperation *>(op))
102  {
103  JLM_ASSERT(sn->noutputs() == 2);
104  if (load_nodes.end() != std::find(load_nodes.begin(), load_nodes.end(), sn))
105  {
106  auto & new_next = *new_edge->Users().begin();
107  user.divert_to(new_edge);
108  sn->output(1)->divert_users(common_edge);
109  new_next.divert_to(sn->output(1));
110  new_edge = sn->output(1);
111  new_next.divert_to(new_edge);
112  }
113  else
114  {
115  common_edge = sn->output(1);
116  }
117  }
118  else if (dynamic_cast<const jlm::llvm::CallOperation *>(op))
119  {
120  int oi = sn->noutputs() - sn->ninputs() + user.index();
121  // TODO: verify this is the right type of function call
122  if (decouple_nodes.end() != std::find(decouple_nodes.begin(), decouple_nodes.end(), sn))
123  {
124  auto & new_next = *new_edge->Users().begin();
125  user.divert_to(new_edge);
126  sn->output(oi)->divert_users(common_edge);
127  new_next.divert_to(sn->output(oi));
128  new_edge = new_next.origin();
129  }
130  else
131  {
132  common_edge = sn->output(oi);
133  }
134  }
135  else
136  {
137  JLM_ASSERT(sn->noutputs() == 1);
138  common_edge = sn->output(0);
139  }
140  }
141  else
142  {
143  JLM_UNREACHABLE("THIS SHOULD NOT HAPPEN");
144  }
145  }
146 }
147 
148 std::vector<rvsdg::Node *>
150 {
151  std::function<void(rvsdg::Region &, std::vector<rvsdg::Node *> &)> gatherCalls =
152  [&gatherCalls](rvsdg::Region & region, std::vector<rvsdg::Node *> & calls)
153  {
154  for (auto node : rvsdg::TopDownTraverser(&region))
155  {
156  // Handle innermost regions first
157  if (const auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(node))
158  {
159  for (auto & subregion : structuralNode->Subregions())
160  {
161  gatherCalls(subregion, calls);
162  }
163  }
164 
165  if (rvsdg::is<llvm::CallOperation>(node))
166  {
167  auto functionName = get_function_name(node->input(0));
168  if (functionName.rfind("decouple") == functionName.npos)
169  {
170  calls.push_back(node);
171  }
172  }
173  }
174  };
175 
176  std::vector<rvsdg::Node *> calls;
177  gatherCalls(region, calls);
178  return calls;
179 }
180 
181 void
183 {
184  const auto lambdaSubregion = lambdaNode.subregion();
185  auto & memoryStateArgument = llvm::GetMemoryStateRegionArgument(lambdaNode);
186 
187  auto tracedPointerNodesVector = TracePointerArguments(&lambdaNode);
188  for (auto & tp : tracedPointerNodesVector)
189  {
190  auto & decouple_nodes = tp.decoupleNodes;
191  auto decouple_requests_cnt = decouple_nodes.size();
192  // place decouple responses along same state edge
193  for (size_t i = 0; i < decouple_requests_cnt; ++i)
194  {
195  auto req = decouple_nodes[i];
196  auto channel = req->input(1)->origin();
197  auto channel_constant = jlm::hls::trace_constant(channel);
198  auto decouple_response = find_decouple_response(&lambdaNode, channel_constant);
199  decouple_nodes.push_back(decouple_response);
200  }
201  }
202 
203  // Create fake ports for non-decouple calls
204  const auto nonDecoupleCalls = gatherNonDecoupleCalls(*lambdaSubregion);
205  for (auto call : nonDecoupleCalls)
206  {
207  tracedPointerNodesVector.emplace_back();
208  tracedPointerNodesVector.back().decoupleNodes.push_back(call);
209  }
210 
211  const size_t numMemoryStates = tracedPointerNodesVector.size() + 1;
212 
213  // Assign memory node ids incrementally, used by both the split and merge
214  std::vector<llvm::MemoryNodeId> memoryNodeIds;
215  for (size_t i = 0; i < numMemoryStates; ++i)
216  {
217  memoryNodeIds.push_back(i);
218  }
219 
220  auto & lambdaEntrySplitNode =
221  llvm::LambdaEntryMemoryStateSplitOperation::CreateNode(memoryStateArgument, memoryNodeIds);
222  auto memoryStates = outputs(&lambdaEntrySplitNode);
223 
224  // handle existing state edge - TODO: remove entirely?
225  // The memory state chain that was previously between the state argument and state result,
226  // should now be attached as a chain between the final state on the split and merge, respectively
227  auto common_edge = memoryStates.back();
228  // Divert everything except the lambdaEntrySplit itself, to the lambdaEntrySplit
229  memoryStateArgument.divertUsersWhere(
230  *common_edge,
231  [&](const rvsdg::Input & input)
232  {
233  return &input != lambdaEntrySplitNode.input(0);
234  });
235 
236  auto state_result = &llvm::GetMemoryStateRegionResult(lambdaNode);
237  memoryStates.back() = state_result->origin();
238 
239  auto & lambdaExitMergeNode = llvm::LambdaExitMemoryStateMergeOperation::CreateNode(
240  *lambdaSubregion,
241  memoryStates,
242  memoryNodeIds);
243  memoryStates.pop_back();
244  state_result->divert_to(lambdaExitMergeNode.output(0));
245 
246  for (auto tp : tracedPointerNodesVector)
247  {
248  auto new_edge = memoryStates.back();
249  memoryStates.pop_back();
250  trace_edge(common_edge, new_edge, tp.loadNodes, tp.storeNodes, tp.decoupleNodes);
251  }
252 }
253 
255 
258 {}
259 
260 void
262 {
263  const auto & graph = rvsdgModule.Rvsdg();
264  const auto rootRegion = &graph.GetRootRegion();
265  if (rootRegion->numNodes() != 1)
266  {
267  throw std::logic_error("Root should have only one node now");
268  }
269 
270  const auto lambdaNode =
271  dynamic_cast<const rvsdg::LambdaNode *>(rootRegion->Nodes().begin().ptr());
272  if (!lambdaNode)
273  {
274  throw std::logic_error("Node needs to be a lambda");
275  }
276 
277  separateMemoryStates(*lambdaNode);
278 }
279 
280 } // namespace jlm::hls
~MemoryStateSeparation() noexcept override
static std::vector< rvsdg::Node * > gatherNonDecoupleCalls(rvsdg::Region &region)
Definition: mem-sep.cpp:149
static void separateMemoryStates(const rvsdg::LambdaNode &lambdaNode)
Definition: mem-sep.cpp:182
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: mem-sep.cpp:261
Call operation class.
Definition: call.hpp:249
static rvsdg::SimpleNode & CreateNode(rvsdg::Output &operand, std::vector< MemoryNodeId > memoryNodeIds)
static rvsdg::Node & CreateNode(rvsdg::Region &region, const std::vector< rvsdg::Output * > &operands, const std::vector< MemoryNodeId > &memoryNodeIds)
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Lambda node.
Definition: lambda.hpp:83
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
UsersRange Users()
Definition: node.hpp:354
size_t index() const noexcept
Definition: node.hpp:274
size_t nusers() const noexcept
Definition: node.hpp:280
Represents the result of a region.
Definition: region.hpp:120
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
rvsdg::SimpleNode * find_decouple_response(const rvsdg::LambdaNode *lambda, const llvm::IntegerConstantOperation *request_constant)
Definition: mem-conv.cpp:27
std::string get_function_name(jlm::rvsdg::Input *input)
const llvm::IntegerConstantOperation * trace_constant(const rvsdg::Output *dst)
std::vector< TracedPointerNodes > TracePointerArguments(const rvsdg::LambdaNode *lambda)
Definition: mem-conv.cpp:339
static rvsdg::Output * trace_edge(rvsdg::Input *state_edge, rvsdg::Output *new_edge, rvsdg::SimpleNode *target_call, rvsdg::Output *end)
rvsdg::Input & GetMemoryStateRegionResult(const rvsdg::LambdaNode &lambdaNode) noexcept
rvsdg::Output & GetMemoryStateRegionArgument(const rvsdg::LambdaNode &lambdaNode) noexcept
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058