Jlm
add-prints.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
7 #include <jlm/hls/ir/hls.hpp>
10 #include <jlm/rvsdg/gamma.hpp>
11 #include <jlm/rvsdg/theta.hpp>
12 #include <jlm/rvsdg/traverser.hpp>
13 
14 namespace jlm::hls
15 {
16 
17 void
19 {
20  for (auto & node : rvsdg::TopDownTraverser(region))
21  {
22  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
23  {
24  for (size_t n = 0; n < structnode->nsubregions(); n++)
25  {
26  add_prints(structnode->subregion(n));
27  }
28  }
29  if (dynamic_cast<jlm::rvsdg::SimpleNode *>(node) && node->noutputs() == 1
30  && jlm::rvsdg::is<rvsdg::BitType>(node->output(0)->Type())
31  && !jlm::rvsdg::is<llvm::UndefValueOperation>(node))
32  {
33  auto out = node->output(0);
34  std::vector<jlm::rvsdg::Input *> old_users;
35  for (auto & user : out->Users())
36  old_users.push_back(&user);
37  auto new_out = PrintOperation::create(*out)[0];
38  for (auto user : old_users)
39  {
40  user->divert_to(new_out);
41  }
42  }
43  }
44 }
45 
46 void
48 {
49  auto & graph = rm.Rvsdg();
50  auto root = &graph.GetRootRegion();
51  add_prints(root);
52 }
53 
54 void
56 {
57  auto & graph = rm.Rvsdg();
58  auto root = &graph.GetRootRegion();
59  // TODO: make this less hacky by using the correct state types
60  auto fct =
62  auto & printf =
64  convert_prints(root, &printf, fct);
65 }
66 
67 void
69  rvsdg::Region * region,
70  jlm::rvsdg::Output * printf,
71  const std::shared_ptr<const rvsdg::FunctionType> & functionType)
72 {
73  for (auto & node : rvsdg::TopDownTraverser(region))
74  {
75  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
76  {
77  for (size_t n = 0; n < structnode->nsubregions(); n++)
78  {
79  convert_prints(structnode->subregion(n), printf, functionType);
80  }
81  }
82  else if (auto po = dynamic_cast<const PrintOperation *>(&(node->GetOperation())))
83  {
84  auto printf_local = &rvsdg::RouteToRegion(*printf,
85  *region); // TODO: prevent repetition?
86  auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id());
87  jlm::rvsdg::Output * val = node->input(0)->origin();
88  if (*val->Type() != *jlm::rvsdg::BitType::Create(64))
89  {
90  auto bt = std::dynamic_pointer_cast<const rvsdg::BitType>(val->Type());
91  JLM_ASSERT(bt);
93  }
94  llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val });
95  node->output(0)->divert_users(node->input(0)->origin());
96  jlm::rvsdg::remove(node);
97  }
98  }
99 }
100 
101 }
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &value)
Definition: hls.hpp:557
static std::vector< rvsdg::Output * > Create(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:440
static rvsdg::Node & Create(rvsdg::Region &region, IntegerValueRepresentation representation)
static LlvmGraphImport & Create(rvsdg::Graph &graph, std::shared_ptr< const rvsdg::Type > valueType, std::shared_ptr< const rvsdg::Type > importedType, std::string name, Linkage linkage, bool isConstant=false)
Definition: RvsdgModule.hpp:81
static rvsdg::Output & Create(rvsdg::Output &operand, const std::shared_ptr< const rvsdg::Type > &resultType)
Definition: operators.hpp:822
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
size_t noutputs() const noexcept
Definition: node.hpp:644
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
#define JLM_ASSERT(x)
Definition: common.hpp:16
void convert_prints(llvm::LlvmRvsdgModule &rm)
Definition: add-prints.cpp:55
void add_prints(rvsdg::Region *region)
Definition: add-prints.cpp:18
static void remove(Node *node)
Definition: region.hpp:932
Output & RouteToRegion(Output &output, Region &region)
Definition: node.cpp:381