Jlm
hls-function-util.cpp
Go to the documentation of this file.
1 //
2 // Created by david on 7/2/21.
3 //
4 
6 #include <jlm/hls/ir/hls.hpp>
13 #include <jlm/rvsdg/theta.hpp>
14 #include <jlm/rvsdg/traverser.hpp>
15 #include <jlm/rvsdg/view.hpp>
16 
17 #include <deque>
18 
19 namespace jlm::hls
20 {
21 
22 std::vector<rvsdg::LambdaNode::ContextVar>
23 find_function_arguments(const rvsdg::LambdaNode * lambda, std::string name_contains)
24 {
25  std::vector<rvsdg::LambdaNode::ContextVar> result;
26  for (auto cv : lambda->GetContextVars())
27  {
28  auto ip = cv.input;
29  auto traced = trace_call_rhls(ip);
30  JLM_ASSERT(traced);
31  auto arg = util::assertedCast<const llvm::LlvmGraphImport>(traced);
32  if (dynamic_cast<const rvsdg::FunctionType *>(arg->ImportedType().get())
33  && arg->Name().find(name_contains) != arg->Name().npos)
34  {
35  result.push_back(cv);
36  }
37  }
38  return result;
39 }
40 
41 void
43  rvsdg::Output * output,
44  std::vector<rvsdg::SimpleNode *> & calls,
45  std::unordered_set<rvsdg::Output *> & visited)
46 {
47  if (visited.count(output))
48  {
49  // skip already processed outputs
50  return;
51  }
52  visited.insert(output);
53  for (auto & user : output->Users())
54  {
55  if (auto simplenode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user))
56  {
57  if (dynamic_cast<const llvm::CallOperation *>(&simplenode->GetOperation()))
58  {
59  // TODO: verify this is the right type of function call
60  calls.push_back(simplenode);
61  }
62  else
63  {
64  for (size_t i = 0; i < simplenode->noutputs(); ++i)
65  {
66  trace_function_calls(simplenode->output(i), calls, visited);
67  }
68  }
69  }
70  else if (auto sti = dynamic_cast<rvsdg::StructuralInput *>(&user))
71  {
72  for (auto & arg : sti->arguments)
73  {
74  trace_function_calls(&arg, calls, visited);
75  }
76  }
77  else if (auto r = dynamic_cast<rvsdg::RegionResult *>(&user))
78  {
79  if (auto ber = dynamic_cast<BackEdgeResult *>(r))
80  {
81  trace_function_calls(ber->argument(), calls, visited);
82  }
83  else
84  {
85  trace_function_calls(r->output(), calls, visited);
86  }
87  }
88  else
89  {
90  JLM_UNREACHABLE("THIS SHOULD BE COVERED");
91  }
92  }
93 }
94 
97 {
98  if (auto arg = dynamic_cast<const rvsdg::RegionArgument *>(dst))
99  {
100  return trace_constant(arg->input()->origin());
101  }
102 
103  auto [constantNode, constantOperation] =
104  rvsdg::TryGetSimpleNodeAndOptionalOp<llvm::IntegerConstantOperation>(*dst);
105  if (constantNode)
106  {
107  if (constantOperation)
108  return constantOperation;
109 
110  for (size_t i = 0; i < constantNode->ninputs(); ++i)
111  {
112  // TODO: fix, this is a hack - only works because of distribute constants
113  if (*constantNode->input(i)->Type() == *dst->Type())
114  {
115  return trace_constant(constantNode->input(i)->origin());
116  }
117  }
118  }
119 
120  JLM_UNREACHABLE("Constant not found");
121 }
122 
125 {
126  // create lists of nested regions
127  std::deque<rvsdg::Region *> target_regions = get_parent_regions(target);
128  std::deque<rvsdg::Region *> out_regions = get_parent_regions(out->region());
129  JLM_ASSERT(target_regions.front() == out_regions.front());
130  // remove common ancestor regions
131  rvsdg::Region * common_region = nullptr;
132  while (!target_regions.empty() && !out_regions.empty()
133  && target_regions.front() == out_regions.front())
134  {
135  common_region = target_regions.front();
136  target_regions.pop_front();
137  out_regions.pop_front();
138  }
139  // route out to convergence point from out
140  rvsdg::Output * common_out = route_request_rhls(common_region, out);
141  auto common_loop = dynamic_cast<LoopNode *>(common_region->node());
142  if (common_loop)
143  {
144  // add a backedge to prevent cycles
145  auto arg = common_loop->add_backedge(out->Type());
146  arg->result()->divert_to(common_out);
147  // route inwards from convergence point to target
148  auto result = route_response_rhls(target, arg);
149  return result;
150  }
151  else
152  {
153  // lambda is common region - might create cycle
154  // TODO: how to check that this won't create a cycle
155  JLM_ASSERT(
156  target_regions.empty() || target_regions.front()->node()->region() == common_out->region());
157  return route_response_rhls(target, common_out);
158  }
159 }
160 
163 {
164  if (response->region() == target)
165  {
166  return response;
167  }
168  else
169  {
170  auto parent_response = route_response_rhls(target->node()->region(), response);
171  auto ln = util::assertedCast<LoopNode>(target->node());
172  return ln->addResponseInput(parent_response);
173  }
174 }
175 
178 {
179  if (request->region() == target)
180  {
181  return request;
182  }
183 
184  auto ln = util::assertedCast<LoopNode>(request->region()->node());
185  auto output = ln->addRequestOutput(request);
186 
187  return route_request_rhls(target, output);
188 }
189 
190 std::deque<rvsdg::Region *>
192 {
193  std::deque<rvsdg::Region *> regions;
194  rvsdg::Region * target_region = region;
195  while (!dynamic_cast<const llvm::LlvmLambdaOperation *>(&target_region->node()->GetOperation()))
196  {
197  regions.push_front(target_region);
198  target_region = target_region->node()->region();
199  }
200  regions.push_front(target_region);
201  return regions;
202 }
203 
204 const rvsdg::Output *
206 {
207  // version of trace call for rhls
208  if (auto argument = dynamic_cast<const rvsdg::RegionArgument *>(output))
209  {
210  auto graph = output->region()->graph();
211  if (argument->region() == &graph->GetRootRegion())
212  {
213  return argument;
214  }
215  else if (dynamic_cast<const BackEdgeArgument *>(argument))
216  {
217  // don't follow backedges to avoid cycles
218  return nullptr;
219  }
220  return trace_call_rhls(argument->input());
221  }
222  else if (auto so = dynamic_cast<const rvsdg::StructuralOutput *>(output))
223  {
224  for (auto & r : so->results)
225  {
226  if (auto result = trace_call_rhls(&r))
227  {
228  return result;
229  }
230  }
231  }
232  else if (auto simpleNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*output))
233  {
234  for (size_t i = 0; i < simpleNode->ninputs(); ++i)
235  {
236  auto ip = simpleNode->input(i);
237  if (*ip->Type() == *output->Type())
238  {
239  if (auto result = trace_call_rhls(ip))
240  {
241  return result;
242  }
243  }
244  }
245  }
246  else
247  {
248  JLM_UNREACHABLE("");
249  }
250  return nullptr;
251 }
252 
253 const rvsdg::Output *
255 {
256  // version of trace call for rhls
257  return trace_call_rhls(input->origin());
258 }
259 
260 bool
262 {
263  auto ip = cv.input;
264  auto traced = trace_call_rhls(ip);
265  JLM_ASSERT(traced);
266  auto arg = util::assertedCast<const llvm::LlvmGraphImport>(traced);
267  return dynamic_cast<const rvsdg::FunctionType *>(arg->ImportedType().get());
268 }
269 
270 std::string
272 {
273  auto traced = jlm::hls::trace_call_rhls(input);
274  JLM_ASSERT(traced);
275  auto arg = jlm::util::assertedCast<const jlm::llvm::LlvmGraphImport>(traced);
276  return arg->Name();
277 }
278 
279 bool
281 {
282  if (dynamic_cast<const llvm::CallOperation *>(&node->GetOperation()))
283  {
284  auto name = get_function_name(node->input(0));
285  if (name.rfind("decouple_req") != name.npos)
286  return true;
287  }
288  return false;
289 }
290 
291 bool
293 {
294  if (dynamic_cast<const llvm::CallOperation *>(&node->GetOperation()))
295  {
296  auto name = get_function_name(node->input(0));
297  if (name.rfind("decouple_res") != name.npos)
298  return true;
299  }
300  return false;
301 }
302 
303 rvsdg::Input *
305 {
306  JLM_ASSERT(state_edge);
307  JLM_ASSERT(state_edge->nusers() == 1);
308  JLM_ASSERT(rvsdg::is<llvm::MemoryStateType>(state_edge->Type()));
309  return &state_edge->SingleUser();
310 }
311 
314 {
315  if (auto ba = dynamic_cast<BackEdgeArgument *>(out))
316  {
317  return FindSourceNode(ba->result()->origin());
318  }
319  else if (auto ra = dynamic_cast<rvsdg::RegionArgument *>(out))
320  {
321  if (ra->input() && rvsdg::TryGetOwnerNode<LoopNode>(*ra->input()))
322  {
323  return FindSourceNode(ra->input()->origin());
324  }
325  else
326  {
327  // lambda argument
328  return ra;
329  }
330  }
331  else if (auto so = dynamic_cast<rvsdg::StructuralOutput *>(out))
332  {
333  JLM_ASSERT(rvsdg::TryGetOwnerNode<LoopNode>(*out));
334  return FindSourceNode(so->results.begin()->origin());
335  }
336 
337  JLM_ASSERT(rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*out));
338  return out;
339 }
340 }
BackEdgeResult * result()
Definition: hls.hpp:632
BackEdgeArgument * add_backedge(std::shared_ptr< const jlm::rvsdg::Type > type)
Definition: hls.cpp:284
Call operation class.
Definition: call.hpp:249
Lambda operation.
Definition: lambda.hpp:30
Function type class.
void divert_to(Output *new_origin)
Definition: node.cpp:64
Output * origin() const noexcept
Definition: node.hpp:58
Lambda node.
Definition: lambda.hpp:83
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
virtual const Operation & GetOperation() const noexcept=0
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
rvsdg::Input & SingleUser() noexcept
Definition: node.hpp:347
rvsdg::Region * region() const noexcept
Definition: node.cpp:151
UsersRange Users()
Definition: node.hpp:354
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
size_t nusers() const noexcept
Definition: node.hpp:280
Represents the argument of a region.
Definition: region.hpp:41
Represents the result of a region.
Definition: region.hpp:120
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
rvsdg::StructuralNode * node() const noexcept
Definition: region.hpp:369
Graph * graph() const noexcept
Definition: region.hpp:363
const SimpleOperation & GetOperation() const noexcept override
Definition: simple-node.cpp:48
NodeInput * input(size_t index) const noexcept
Definition: simple-node.hpp:82
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
rvsdg::Output * route_response_rhls(rvsdg::Region *target, rvsdg::Output *response)
void trace_function_calls(rvsdg::Output *output, std::vector< rvsdg::SimpleNode * > &calls, std::unordered_set< rvsdg::Output * > &visited)
std::deque< rvsdg::Region * > get_parent_regions(rvsdg::Region *region)
rvsdg::Output * FindSourceNode(rvsdg::Output *out)
bool is_function_argument(const rvsdg::LambdaNode::ContextVar &cv)
rvsdg::Output * route_request_rhls(rvsdg::Region *target, rvsdg::Output *request)
bool is_dec_res(rvsdg::SimpleNode *node)
std::string get_function_name(jlm::rvsdg::Input *input)
rvsdg::Output * route_to_region_rhls(rvsdg::Region *target, rvsdg::Output *out)
const llvm::IntegerConstantOperation * trace_constant(const rvsdg::Output *dst)
const rvsdg::Output * trace_call_rhls(const rvsdg::Output *output)
rvsdg::Input * get_mem_state_user(rvsdg::Output *state_edge)
std::vector< rvsdg::LambdaNode::ContextVar > find_function_arguments(const rvsdg::LambdaNode *lambda, std::string name_contains)
bool is_dec_req(rvsdg::SimpleNode *node)
Bound context variable.
Definition: lambda.hpp:100
rvsdg::Input * input
Input variable bound into lambda node.
Definition: lambda.hpp:108