Jlm
add-prints.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
7 #include <jlm/hls/ir/hls.hpp>
12 #include <jlm/rvsdg/gamma.hpp>
13 #include <jlm/rvsdg/traverser.hpp>
14 
15 namespace jlm::hls
16 {
17 
18 void
20 {
21  for (auto & node : rvsdg::TopDownTraverser(region))
22  {
23  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
24  {
25  for (size_t n = 0; n < structnode->nsubregions(); n++)
26  {
27  add_prints(structnode->subregion(n));
28  }
29  }
30  if (dynamic_cast<jlm::rvsdg::SimpleNode *>(node) && node->noutputs() == 1
31  && jlm::rvsdg::is<rvsdg::BitType>(node->output(0)->Type())
32  && !jlm::rvsdg::is<llvm::UndefValueOperation>(node))
33  {
34  auto out = node->output(0);
35  std::vector<jlm::rvsdg::Input *> old_users;
36  for (auto & user : out->Users())
37  old_users.push_back(&user);
38  auto new_out = PrintOperation::create(*out)[0];
39  for (auto user : old_users)
40  {
41  user->divert_to(new_out);
42  }
43  }
44  }
45 }
46 
47 void
49 {
50  auto & graph = rm.Rvsdg();
51  auto root = &graph.GetRootRegion();
52  add_prints(root);
53 }
54 
55 void
57 {
58  auto & graph = rm.Rvsdg();
59  auto root = &graph.GetRootRegion();
60  // TODO: make this less hacky by using the correct state types
61  auto fct =
64  graph,
65  fct,
66  "printnode",
69  convert_prints(root, &printf, fct);
70 }
71 
72 void
74  rvsdg::Region * region,
75  jlm::rvsdg::Output * printf,
76  const std::shared_ptr<const rvsdg::FunctionType> & functionType)
77 {
78  for (auto & node : rvsdg::TopDownTraverser(region))
79  {
80  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
81  {
82  for (size_t n = 0; n < structnode->nsubregions(); n++)
83  {
84  convert_prints(structnode->subregion(n), printf, functionType);
85  }
86  }
87  else if (auto po = dynamic_cast<const PrintOperation *>(&(node->GetOperation())))
88  {
89  auto printf_local = &rvsdg::RouteToRegion(*printf,
90  *region); // TODO: prevent repetition?
91  auto & constantNode = llvm::IntegerConstantOperation::Create(*region, 64, po->id());
92  jlm::rvsdg::Output * val = node->input(0)->origin();
93  if (*val->Type() != *jlm::rvsdg::BitType::Create(64))
94  {
95  auto bt = std::dynamic_pointer_cast<const rvsdg::BitType>(val->Type());
96  JLM_ASSERT(bt);
98  }
99  llvm::CallOperation::Create(printf_local, functionType, { constantNode.output(0), val });
100  node->output(0)->divert_users(node->input(0)->origin());
101  jlm::rvsdg::remove(node);
102  }
103  }
104 }
105 
106 }
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &value)
Definition: hls.hpp:557
static std::vector< rvsdg::Output * > Create(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:464
static rvsdg::Node & Create(rvsdg::Region &region, IntegerValueRepresentation representation)
static LlvmGraphImport & createFunctionImport(rvsdg::Graph &graph, std::shared_ptr< const rvsdg::FunctionType > functionType, std::string name, Linkage linkage, CallingConvention callingConvention)
static rvsdg::Output & Create(rvsdg::Output &operand, const std::shared_ptr< const rvsdg::Type > &resultType)
Definition: operators.hpp:853
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const FunctionType > Create(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> argumentTypes, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> resultTypes)
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
size_t noutputs() const noexcept
Definition: node.hpp:644
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
#define JLM_ASSERT(x)
Definition: common.hpp:16
void convert_prints(llvm::LlvmRvsdgModule &rm)
Definition: add-prints.cpp:56
void add_prints(rvsdg::Region *region)
Definition: add-prints.cpp:19
static void remove(Node *node)
Definition: region.hpp:978
Output & RouteToRegion(Output &output, Region &region)
Definition: node.cpp:381