Jlm
alloca-conv.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2021 David Metz <david.c.metz@ntnu.no>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/hls/ir/hls.hpp>
20 #include <jlm/rvsdg/traverser.hpp>
21 
22 namespace jlm::hls
23 {
24 
26 {
27 public:
28  std::vector<jlm::rvsdg::SimpleNode *> load_nodes;
29  std::vector<jlm::rvsdg::SimpleNode *> store_nodes;
30 
32  {
33  trace(op);
34  }
35 
36 private:
37  void
39  {
40  if (!rvsdg::is<llvm::PointerType>(op->Type()))
41  {
42  // only process pointer outputs
43  return;
44  }
45  if (visited.count(op))
46  {
47  // skip already processed outputs
48  return;
49  }
50  visited.insert(op);
51  for (auto & user : op->Users())
52  {
53  if (auto simplenode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user))
54  {
55  if (dynamic_cast<const jlm::llvm::StoreNonVolatileOperation *>(&simplenode->GetOperation()))
56  {
57  store_nodes.push_back(simplenode);
58  }
59  else if (dynamic_cast<const jlm::llvm::LoadNonVolatileOperation *>(
60  &simplenode->GetOperation()))
61  {
62  load_nodes.push_back(simplenode);
63  }
64  else if (dynamic_cast<const jlm::llvm::CallOperation *>(&simplenode->GetOperation()))
65  {
66  // TODO: verify this is the right type of function call
67  throw util::Error("encountered a call for an alloca");
68  }
69  else
70  {
71  for (size_t i = 0; i < simplenode->noutputs(); ++i)
72  {
73  trace(simplenode->output(i));
74  }
75  }
76  }
77  else if (auto sti = dynamic_cast<rvsdg::StructuralInput *>(&user))
78  {
79  for (auto & arg : sti->arguments)
80  {
81  trace(&arg);
82  }
83  }
84  else if (auto r = dynamic_cast<rvsdg::RegionResult *>(&user))
85  {
86  if (auto ber = dynamic_cast<BackEdgeResult *>(r))
87  {
88  trace(ber->argument());
89  }
90  else
91  {
92  trace(r->output());
93  }
94  }
95  else
96  {
97  JLM_UNREACHABLE("THIS SHOULD BE COVERED");
98  }
99  }
100  }
101 
102  std::unordered_set<jlm::rvsdg::Output *> visited;
103 };
104 
105 static jlm::rvsdg::Output *
107 {
108  // TODO: handle geps that are not direct predecessors
109  auto & node = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*o);
110  util::assertedCast<const jlm::llvm::GetElementPtrOperation>(&node.GetOperation());
111  // pointer to array, i.e. first index is zero
112  // TODO: check
113  JLM_ASSERT(node.ninputs() == 3);
114  return node.input(2)->origin();
115 }
116 
117 static void
119 {
120  for (auto & node : rvsdg::TopDownTraverser(region))
121  {
122  if (auto structnode = dynamic_cast<rvsdg::StructuralNode *>(node))
123  {
124  for (size_t n = 0; n < structnode->nsubregions(); n++)
125  {
126  alloca_conv(structnode->subregion(n));
127  }
128  }
129  else if (auto po = dynamic_cast<const jlm::llvm::AllocaOperation *>(&(node->GetOperation())))
130  {
131  // ensure that the size is one
132  JLM_ASSERT(node->ninputs() == 1);
133  auto constant_output = dynamic_cast<rvsdg::NodeOutput *>(node->input(0)->origin());
134  JLM_ASSERT(constant_output);
135  auto constant_operation = dynamic_cast<const llvm::IntegerConstantOperation *>(
136  &constant_output->node()->GetOperation());
137  JLM_ASSERT(constant_operation);
138  JLM_ASSERT(constant_operation->Representation().to_uint() == 1);
139  // ensure that the alloca is an array type
140  auto at = std::dynamic_pointer_cast<const llvm::ArrayType>(po->ValueType());
141  JLM_ASSERT(at);
142  // detect loads and stores attached to alloca
143  TraceAllocaUses ta(node->output(0));
144  // create memory + response
145  auto mem_outs = LocalMemoryOperation::create(at, node->region());
146  auto resp_outs = LocalMemoryResponseOperation::create(*mem_outs[0], ta.load_nodes.size());
147  std::cout << "alloca converted " << at->debug_string() << std::endl;
148  // replace gep outputs (convert pointer to index calculation)
149  // replace loads and stores
150  std::vector<jlm::rvsdg::Output *> load_addrs;
151  for (auto l : ta.load_nodes)
152  {
153  auto index = gep_to_index(l->input(0)->origin());
154  auto response = route_response_rhls(l->region(), resp_outs.front());
155  resp_outs.erase(resp_outs.begin());
156  std::vector<jlm::rvsdg::Output *> states;
157  for (size_t i = 1; i < l->ninputs(); ++i)
158  {
159  states.push_back(l->input(i)->origin());
160  }
161  auto load_outs = LocalLoadOperation::create(*index, states, *response);
162  auto nn = dynamic_cast<rvsdg::NodeOutput *>(load_outs[0])->node();
163  for (size_t i = 0; i < l->noutputs(); ++i)
164  {
165  l->output(i)->divert_users(nn->output(i));
166  }
167  remove(l);
168  auto addr = route_request_rhls(node->region(), load_outs.back());
169  load_addrs.push_back(addr);
170  }
171  std::vector<jlm::rvsdg::Output *> store_operands;
172  for (auto s : ta.store_nodes)
173  {
174  auto index = gep_to_index(s->input(0)->origin());
175  std::vector<jlm::rvsdg::Output *> states;
176  for (size_t i = 2; i < s->ninputs(); ++i)
177  {
178  states.push_back(s->input(i)->origin());
179  }
180  auto store_outs = LocalStoreOperation::create(*index, *s->input(1)->origin(), states);
181  auto nn = dynamic_cast<rvsdg::NodeOutput *>(store_outs[0])->node();
182  for (size_t i = 0; i < s->noutputs(); ++i)
183  {
184  s->output(i)->divert_users(nn->output(i));
185  }
186  remove(s);
187  auto addr = route_request_rhls(node->region(), store_outs[store_outs.size() - 2]);
188  auto data = route_request_rhls(node->region(), store_outs.back());
189  store_operands.push_back(addr);
190  store_operands.push_back(data);
191  }
192  // TODO: ensure that loads/stores are either alloca or global, never both
193  // TODO: ensure that loads/stores have same width and alignment and geps can be merged -
194  // otherwise slice? create request
195  auto req_outs = LocalMemoryRequestOperation::create(*mem_outs[1], load_addrs, store_operands);
196 
197  // remove alloca from memstate merge
198  // TODO: handle general case of other nodes getting state edge without a merge
199  JLM_ASSERT(node->output(1)->nusers() == 1);
200  auto & merge_in = *node->output(1)->Users().begin();
201  auto merge_node = rvsdg::TryGetOwnerNode<rvsdg::Node>(merge_in);
202  if (dynamic_cast<const llvm::MemoryStateMergeOperation *>(&merge_node->GetOperation()))
203  {
204  // merge after alloca -> remove merge
205  JLM_ASSERT(merge_node->ninputs() == 2);
206  auto other_index = merge_in.index() ? 0 : 1;
207  merge_node->output(0)->divert_users(merge_node->input(other_index)->origin());
208  jlm::rvsdg::remove(merge_node);
209  }
210  else
211  {
212  // TODO: fix this properly by adding a state edge to the LambdaEntryMemState and routing it
213  // to the region
214  JLM_ASSERT(false);
215  }
216 
217  // TODO: run dne to
218  // remove loads/stores
219  // remove geps
220  // remove alloca pointer users
221  // remove alloca
222  }
223  }
224 }
225 
227 
230 {}
231 
232 void
234 {
235  alloca_conv(&rvsdgModule.Rvsdg().GetRootRegion());
236 }
237 
238 } // namespace jlm::hls
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
~AllocaNodeConversion() noexcept override
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &index, const std::vector< jlm::rvsdg::Output * > &states, jlm::rvsdg::Output &load_result)
Definition: hls.hpp:1596
static std::vector< jlm::rvsdg::Output * > create(std::shared_ptr< const llvm::ArrayType > at, rvsdg::Region *region)
Definition: hls.hpp:1490
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &mem, const std::vector< jlm::rvsdg::Output * > &load_operands, const std::vector< jlm::rvsdg::Output * > &store_operands)
Definition: hls.hpp:1742
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &mem, size_t resp_count)
Definition: hls.hpp:1533
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &index, jlm::rvsdg::Output &value, const std::vector< jlm::rvsdg::Output * > &states)
Definition: hls.hpp:1669
void trace(jlm::rvsdg::Output *op)
Definition: alloca-conv.cpp:38
std::vector< jlm::rvsdg::SimpleNode * > store_nodes
Definition: alloca-conv.cpp:29
std::unordered_set< jlm::rvsdg::Output * > visited
TraceAllocaUses(jlm::rvsdg::Output *op)
Definition: alloca-conv.cpp:31
std::vector< jlm::rvsdg::SimpleNode * > load_nodes
Definition: alloca-conv.cpp:28
Call operation class.
Definition: call.hpp:249
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
UsersRange Users()
Definition: node.hpp:354
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Represents the result of a region.
Definition: region.hpp:120
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
Represents an RVSDG transformation.
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
rvsdg::Output * route_response_rhls(rvsdg::Region *target, rvsdg::Output *response)
rvsdg::Output * route_request_rhls(rvsdg::Region *target, rvsdg::Output *request)
static void alloca_conv(rvsdg::Region *region)
static jlm::rvsdg::Output * gep_to_index(jlm::rvsdg::Output *o)
static void remove(Node *node)
Definition: region.hpp:932