Jlm
mem-conv.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * and Magnus Sjalander <work@sjalander.com>
4  * See COPYING for terms of redistribution.
5  */
6 
12 #include <jlm/hls/ir/hls.hpp>
20 #include <jlm/rvsdg/theta.hpp>
21 #include <jlm/rvsdg/traverser.hpp>
22 #include <jlm/rvsdg/view.hpp>
23 
24 namespace jlm::hls
25 {
26 rvsdg::SimpleNode *
28  const rvsdg::LambdaNode * lambda,
29  const llvm::IntegerConstantOperation * request_constant)
30 {
31  auto response_functions = find_function_arguments(lambda, "decouple_res");
32  for (auto & func : response_functions)
33  {
34  std::unordered_set<rvsdg::Output *> visited;
35  std::vector<rvsdg::SimpleNode *> reponse_calls;
36  trace_function_calls(func.inner, reponse_calls, visited);
37  for (auto & rc : reponse_calls)
38  {
39  auto response_constant = trace_constant(rc->input(1)->origin());
40  if (*response_constant == *request_constant)
41  {
42  return rc;
43  }
44  }
45  }
46  JLM_UNREACHABLE("No response found");
47 }
48 
49 static std::pair<rvsdg::Input *, std::vector<rvsdg::Input *>>
51 {
52  std::vector<rvsdg::Input *> encountered_muxes;
53  // should encounter no new loops, or gammas, only exit them
54  rvsdg::Input * previous_state_edge = nullptr;
55  while (true)
56  {
57  // make sure we make progress
58  JLM_ASSERT(previous_state_edge != state_edge);
59  if (dynamic_cast<jlm::rvsdg::RegionResult *>(state_edge))
60  {
61  JLM_UNREACHABLE("this should be handled by branch");
62  }
63  else if (rvsdg::TryGetOwnerNode<LoopNode>(*state_edge))
64  {
65  JLM_UNREACHABLE("there should be no new loops");
66  }
67  auto si = state_edge;
68  auto sn = &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*si);
69  auto [branchNode, branchOperation] =
70  rvsdg::TryGetSimpleNodeAndOptionalOp<BranchOperation>(*state_edge);
71  auto [muxNode, muxOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<MuxOperation>(*state_edge);
72  if (branchOperation)
73  {
74  // end of loop
75  JLM_ASSERT(branchOperation->loop);
76  state_edge = get_mem_state_user(
77  util::assertedCast<rvsdg::RegionResult>(get_mem_state_user(sn->output(0)))->output());
78  }
79  else if (muxOperation && !muxOperation->loop)
80  {
81  // end of gamma
82  encountered_muxes.push_back(si);
83  state_edge = get_mem_state_user(sn->output(0));
84  }
85  else if (
86  rvsdg::IsOwnerNodeOperation<llvm::MemoryStateMergeOperation>(*state_edge)
87  || rvsdg::IsOwnerNodeOperation<llvm::LambdaExitMemoryStateMergeOperation>(*state_edge))
88  {
89  return { state_edge, encountered_muxes };
90  }
91  else
92  {
93  JLM_UNREACHABLE("whoops");
94  }
95  }
96 }
97 
98 void
100 {
101  // replace other branches with undefs, so the stateedge before the res can be killed.
102  auto [merge_in, encountered_muxes] = TraceEdgeToMerge(get_mem_state_user(res_mem_state));
103  JLM_ASSERT(merge_in);
104  for (auto si : encountered_muxes)
105  {
106  auto & sn = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*si);
107  for (size_t i = 1; i < sn.ninputs(); ++i)
108  {
109  if (i != si->index())
110  {
111  auto state_dummy = llvm::UndefValueOperation::Create(*si->region(), si->Type());
112  sn.input(i)->divert_to(state_dummy);
113  }
114  }
115  }
116 }
117 
118 void
120 {
121  // there is no reason to wait for requests, if we already wait for responses, so we kill the rest
122  // of this state edge
123  auto [merge_in, _] = TraceEdgeToMerge(get_mem_state_user(req_mem_state));
124  JLM_ASSERT(merge_in);
125  auto & merge_node = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*merge_in);
126  std::vector<rvsdg::Output *> merge_origins;
127  for (size_t i = 0; i < merge_node.ninputs(); ++i)
128  {
129  if (i != merge_in->index())
130  {
131  merge_origins.push_back(merge_node.input(i)->origin());
132  }
133  }
134  auto new_merge_output = llvm::MemoryStateMergeOperation::Create(merge_origins);
135  merge_node.output(0)->divert_users(new_merge_output);
136  JLM_ASSERT(merge_node.IsDead());
137  remove(&merge_node);
138 }
139 
142  const rvsdg::LambdaNode * lambda,
143  rvsdg::SimpleNode * decouple_request,
144  rvsdg::Output * resp)
145 {
146  JLM_ASSERT(dynamic_cast<const llvm::CallOperation *>(&decouple_request->GetOperation()));
147  auto channel = decouple_request->input(1)->origin();
148  auto channel_constant = trace_constant(channel);
149 
150  auto decouple_response = find_decouple_response(lambda, channel_constant);
151 
152  // handle request
153  auto addr = decouple_request->input(2)->origin();
154  auto req_mem_state = decouple_request->input(decouple_request->ninputs() - 1)->origin();
155  // state gate for req
156  auto sg_out = StateGateOperation::create(*addr, { req_mem_state });
157  addr = sg_out[0];
158  req_mem_state = sg_out[1];
159  // redirect memstate - iostate output has already been removed by mem-sep pass
160  decouple_request->output(decouple_request->noutputs() - 1)->divert_users(req_mem_state);
161 
162  // handle response
163  int load_capacity = 10;
164  if (rvsdg::is<const rvsdg::BitType>(decouple_response->input(2)->Type()))
165  {
166  auto constant = trace_constant(decouple_response->input(2)->origin());
167  load_capacity = constant->Representation().to_int();
168  assert(load_capacity >= 0);
169  }
170  auto routed_resp = route_response_rhls(decouple_request->region(), resp);
171  auto dload_out = DecoupledLoadOperation::create(*addr, *routed_resp, load_capacity);
172  auto dload_node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*dload_out[0]);
173 
174  auto routed_data = route_to_region_rhls(decouple_response->region(), dload_out[0]);
175  decouple_response->output(0)->divert_users(routed_data);
176  auto response_state_origin = decouple_response->input(decouple_response->ninputs() - 1)->origin();
177 
178  if (decouple_request->region() != decouple_response->region())
179  {
180  // they are in different regions, so we handle state edge at response
181  auto state_dummy = llvm::UndefValueOperation::Create(
182  *response_state_origin->region(),
183  response_state_origin->Type());
184  auto sg_resp = StateGateOperation::create(*routed_data, { state_dummy });
185  decouple_response->output(decouple_response->noutputs() - 1)->divert_users(sg_resp[1]);
186  JLM_ASSERT(decouple_response->IsDead());
187  remove(decouple_response);
188  JLM_ASSERT(decouple_request->IsDead());
189  remove(decouple_request);
190 
191  OptimizeResMemState(sg_resp[1]);
192  OptimizeReqMemState(req_mem_state);
193  }
194  else
195  {
196  // they are in the same region, handle at request
197  // remove mem state from response call
198  decouple_response->output(decouple_response->noutputs() - 1)
199  ->divert_users(response_state_origin);
200 
201  auto state_dummy = llvm::UndefValueOperation::Create(
202  *response_state_origin->region(),
203  response_state_origin->Type());
204  // put state gate on load response
205  auto sg_resp = StateGateOperation::create(*dload_node->input(1)->origin(), { state_dummy });
206  dload_node->input(1)->divert_to(sg_resp[0]);
207  auto state_user = get_mem_state_user(req_mem_state);
208  state_user->divert_to(sg_resp[1]);
209 
210  JLM_ASSERT(decouple_response->IsDead());
211  remove(decouple_response);
212  JLM_ASSERT(decouple_request->IsDead());
213  remove(decouple_request);
214 
215  // these are swapped in this scenario, since we keep the one from request
216  OptimizeReqMemState(response_state_origin);
217  OptimizeResMemState(sg_resp[1]);
218  }
219 
220  auto nn = dynamic_cast<rvsdg::NodeOutput *>(dload_out[0])->node();
221  return dynamic_cast<rvsdg::SimpleNode *>(nn);
222 }
223 
224 void
226  rvsdg::Region * region,
227  std::vector<rvsdg::Node *> & loadNodes,
228  std::vector<rvsdg::Node *> & storeNodes,
229  std::vector<rvsdg::Node *> & decoupleNodes,
230  std::unordered_set<rvsdg::Node *> exclude)
231 {
232  for (auto & node : rvsdg::TopDownTraverser(region))
233  {
234  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
235  {
236  for (size_t n = 0; n < structnode->nsubregions(); n++)
237  gather_mem_nodes(structnode->subregion(n), loadNodes, storeNodes, decoupleNodes, exclude);
238  }
239  else if (auto simplenode = dynamic_cast<rvsdg::SimpleNode *>(node))
240  {
241  if (exclude.find(simplenode) != exclude.end())
242  {
243  continue;
244  }
245  if (dynamic_cast<const llvm::StoreNonVolatileOperation *>(&simplenode->GetOperation()))
246  {
247  storeNodes.push_back(simplenode);
248  }
249  else if (dynamic_cast<const llvm::LoadNonVolatileOperation *>(&simplenode->GetOperation()))
250  {
251  loadNodes.push_back(simplenode);
252  }
253  else if (dynamic_cast<const llvm::CallOperation *>(&simplenode->GetOperation()))
254  {
255  // we only want to collect requests
256  if (is_dec_req(simplenode))
257  decoupleNodes.push_back(simplenode);
258  }
259  }
260  }
261 }
262 
270 static void
272  rvsdg::Output * output,
273  std::unordered_set<rvsdg::Output *> & visited,
274  TracedPointerNodes & tracedPointerNodes)
275 {
276  if (!rvsdg::is<llvm::PointerType>(output->Type()))
277  {
278  // Only process pointer outputs
279  return;
280  }
281  if (visited.count(output))
282  {
283  // Skip already processed outputs
284  return;
285  }
286  visited.insert(output);
287  for (auto & user : output->Users())
288  {
289  if (auto simplenode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user))
290  {
291  if (dynamic_cast<const llvm::StoreNonVolatileOperation *>(&simplenode->GetOperation()))
292  {
293  tracedPointerNodes.storeNodes.push_back(simplenode);
294  }
295  else if (dynamic_cast<const llvm::LoadNonVolatileOperation *>(&simplenode->GetOperation()))
296  {
297  tracedPointerNodes.loadNodes.push_back(simplenode);
298  }
299  else if (dynamic_cast<const llvm::CallOperation *>(&simplenode->GetOperation()))
300  {
301  // request
302  JLM_ASSERT(is_dec_req(simplenode));
303  tracedPointerNodes.decoupleNodes.push_back(simplenode);
304  }
305  else
306  {
307  for (size_t i = 0; i < simplenode->noutputs(); ++i)
308  {
309  TracePointer(simplenode->output(i), visited, tracedPointerNodes);
310  }
311  }
312  }
313  else if (auto sti = dynamic_cast<rvsdg::StructuralInput *>(&user))
314  {
315  for (auto & arg : sti->arguments)
316  {
317  TracePointer(&arg, visited, tracedPointerNodes);
318  }
319  }
320  else if (auto r = dynamic_cast<rvsdg::RegionResult *>(&user))
321  {
322  if (auto ber = dynamic_cast<BackEdgeResult *>(r))
323  {
324  TracePointer(ber->argument(), visited, tracedPointerNodes);
325  }
326  else
327  {
328  TracePointer(r->output(), visited, tracedPointerNodes);
329  }
330  }
331  else
332  {
333  JLM_UNREACHABLE("THIS SHOULD BE COVERED");
334  }
335  }
336 }
337 
338 std::vector<TracedPointerNodes>
340 {
341  std::vector<TracedPointerNodes> tracedPointerNodes;
342  for (const auto argument : lambda->GetFunctionArguments())
343  {
344  if (rvsdg::is<llvm::PointerType>(argument->Type()))
345  {
346  std::unordered_set<rvsdg::Output *> visited;
347  tracedPointerNodes.emplace_back();
348  TracePointer(argument, visited, tracedPointerNodes.back());
349  }
350  }
351 
352  for (auto cv : lambda->GetContextVars())
353  {
354  if (rvsdg::is<llvm::PointerType>(cv.inner->Type()) && !is_function_argument(cv))
355  {
356  std::unordered_set<rvsdg::Output *> visited;
357  tracedPointerNodes.emplace_back();
358  TracePointer(cv.inner, visited, tracedPointerNodes.back());
359  }
360  }
361 
362  return tracedPointerNodes;
363 }
364 
367 {
368  if (auto l = dynamic_cast<rvsdg::LambdaNode *>(region->node()))
369  {
370  return l;
371  }
372  return find_containing_lambda(region->node()->region());
373 }
374 
375 static size_t
376 CalculatePortWidth(const TracedPointerNodes & tracedPointerNodes)
377 {
378  int max_width = 0;
379  for (auto node : tracedPointerNodes.loadNodes)
380  {
381  auto loadOp = util::assertedCast<const llvm::LoadNonVolatileOperation>(&node->GetOperation());
382  auto sz = JlmSize(loadOp->GetLoadedType().get());
383  max_width = sz > max_width ? sz : max_width;
384  }
385  for (auto node : tracedPointerNodes.storeNodes)
386  {
387  auto storeOp = util::assertedCast<const llvm::StoreNonVolatileOperation>(&node->GetOperation());
388  auto sz = JlmSize(&storeOp->GetStoredType());
389  max_width = sz > max_width ? sz : max_width;
390  }
391  for (auto decoupleRequest : tracedPointerNodes.decoupleNodes)
392  {
393  auto lambda = find_containing_lambda(decoupleRequest->region());
394  auto channel = decoupleRequest->input(1)->origin();
395  auto channelConstant = trace_constant(channel);
396  auto reponse = find_decouple_response(lambda, channelConstant);
397  auto sz = JlmSize(reponse->output(0)->Type().get());
398  max_width = sz > max_width ? sz : max_width;
399  }
400  JLM_ASSERT(max_width != 0);
401  return max_width;
402 }
403 
404 static rvsdg::SimpleNode *
406  rvsdg::SubstitutionMap & smap,
407  const rvsdg::Node * originalLoad,
408  rvsdg::Output * response)
409 {
410  // We have the load from the original lambda since it is needed to update the smap
411  // We need the load in the new lambda such that we can replace it with a load node with explicit
412  // memory ports
413  auto replacedLoad =
414  &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.lookup(*originalLoad->output(0)));
415 
416  auto loadAddress = replacedLoad->input(0)->origin();
417  std::vector<rvsdg::Output *> states;
418  for (size_t i = 1; i < replacedLoad->ninputs(); ++i)
419  {
420  states.push_back(replacedLoad->input(i)->origin());
421  }
422 
423  rvsdg::Node * newLoad = nullptr;
424  if (states.empty())
425  {
426  size_t load_capacity = 10;
427  auto outputs = DecoupledLoadOperation::create(*loadAddress, *response, load_capacity);
428  newLoad = dynamic_cast<rvsdg::NodeOutput *>(outputs[0])->node();
429  }
430  else
431  {
432  // TODO: switch this to a decoupled load?
433  auto outputs = LoadOperation::create(*loadAddress, states, *response);
434  newLoad = dynamic_cast<rvsdg::NodeOutput *>(outputs[0])->node();
435  }
436 
437  for (size_t i = 0; i < replacedLoad->noutputs(); ++i)
438  {
439  smap.insert(originalLoad->output(i), newLoad->output(i));
440  replacedLoad->output(i)->divert_users(newLoad->output(i));
441  }
442  remove(replacedLoad);
443  return dynamic_cast<rvsdg::SimpleNode *>(newLoad);
444 }
445 
446 static rvsdg::SimpleNode *
448  rvsdg::SubstitutionMap & smap,
449  const rvsdg::Node * originalStore,
450  rvsdg::Output * response)
451 {
452  // We have the store from the original lambda since it is needed to update the smap
453  // We need the store in the new lambda such that we can replace it with a store node with explicit
454  // memory ports
455  auto replacedStore =
456  &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.lookup(*originalStore->output(0)));
457 
458  auto addr = replacedStore->input(0)->origin();
459  JLM_ASSERT(rvsdg::is<llvm::PointerType>(addr->Type()));
460  auto data = replacedStore->input(1)->origin();
461  std::vector<rvsdg::Output *> states;
462  for (size_t i = 2; i < replacedStore->ninputs(); ++i)
463  {
464  states.push_back(replacedStore->input(i)->origin());
465  }
466  auto storeOuts = StoreOperation::create(*addr, *data, states, *response);
467  auto newStore = dynamic_cast<rvsdg::NodeOutput *>(storeOuts[0])->node();
468  // iterate over output states
469  for (size_t i = 0; i < replacedStore->noutputs(); ++i)
470  {
471  // create a buffer to avoid a scenario where the reponse port is blocked because a merge waits
472  // for the store
473  // TODO: It might be better to have memstate merges consume individual tokens instead,, and fire
474  // the output once all inputs have consumed
475  const auto bo = BufferOperation::create(*storeOuts[i], 1, true)[0];
476  smap.insert(originalStore->output(i), bo);
477  replacedStore->output(i)->divert_users(bo);
478  }
479  remove(replacedStore);
480  return dynamic_cast<rvsdg::SimpleNode *>(newStore);
481 }
482 
483 static rvsdg::Output *
485  const rvsdg::LambdaNode * lambda,
486  size_t argumentIndex,
487  rvsdg::SubstitutionMap & smap,
488  const std::vector<rvsdg::Node *> & originalLoadNodes,
489  const std::vector<rvsdg::Node *> & originalStoreNodes,
490  const std::vector<rvsdg::Node *> & originalDecoupledNodes)
491 {
492  //
493  // We have the memory operations from the original lambda and need to lookup the corresponding
494  // nodes in the new lambda
495  //
496  std::vector<rvsdg::SimpleNode *> loadNodes;
497  std::vector<std::shared_ptr<const rvsdg::Type>> responseTypes;
498  for (auto loadNode : originalLoadNodes)
499  {
500  auto oldLoadedValue = loadNode->output(0);
501  JLM_ASSERT(smap.contains(*oldLoadedValue));
502  auto & newLoadNode = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.lookup(*oldLoadedValue));
503  loadNodes.push_back(&newLoadNode);
504  auto loadOp =
505  util::assertedCast<const llvm::LoadNonVolatileOperation>(&newLoadNode.GetOperation());
506  responseTypes.push_back(loadOp->GetLoadedType());
507  }
508  std::vector<rvsdg::SimpleNode *> decoupledNodes;
509  for (auto decoupleRequest : originalDecoupledNodes)
510  {
511  auto oldOutput = decoupleRequest->output(0);
512  JLM_ASSERT(smap.contains(*oldOutput));
513  auto & decoupledRequestNode =
514  rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.lookup(*oldOutput));
515  decoupledNodes.push_back(&decoupledRequestNode);
516  // get load type from response output
517  auto channel = decoupleRequest->input(1)->origin();
518  auto channelConstant = trace_constant(channel);
519  auto reponse = find_decouple_response(lambda, channelConstant);
520  auto vt = reponse->output(0)->Type();
521  responseTypes.push_back(vt);
522  }
523  std::vector<rvsdg::SimpleNode *> storeNodes;
524  for (auto storeNode : originalStoreNodes)
525  {
526  auto oldOutput = storeNode->output(0);
527  JLM_ASSERT(smap.contains(*oldOutput));
528  auto & newStoreNode = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(smap.lookup(*oldOutput));
529  storeNodes.push_back(&newStoreNode);
530  // use memory state type as response for stores
531  auto vt = std::make_shared<llvm::MemoryStateType>();
532  responseTypes.push_back(vt);
533  }
534 
535  auto lambdaRegion = lambda->subregion();
536  auto portWidth =
537  CalculatePortWidth({ originalLoadNodes, originalStoreNodes, originalDecoupledNodes });
538  auto responses = MemoryResponseOperation::create(
539  *lambdaRegion->argument(argumentIndex),
540  responseTypes,
541  portWidth);
542  // The (decoupled) load nodes are replaced so the pointer to the types will become invalid
543  std::vector<std::shared_ptr<const rvsdg::Type>> loadTypes;
544  std::vector<rvsdg::Output *> loadAddresses;
545  for (size_t i = 0; i < loadNodes.size(); ++i)
546  {
547  auto routed = route_response_rhls(loadNodes[i]->region(), responses[i]);
548  // The smap contains the nodes from the original lambda so we need to use the original load node
549  // when replacing the load since the smap must be updated
550  auto replacement = ReplaceLoad(smap, originalLoadNodes[i], routed);
551  auto address =
552  route_request_rhls(lambdaRegion, replacement->output(replacement->noutputs() - 1));
553  loadAddresses.push_back(address);
554  std::shared_ptr<const rvsdg::Type> type;
555  if (auto loadOperation = dynamic_cast<const LoadOperation *>(&replacement->GetOperation()))
556  {
557  type = loadOperation->GetLoadedType();
558  }
559  else if (
560  auto loadOperation =
561  dynamic_cast<const DecoupledLoadOperation *>(&replacement->GetOperation()))
562  {
563  type = loadOperation->GetLoadedType();
564  }
565  else
566  {
567  JLM_UNREACHABLE("Unknown load GetOperation");
568  }
569  JLM_ASSERT(type);
570  loadTypes.push_back(type);
571  }
572  for (size_t i = 0; i < decoupledNodes.size(); ++i)
573  {
574  auto response = responses[loadNodes.size() + i];
575  auto node = decoupledNodes[i];
576 
577  // TODO: this beahvior is not completly correct - if a function returns a top-level result from
578  // a decouple it fails and smap translation would be required
579  auto replacement = ReplaceDecouple(lambda, node, response);
580  auto addr = route_request_rhls(lambdaRegion, replacement->output(1));
581  loadAddresses.push_back(addr);
582  loadTypes.push_back(dynamic_cast<const DecoupledLoadOperation *>(&replacement->GetOperation())
583  ->GetLoadedType());
584  }
585  std::vector<rvsdg::Output *> storeOperands;
586  for (size_t i = 0; i < storeNodes.size(); ++i)
587  {
588  auto response = responses[loadNodes.size() + decoupledNodes.size() + i];
589  auto routed = route_response_rhls(storeNodes[i]->region(), response);
590  // The smap contains the nodes from the original lambda so we need to use the original store
591  // node when replacing the store since the smap must be updated
592  auto replacement = ReplaceStore(smap, originalStoreNodes[i], routed);
593  auto addr = route_request_rhls(lambdaRegion, replacement->output(replacement->noutputs() - 2));
594  auto data = route_request_rhls(lambdaRegion, replacement->output(replacement->noutputs() - 1));
595  storeOperands.push_back(addr);
596  storeOperands.push_back(data);
597  }
598 
599  return MemoryRequestOperation::create(loadAddresses, loadTypes, storeOperands, lambdaRegion)[0];
600 }
601 
602 static void
604 {
605  //
606  // Replacing memory nodes with nodes that have explicit memory ports requires arguments and
607  // results to be added to the lambda. The arguments must be added before the memory nodes are
608  // replaced, else the input of the new memory node will be left dangling, which is not allowed. We
609  // therefore need to first replace the lambda node with a new lambda node that has the new
610  // arguments and results. We can then replace the memory nodes and connect them to the new
611  // arguments.
612  //
613 
614  const auto & graph = rvsdgModule.Rvsdg();
615  const auto rootRegion = &graph.GetRootRegion();
616  if (rootRegion->numNodes() != 1)
617  {
618  throw std::logic_error("Root should have only one node now");
619  }
620 
621  const auto lambda = dynamic_cast<rvsdg::LambdaNode *>(rootRegion->Nodes().begin().ptr());
622  if (!lambda)
623  {
624  throw std::logic_error("Node needs to be a lambda");
625  }
626 
627  //
628  // Converting loads and stores to explicitly use memory ports
629  // This modifies the function signature so we create a new lambda node to replace the old one
630  //
631  const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(lambda->GetOperation());
632  auto oldFunctionType = op.type();
633  std::vector<std::shared_ptr<const rvsdg::Type>> newArgumentTypes;
634  for (size_t i = 0; i < oldFunctionType.NumArguments(); ++i)
635  {
636  newArgumentTypes.push_back(oldFunctionType.Arguments()[i]);
637  }
638  std::vector<std::shared_ptr<const rvsdg::Type>> newResultTypes;
639  for (size_t i = 0; i < oldFunctionType.NumResults(); ++i)
640  {
641  newResultTypes.push_back(oldFunctionType.Results()[i]);
642  }
643 
644  //
645  // Get the load and store nodes and add an argument and result for each to represent the memory
646  // response and request ports
647  //
648  auto tracedPointerNodesVector = TracePointerArguments(lambda);
649 
650  std::unordered_set<rvsdg::Node *> accountedNodes;
651  for (auto & portNode : tracedPointerNodesVector)
652  {
653  auto portWidth = CalculatePortWidth(portNode);
654  auto responseTypePtr = get_mem_res_type(rvsdg::BitType::Create(portWidth));
655  auto requestTypePtr = get_mem_req_type(rvsdg::BitType::Create(portWidth), false);
656  auto requestTypePtrWrite = get_mem_req_type(rvsdg::BitType::Create(portWidth), true);
657  newArgumentTypes.push_back(responseTypePtr);
658  if (portNode.storeNodes.empty())
659  {
660  newResultTypes.push_back(requestTypePtr);
661  }
662  else
663  {
664  newResultTypes.push_back(requestTypePtrWrite);
665  }
666  accountedNodes.insert(portNode.loadNodes.begin(), portNode.loadNodes.end());
667  accountedNodes.insert(portNode.storeNodes.begin(), portNode.storeNodes.end());
668  accountedNodes.insert(portNode.decoupleNodes.begin(), portNode.decoupleNodes.end());
669  }
670  std::vector<rvsdg::Node *> unknownLoadNodes;
671  std::vector<rvsdg::Node *> unknownStoreNodes;
672  std::vector<rvsdg::Node *> unknownDecoupledNodes;
674  rootRegion,
675  unknownLoadNodes,
676  unknownStoreNodes,
677  unknownDecoupledNodes,
678  accountedNodes);
679  if (!unknownLoadNodes.empty() || !unknownStoreNodes.empty() || !unknownDecoupledNodes.empty())
680  {
681  auto portWidth =
682  CalculatePortWidth({ unknownLoadNodes, unknownStoreNodes, unknownDecoupledNodes });
683  auto responseTypePtr = get_mem_res_type(rvsdg::BitType::Create(portWidth));
684  auto requestTypePtr = get_mem_req_type(rvsdg::BitType::Create(portWidth), false);
685  auto requestTypePtrWrite = get_mem_req_type(rvsdg::BitType::Create(portWidth), true);
686  // Extra port for loads/stores not associated to a port yet (i.e., unknown base pointer)
687  newArgumentTypes.push_back(responseTypePtr);
688  if (unknownStoreNodes.empty())
689  {
690  newResultTypes.push_back(requestTypePtr);
691  }
692  else
693  {
694  newResultTypes.push_back(requestTypePtrWrite);
695  }
696  }
697 
698  //
699  // Create new lambda and copy the region from the old lambda
700  //
701  auto newFunctionType = rvsdg::FunctionType::Create(newArgumentTypes, newResultTypes);
702  auto newLambda = rvsdg::LambdaNode::Create(
703  *lambda->region(),
704  llvm::LlvmLambdaOperation::Create(newFunctionType, op.name(), op.linkage(), op.attributes()));
705 
707  for (const auto & ctxvar : lambda->GetContextVars())
708  {
709  smap.insert(ctxvar.inner, newLambda->AddContextVar(*ctxvar.input->origin()).inner);
710  }
711 
712  auto args = lambda->GetFunctionArguments();
713  auto newArgs = newLambda->GetFunctionArguments();
714  // The new function has more arguments than the old function.
715  // Substitution of existing arguments is safe, but note
716  // that this is not an isomorphism.
717  JLM_ASSERT(args.size() <= newArgs.size());
718  for (size_t i = 0; i < args.size(); ++i)
719  {
720  smap.insert(args[i], newArgs[i]);
721  }
722  lambda->subregion()->copy(newLambda->subregion(), smap);
723 
724  //
725  // All memory nodes need to be replaced with new nodes that have explicit memory ports.
726  // This needs to happen first and the smap needs to be updated with the new nodes,
727  // before we can use the original lambda results and look them up in the updated smap.
728  //
729 
730  std::vector<rvsdg::Output *> newResults;
731  // The new arguments are placed directly after the original arguments so we create an index that
732  // points to the first new argument
733  auto newArgumentsIndex = args.size();
734  for (auto & portNode : tracedPointerNodesVector)
735  {
736  newResults.push_back(ConnectRequestResponseMemPorts(
737  newLambda,
738  newArgumentsIndex++,
739  smap,
740  portNode.loadNodes,
741  portNode.storeNodes,
742  portNode.decoupleNodes));
743  }
744  if (!unknownLoadNodes.empty() || !unknownStoreNodes.empty() || !unknownDecoupledNodes.empty())
745  {
746  newResults.push_back(ConnectRequestResponseMemPorts(
747  newLambda,
748  newArgumentsIndex++,
749  smap,
750  unknownLoadNodes,
751  unknownStoreNodes,
752  unknownDecoupledNodes));
753  }
754 
755  std::vector<rvsdg::Output *> originalResults;
756  for (auto result : lambda->GetFunctionResults())
757  {
758  originalResults.push_back(&smap.lookup(*result->origin()));
759  }
760  originalResults.insert(originalResults.end(), newResults.begin(), newResults.end());
761  auto newOut = newLambda->finalize(originalResults);
762  auto oldExport = llvm::ComputeCallSummary(*lambda).GetRvsdgExport();
763  rvsdg::GraphExport::Create(*newOut, oldExport ? oldExport->Name() : "");
764 
765  JLM_ASSERT(lambda->output()->nusers() == 1);
766  lambda->region()->RemoveResults({ (*lambda->output()->Users().begin()).index() });
767  remove(lambda);
768 
769  // Remove imports for decouple_ function pointers
771  util::StatisticsCollector statisticsCollector;
772  dne.Run(*newLambda->subregion(), statisticsCollector);
773 
774  //
775  // TODO
776  // RemoveUnusedStates also creates a new lambda, which we have already done above.
777  // It might be better to apply this functionality above such that we only create a new lambda
778  // once.
779  //
780  UnusedStateRemoval::CreateAndRun(rvsdgModule, statisticsCollector);
781 
782  // Need to get the lambda from the root since remote_unused_state replaces the lambda
783  JLM_ASSERT(rootRegion->numNodes() == 1);
784  newLambda = util::assertedCast<rvsdg::LambdaNode>(rootRegion->Nodes().begin().ptr());
785  auto decouple_funcs = find_function_arguments(newLambda, "decoupled");
786  // make sure context vars are actually dead
787  for (auto cv : decouple_funcs)
788  {
789  JLM_ASSERT(cv.inner->nusers() == 0);
790  }
791  // remove dead cvargs
792  newLambda->PruneLambdaInputs();
793 }
794 
795 MemoryConverter::~MemoryConverter() noexcept = default;
796 
799 {}
800 
801 void
803 {
804  ConvertMemory(rvsdgModule);
805 }
806 
807 }
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &value, size_t capacity, bool pass_through=false)
Definition: hls.hpp:438
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, jlm::rvsdg::Output &load_result, size_t capacity)
Definition: hls.hpp:1157
std::shared_ptr< const rvsdg::Type > GetLoadedType() const noexcept
Definition: hls.hpp:1174
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, const std::vector< jlm::rvsdg::Output * > &states, jlm::rvsdg::Output &load_result)
Definition: hls.hpp:957
~MemoryConverter() noexcept override
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: mem-conv.cpp:802
static std::vector< jlm::rvsdg::Output * > create(const std::vector< jlm::rvsdg::Output * > &load_operands, const std::vector< std::shared_ptr< const rvsdg::Type >> &loadTypes, const std::vector< jlm::rvsdg::Output * > &store_operands, rvsdg::Region *)
Definition: hls.hpp:1331
static std::vector< jlm::rvsdg::Output * > create(rvsdg::Output &result, const std::vector< std::shared_ptr< const rvsdg::Type >> &output_types, int in_width)
Definition: hls.hpp:1234
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: rhls-dne.cpp:516
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, const std::vector< jlm::rvsdg::Output * > &states)
Definition: hls.hpp:1096
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &addr, jlm::rvsdg::Output &value, const std::vector< jlm::rvsdg::Output * > &states, jlm::rvsdg::Output &resp)
Definition: hls.hpp:1428
static void CreateAndRun(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector)
Call operation class.
Definition: call.hpp:249
rvsdg::GraphExport * GetRvsdgExport() const noexcept
Lambda operation.
Definition: lambda.hpp:30
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:77
static rvsdg::Output * Create(const std::vector< rvsdg::Output * > &operands)
static jlm::rvsdg::Output * Create(rvsdg::Region &region, std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: operators.hpp:1024
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
static GraphExport & Create(Output &origin, std::string name)
Definition: graph.cpp:62
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
const FunctionType & type() const noexcept
Definition: lambda.hpp:36
bool IsDead() const noexcept
Determines whether the node is dead.
Definition: node.hpp:688
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
NodeOutput * output(size_t index) const noexcept
Definition: node.hpp:650
size_t ninputs() const noexcept
Definition: node.hpp:609
size_t noutputs() const noexcept
Definition: node.hpp:644
UsersRange Users()
Definition: node.hpp:354
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Represents the result of a region.
Definition: region.hpp:120
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
rvsdg::StructuralNode * node() const noexcept
Definition: region.hpp:369
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
const SimpleOperation & GetOperation() const noexcept override
Definition: simple-node.cpp:48
NodeInput * input(size_t index) const noexcept
Definition: simple-node.hpp:82
NodeOutput * output(size_t index) const noexcept
Definition: simple-node.hpp:88
void insert(const Output *original, Output *substitute)
bool contains(const Output &original) const noexcept
Output & lookup(const Output &original) const
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
std::shared_ptr< const BundleType > get_mem_res_type(std::shared_ptr< const jlm::rvsdg::Type > dataType)
Definition: hls.cpp:335
rvsdg::Output * route_response_rhls(rvsdg::Region *target, rvsdg::Output *response)
void gather_mem_nodes(rvsdg::Region *region, std::vector< rvsdg::Node * > &loadNodes, std::vector< rvsdg::Node * > &storeNodes, std::vector< rvsdg::Node * > &decoupleNodes, std::unordered_set< rvsdg::Node * > exclude)
Definition: mem-conv.cpp:225
std::shared_ptr< const BundleType > get_mem_req_type(std::shared_ptr< const rvsdg::Type > elementType, bool write)
Definition: hls.cpp:320
void trace_function_calls(rvsdg::Output *output, std::vector< rvsdg::SimpleNode * > &calls, std::unordered_set< rvsdg::Output * > &visited)
static rvsdg::Output * ConnectRequestResponseMemPorts(const rvsdg::LambdaNode *lambda, size_t argumentIndex, rvsdg::SubstitutionMap &smap, const std::vector< rvsdg::Node * > &originalLoadNodes, const std::vector< rvsdg::Node * > &originalStoreNodes, const std::vector< rvsdg::Node * > &originalDecoupledNodes)
Definition: mem-conv.cpp:484
bool is_function_argument(const rvsdg::LambdaNode::ContextVar &cv)
static rvsdg::SimpleNode * ReplaceLoad(rvsdg::SubstitutionMap &smap, const rvsdg::Node *originalLoad, rvsdg::Output *response)
Definition: mem-conv.cpp:405
rvsdg::Output * route_request_rhls(rvsdg::Region *target, rvsdg::Output *request)
static void TracePointer(rvsdg::Output *output, std::unordered_set< rvsdg::Output * > &visited, TracedPointerNodes &tracedPointerNodes)
Definition: mem-conv.cpp:271
int JlmSize(const jlm::rvsdg::Type *type)
Definition: hls.cpp:344
rvsdg::SimpleNode * find_decouple_response(const rvsdg::LambdaNode *lambda, const llvm::IntegerConstantOperation *request_constant)
Definition: mem-conv.cpp:27
static rvsdg::SimpleNode * ReplaceStore(rvsdg::SubstitutionMap &smap, const rvsdg::Node *originalStore, rvsdg::Output *response)
Definition: mem-conv.cpp:447
rvsdg::Output * route_to_region_rhls(rvsdg::Region *target, rvsdg::Output *out)
void OptimizeReqMemState(rvsdg::Output *req_mem_state)
Definition: mem-conv.cpp:119
rvsdg::LambdaNode * find_containing_lambda(rvsdg::Region *region)
Definition: mem-conv.cpp:366
static void ConvertMemory(rvsdg::RvsdgModule &rvsdgModule)
Definition: mem-conv.cpp:603
const llvm::IntegerConstantOperation * trace_constant(const rvsdg::Output *dst)
std::vector< TracedPointerNodes > TracePointerArguments(const rvsdg::LambdaNode *lambda)
Definition: mem-conv.cpp:339
rvsdg::Input * get_mem_state_user(rvsdg::Output *state_edge)
rvsdg::SimpleNode * ReplaceDecouple(const rvsdg::LambdaNode *lambda, rvsdg::SimpleNode *decouple_request, rvsdg::Output *resp)
Definition: mem-conv.cpp:141
static size_t CalculatePortWidth(const TracedPointerNodes &tracedPointerNodes)
Definition: mem-conv.cpp:376
std::vector< rvsdg::LambdaNode::ContextVar > find_function_arguments(const rvsdg::LambdaNode *lambda, std::string name_contains)
bool is_dec_req(rvsdg::SimpleNode *node)
void OptimizeResMemState(rvsdg::Output *res_mem_state)
Definition: mem-conv.cpp:99
static std::pair< rvsdg::Input *, std::vector< rvsdg::Input * > > TraceEdgeToMerge(rvsdg::Input *state_edge)
Definition: mem-conv.cpp:50
CallSummary ComputeCallSummary(const rvsdg::LambdaNode &lambdaNode)
Definition: CallSummary.cpp:30
static void remove(Node *node)
Definition: region.hpp:932
static std::string type(const Node *n)
Definition: view.cpp:255
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058
std::vector< rvsdg::Node * > loadNodes
Definition: mem-conv.hpp:18
std::vector< rvsdg::Node * > decoupleNodes
Definition: mem-conv.hpp:20
std::vector< rvsdg::Node * > storeNodes
Definition: mem-conv.hpp:19