Jlm
rvsdg2rhls.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
27 #include <jlm/hls/opt/cne.hpp>
30 #include <jlm/hls/util/view.hpp>
33 #include <jlm/llvm/DotWriter.hpp>
48 #include <jlm/rvsdg/traverser.hpp>
49 #include <jlm/rvsdg/view.hpp>
50 #include <jlm/util/Statistics.hpp>
51 #include <llvm/IR/LLVMContext.h>
52 #include <llvm/IR/Module.h>
53 #include <llvm/Support/raw_ostream.h>
54 #include <llvm/Support/SourceMgr.h>
55 
56 #include <regex>
57 
58 namespace jlm::hls
59 {
60 
61 void
63 {
64  // TODO: figure out which optimizations to use here
67  constexpr llvm::InvariantValueRedirection::Configuration ivrConfiguration;
68  llvm::InvariantValueRedirection invariantValueRedirection(std::move(ivrConfiguration));
72  loopUnswitching.Run(rm, statisticsCollector);
73  dne.Run(rm, statisticsCollector);
74  cne.Run(rm, statisticsCollector);
75  invariantValueRedirection.Run(rm, statisticsCollector);
76  red.Run(rm, statisticsCollector);
77  dne.Run(rm, statisticsCollector);
78 }
79 
80 void
82 {
83  // TODO: figure out which optimizations to use here
86  constexpr llvm::InvariantValueRedirection::Configuration ivrConfiguration;
87  llvm::InvariantValueRedirection invariantValueRedirection(std::move(ivrConfiguration));
90  loopUnswitching.Run(rm, statisticsCollector);
91  dne.Run(rm, statisticsCollector);
92  cne.Run(rm, statisticsCollector);
93  invariantValueRedirection.Run(rm, statisticsCollector);
94  dne.Run(rm, statisticsCollector);
95  cne.Run(rm, statisticsCollector);
96  dne.Run(rm, statisticsCollector);
97 }
98 
99 bool
100 function_match(rvsdg::LambdaNode * ln, const std::string & function_name)
101 {
102  const std::regex fn_regex(function_name);
103  if (std::regex_match(
104  dynamic_cast<llvm::LlvmLambdaOperation &>(ln->GetOperation()).name(),
105  fn_regex))
106  { // TODO: handle C++ name mangling
107  return true;
108  }
109  return false;
110 }
111 
112 const jlm::rvsdg::Output *
114 {
115  auto graph = input->region()->graph();
116 
117  auto argument = dynamic_cast<const rvsdg::RegionArgument *>(input->origin());
118  const rvsdg::Output * result = nullptr;
119  if (auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(*input->origin()))
120  {
121  result = trace_call(theta->MapOutputLoopVar(*input->origin()).input);
122  }
123  else if (argument == nullptr)
124  {
125  result = input->origin();
126  }
127  else if (argument->region() == &graph->GetRootRegion())
128  {
129  result = argument;
130  }
131  else
132  {
133  JLM_ASSERT(argument->input() != nullptr);
134  result = trace_call(argument->input());
135  }
136  return result;
137 }
138 
139 void
141 {
142  for (auto & node : rvsdg::TopDownTraverser(region))
143  {
144  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
145  {
146  for (size_t n = 0; n < structnode->nsubregions(); n++)
147  {
148  inline_calls(structnode->subregion(n));
149  }
150  }
151  else if (dynamic_cast<const llvm::CallOperation *>(&(node->GetOperation())))
152  {
153  auto traced = jlm::hls::trace_call(node->input(0));
154  auto so = dynamic_cast<const rvsdg::StructuralOutput *>(traced);
155  if (!so)
156  {
157  if (auto graphImport = dynamic_cast<const llvm::LlvmGraphImport *>(traced))
158  {
159  if (graphImport->Name().rfind("decouple_", 0) == 0)
160  {
161  // can't inline pseudo functions used for decoupling
162  continue;
163  }
164  if (graphImport->Name().rfind("hls_", 0) == 0)
165  {
166  // can't inline pseudo functions used for streaming
167  continue;
168  }
169  throw util::Error("can not inline external function " + graphImport->Name());
170  }
171  }
172  JLM_ASSERT(rvsdg::is<rvsdg::LambdaOperation>(so->node()));
173  auto ln = dynamic_cast<const rvsdg::StructuralOutput *>(traced)->node();
175  *dynamic_cast<rvsdg::SimpleNode *>(node),
176  *dynamic_cast<const rvsdg::LambdaNode *>(ln));
177  // restart for this region
178  inline_calls(region);
179  return;
180  }
181  }
182 }
183 
186 {
187  auto op = util::assertedCast<const llvm::DeltaOperation>(&odn->GetOperation());
188  auto name = op->name();
189  std::replace_if(
190  name.begin(),
191  name.end(),
192  [](char c)
193  {
194  return c == '.';
195  },
196  '_');
197  std::cout << "renaming delta node " << op->name() << " to " << name << "\n";
198  auto db = rvsdg::DeltaNode::Create(
199  odn->region(),
201  odn->Type(),
202  name,
204  "",
205  op->constant(),
206  op->getAlignment()));
207  /* add dependencies */
209  for (auto ctxVar : odn->GetContextVars())
210  {
211  auto input = ctxVar.input;
212  auto nd = db->AddContextVar(*input->origin()).inner;
213  rmap.insert(ctxVar.inner, nd);
214  }
215 
216  /* copy subregion */
217  odn->subregion()->copy(db->subregion(), rmap);
218 
219  auto result = &rmap.lookup(*odn->subregion()->result(0)->origin());
220  auto data = &db->finalize(result);
221 
222  odn->output().divert_users(data);
223  jlm::rvsdg::remove(odn);
224  return rvsdg::TryGetOwnerNode<rvsdg::DeltaNode>(*data);
225 }
226 
229 {
230  const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(ln->GetOperation());
231  auto lambda = rvsdg::LambdaNode::Create(
232  *ln->region(),
234  op.Type(),
235  op.name(),
236  link,
237  op.callingConvention(),
238  op.attributes()));
239 
240  /* add context variables */
241  rvsdg::SubstitutionMap subregionmap;
242  for (const auto & cv : ln->GetContextVars())
243  {
244  auto origin = cv.input->origin();
245  auto newcv = lambda->AddContextVar(*origin);
246  subregionmap.insert(cv.inner, newcv.inner);
247  }
248  /* collect function arguments */
249  auto args = ln->GetFunctionArguments();
250  auto newArgs = lambda->GetFunctionArguments();
251  JLM_ASSERT(args.size() == newArgs.size());
252  for (std::size_t n = 0; n < args.size(); ++n)
253  {
254  subregionmap.insert(args[n], newArgs[n]);
255  }
256 
257  /* copy subregion */
258  ln->subregion()->copy(lambda->subregion(), subregionmap);
259 
260  /* collect function results */
261  std::vector<jlm::rvsdg::Output *> results;
262  for (auto result : ln->GetFunctionResults())
263  results.push_back(&subregionmap.lookup(*result->origin()));
264 
265  /* finalize lambda */
266  lambda->finalize(results);
267 
268  divert_users(ln, outputs(lambda));
269  jlm::rvsdg::remove(ln);
270 
271  return lambda;
272 }
273 
274 std::unique_ptr<jlm::llvm::LlvmRvsdgModule>
275 split_hls_function(llvm::LlvmRvsdgModule & rm, const std::string & function_name)
276 {
277  // TODO: use a different datastructure for rhls?
278  // create a copy of rm
279  auto rhls =
281  std::cout << "processing " << rm.SourceFileName().name() << "\n";
282  auto root = &rm.Rvsdg().GetRootRegion();
283  for (auto node : rvsdg::TopDownTraverser(root))
284  {
285  if (auto ln = dynamic_cast<rvsdg::LambdaNode *>(node))
286  {
287  if (!function_match(ln, function_name))
288  {
289  continue;
290  }
291  inline_calls(ln->subregion());
292  split_opt(rm);
293 
295  for (size_t i = 0; i < ln->ninputs(); ++i)
296  {
297  auto orig_node_output = dynamic_cast<rvsdg::NodeOutput *>(ln->input(i)->origin());
298  if (!orig_node_output)
299  {
300  // handle decouple stuff
301  auto oldGraphImport = dynamic_cast<llvm::LlvmGraphImport *>(ln->input(i)->origin());
302  auto & newGraphImport = oldGraphImport->Copy(rhls->Rvsdg().GetRootRegion(), nullptr);
303  smap.insert(ln->input(i)->origin(), &newGraphImport);
304  continue;
305  }
306  auto orig_node = orig_node_output->node();
307  if (auto oln = dynamic_cast<rvsdg::LambdaNode *>(orig_node))
308  {
309  throw util::Error(
310  "Inlining of function "
311  + dynamic_cast<llvm::LlvmLambdaOperation &>(oln->GetOperation()).name()
312  + " not supported");
313  }
314  else if (auto odn = dynamic_cast<rvsdg::DeltaNode *>(orig_node))
315  {
316  auto op = util::assertedCast<const llvm::DeltaOperation>(&odn->GetOperation());
317  // modify name to not contain .
318  if (op->name().find('.') != std::string::npos)
319  {
320  odn = rename_delta(odn);
321  op = util::assertedCast<const llvm::DeltaOperation>(&odn->GetOperation());
322  }
323  std::cout << "delta node " << op->name() << ": " << op->Type()->debug_string() << "\n";
324  // add import for delta to rhls
325  auto & graphImport = llvm::LlvmGraphImport::createGlobalImport(
326  rhls->Rvsdg(),
327  op->Type(),
329  op->name(),
331  op->constant(),
332  op->getAlignment());
333  smap.insert(ln->input(i)->origin(), &graphImport);
334  // add export for delta to rm
335  // TODO: check if not already exported and maybe adjust linkage?
336  rvsdg::GraphExport::Create(odn->output(), op->name());
337  }
338  else
339  {
340  throw util::Error("Unsupported node type: " + orig_node->DebugString());
341  }
342  }
343  // copy function into rhls
344  auto new_ln = ln->copy(&rhls->Rvsdg().GetRootRegion(), smap);
346  auto oldExport = jlm::llvm::ComputeCallSummary(*ln).GetRvsdgExport();
347  rvsdg::GraphExport::Create(*new_ln->output(), oldExport ? oldExport->Name() : "");
348  // add function as input to rm and remove it
349  const auto & op = dynamic_cast<llvm::LlvmLambdaOperation &>(ln->GetOperation());
350  auto & graphImport = llvm::LlvmGraphImport::createFunctionImport(
351  rm.Rvsdg(),
352  op.Type(),
353  op.name(),
354  llvm::Linkage::externalLinkage, // TODO: change linkage?
356  ln->output()->divert_users(&graphImport);
357  remove(ln);
358  std::cout << "function "
359  << dynamic_cast<llvm::LlvmLambdaOperation &>(new_ln->GetOperation()).name()
360  << " extracted for HLS\n";
361  return rhls;
362  }
363  }
364  throw util::Error("HLS function " + function_name + " not found");
365 }
366 
367 void
369 {
370  dump_ref(rhls, path);
371 }
372 
373 std::unique_ptr<rvsdg::TransformationSequence>
374 createTransformationSequence(rvsdg::DotWriter & dotWriter, const bool dumpRvsdgGraphs)
375 {
376  auto predicateCorrelation = std::make_shared<llvm::PredicateCorrelation>();
377  auto deadNodeElimination = std::make_shared<llvm::DeadNodeElimination>();
378  auto commonNodeElimination = std::make_shared<CommonNodeElimination>();
379  auto invariantValueRedirection = std::make_shared<llvm::InvariantValueRedirection>(
381  auto loopUnswitching =
382  std::make_shared<llvm::LoopUnswitching>(llvm::LoopUnswitchingDefaultHeuristic::create());
383  auto ioBarrierRemoval = std::make_shared<IOBarrierRemoval>();
384  auto ioStateElimination = std::make_shared<IOStateElimination>();
385  auto memoryStateSeparation = std::make_shared<MemoryStateSeparation>();
386  auto gammaMerge = std::make_shared<GammaMerge>();
387  auto unusedStateRemoval = std::make_shared<UnusedStateRemoval>();
388  auto constantDistribution = std::make_shared<ConstantDistribution>();
389  auto gammaNodeConversion = std::make_shared<GammaNodeConversion>();
390  auto thetaNodeConversion = std::make_shared<ThetaNodeConversion>();
391  auto rhlsDeadNodeElimination = std::make_shared<RhlsDeadNodeElimination>();
392  auto allocaNodeConversion = std::make_shared<AllocaNodeConversion>();
393  auto streamConversion = std::make_shared<StreamConversion>();
394  auto addressQueueInsertion = std::make_shared<AddressQueueInsertion>();
395  auto memoryStateDecoupling = std::make_shared<MemoryStateDecoupling>();
396  auto memoryConverter = std::make_shared<MemoryConverter>();
397  auto nodeReduction = std::make_shared<llvm::NodeReduction>();
398  auto memoryStateSplitConversion = std::make_shared<MemoryStateSplitConversion>();
399  auto redundantBufferElimination = std::make_shared<RedundantBufferElimination>();
400  auto sinkInsertion = std::make_shared<SinkInsertion>();
401  auto forkInsertion = std::make_shared<ForkInsertion>();
402  auto bufferInsertion = std::make_shared<BufferInsertion>();
403  auto rhlsVerification = std::make_shared<RhlsVerification>();
404 
405  // Use this transformation to dump HLS dot graphs at specific points in the sequence
406  [[maybe_unused]] auto dumpDot = std::make_shared<DumpDotTransformation>();
407 
408  std::vector<std::shared_ptr<rvsdg::Transformation>> sequence({
409  loopUnswitching,
410  deadNodeElimination,
411  commonNodeElimination,
412  invariantValueRedirection,
413  predicateCorrelation,
414  deadNodeElimination,
415  commonNodeElimination,
416  deadNodeElimination,
417  ioBarrierRemoval,
418  ioStateElimination,
419  memoryStateSeparation,
420  gammaMerge,
421  unusedStateRemoval,
422  deadNodeElimination,
423  loopUnswitching,
424  commonNodeElimination,
425  deadNodeElimination,
426  gammaMerge,
427  deadNodeElimination,
428  unusedStateRemoval,
429  constantDistribution,
430  deadNodeElimination,
431  gammaNodeConversion,
432  thetaNodeConversion,
433  commonNodeElimination,
434  rhlsDeadNodeElimination,
435  allocaNodeConversion,
436  streamConversion,
437  addressQueueInsertion,
438  memoryStateDecoupling,
439  unusedStateRemoval,
440  memoryConverter,
441  nodeReduction,
442  memoryStateSplitConversion,
443  redundantBufferElimination,
444  sinkInsertion,
445  forkInsertion,
446  bufferInsertion,
447  rhlsVerification,
448  });
449 
450  return std::make_unique<rvsdg::TransformationSequence>(
451  std::move(sequence),
452  dotWriter,
453  dumpRvsdgGraphs);
454 }
455 
456 void
458 {
459  const std::unique_ptr<llvm::LlvmRvsdgModule> reference(
460  static_cast<llvm::LlvmRvsdgModule *>(rhls.copy().release()));
461  pre_opt(*reference);
462  instrument_ref(*reference);
463  for (size_t i = 0; i < reference->Rvsdg().GetRootRegion().narguments(); ++i)
464  {
465  auto graphImport = util::assertedCast<const llvm::LlvmGraphImport>(
466  reference->Rvsdg().GetRootRegion().argument(i));
467  std::cout << "impport " << graphImport->Name() << ": " << graphImport->Type()->debug_string()
468  << "\n";
469  }
470  ::llvm::LLVMContext ctx;
474  std::error_code EC;
475  ::llvm::raw_fd_ostream os(path.to_str(), EC);
476  lm2->print(os, nullptr);
477 }
478 
479 }
static jlm::util::StatisticsCollector statisticsCollector
Common Node Elimination This is mainly a copy of the CNE optimization in the LLVM backend with the ad...
Definition: cne.hpp:24
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: cne.cpp:623
Call operation class.
Definition: call.hpp:251
rvsdg::GraphExport * GetRvsdgExport() const noexcept
Dead Node Elimination Optimization.
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static std::unique_ptr< DeltaOperation > Create(std::shared_ptr< const rvsdg::Type > type, const std::string &name, const Linkage &linkage, std::string section, bool constant, const size_t alignment)
Definition: delta.hpp:84
static void inlineCall(rvsdg::SimpleNode &callNode, rvsdg::LambdaNode &caller, const rvsdg::LambdaNode &callee)
Definition: inlining.cpp:282
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static std::unique_ptr<::llvm::Module > CreateAndConvertModule(InterProceduralGraphModule &ipGraphModule, ::llvm::LLVMContext &ctx)
static LlvmGraphImport & createGlobalImport(rvsdg::Graph &graph, std::shared_ptr< const rvsdg::Type > valueType, std::shared_ptr< const rvsdg::Type > importedType, std::string name, Linkage linkage, const bool isConstant, const size_t alignment)
static LlvmGraphImport & createFunctionImport(rvsdg::Graph &graph, std::shared_ptr< const rvsdg::FunctionType > functionType, std::string name, Linkage linkage, CallingConvention callingConvention)
LlvmGraphImport & Copy(rvsdg::Region &region, rvsdg::StructuralInput *input) const override
Definition: RvsdgModule.cpp:13
Lambda operation.
Definition: lambda.hpp:30
static std::unique_ptr< LlvmLambdaOperation > Create(std::shared_ptr< const jlm::rvsdg::FunctionType > type, std::string name, const jlm::llvm::Linkage &linkage, jlm::llvm::CallingConvention callingConvention, jlm::llvm::AttributeSet attributes)
Definition: lambda.hpp:84
const std::string & name() const noexcept
Definition: lambda.hpp:42
const util::FilePath & SourceFileName() const noexcept
const std::string & TargetTriple() const noexcept
static std::unique_ptr< LlvmRvsdgModule > Create(const util::FilePath &sourceFileName, const std::string &targetTriple, const std::string &dataLayout)
std::unique_ptr< RvsdgModule > copy() const override
Definition: RvsdgModule.cpp:31
const std::string & DataLayout() const noexcept
static std::shared_ptr< const LoopUnswitchingDefaultHeuristic > create()
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
Definition: reduction.cpp:63
static std::shared_ptr< const PointerType > Create()
Definition: types.cpp:45
static std::unique_ptr< InterProceduralGraphModule > CreateAndConvertModule(LlvmRvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector)
Delta node.
Definition: delta.hpp:129
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: delta.cpp:39
rvsdg::Region * subregion() const noexcept
Definition: delta.hpp:234
const DeltaOperation & GetOperation() const noexcept override
Definition: delta.cpp:71
static DeltaNode * Create(rvsdg::Region *parent, std::unique_ptr< DeltaOperation > op)
Definition: delta.hpp:313
rvsdg::Output & output() const noexcept
Definition: delta.cpp:110
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: delta.hpp:243
static GraphExport & Create(Output &origin, std::string name)
Definition: graph.cpp:62
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
Region * region() const noexcept
Definition: node.cpp:83
Lambda node.
Definition: lambda.hpp:83
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Represents the argument of a region.
Definition: region.hpp:41
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
RegionResult * result(size_t index) const noexcept
Definition: region.hpp:471
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
Graph * graph() const noexcept
Definition: region.hpp:363
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
std::string name() const noexcept
Returns the name of the file, excluding the path.
Definition: file.hpp:81
const std::string & to_str() const noexcept
Definition: file.hpp:275
#define JLM_ASSERT(x)
Definition: common.hpp:16
rvsdg::LambdaNode * change_linkage(rvsdg::LambdaNode *ln, llvm::Linkage link)
Definition: rvsdg2rhls.cpp:228
rvsdg::DeltaNode * rename_delta(rvsdg::DeltaNode *odn)
Definition: rvsdg2rhls.cpp:185
bool function_match(rvsdg::LambdaNode *ln, const std::string &function_name)
Definition: rvsdg2rhls.cpp:100
void split_opt(llvm::LlvmRvsdgModule &rm)
Definition: rvsdg2rhls.cpp:62
void inline_calls(rvsdg::Region *region)
Definition: rvsdg2rhls.cpp:140
static void divert_users(jlm::rvsdg::Output *output, Context &ctx)
Definition: cne.cpp:504
std::unique_ptr< rvsdg::TransformationSequence > createTransformationSequence(rvsdg::DotWriter &dotWriter, const bool dumpRvsdgGraphs)
Definition: rvsdg2rhls.cpp:374
void instrument_ref(llvm::LlvmRvsdgModule &rm)
void dump_ref(llvm::LlvmRvsdgModule &rhls, const util::FilePath &path)
Definition: rvsdg2rhls.cpp:457
void pre_opt(jlm::llvm::LlvmRvsdgModule &rm)
Definition: rvsdg2rhls.cpp:81
std::unique_ptr< jlm::llvm::LlvmRvsdgModule > split_hls_function(llvm::LlvmRvsdgModule &rm, const std::string &function_name)
Definition: rvsdg2rhls.cpp:275
const jlm::rvsdg::Output * trace_call(jlm::rvsdg::Input *input)
Definition: rvsdg2rhls.cpp:113
void rvsdg2ref(llvm::LlvmRvsdgModule &rhls, const util::FilePath &path)
Definition: rvsdg2rhls.cpp:368
CallSummary ComputeCallSummary(const rvsdg::LambdaNode &lambdaNode)
Definition: CallSummary.cpp:30
static void remove(Node *node)
Definition: region.hpp:978
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058