Jlm
call.hpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #ifndef JLM_LLVM_IR_OPERATORS_CALL_HPP
7 #define JLM_LLVM_IR_OPERATORS_CALL_HPP
8 
11 #include <jlm/llvm/ir/tac.hpp>
12 #include <jlm/llvm/ir/types.hpp>
13 #include <jlm/rvsdg/Phi.hpp>
15 
16 namespace jlm::llvm
17 {
18 
23 class CallTypeClassifier final
24 {
25 public:
26  enum class CallType
27  {
33 
39 
44 
49  };
50 
52  : CallType_(callType),
53  Output_(&output)
54  {}
55 
59  [[nodiscard]] CallType
60  GetCallType() const noexcept
61  {
62  return CallType_;
63  }
64 
69  [[nodiscard]] bool
70  IsNonRecursiveDirectCall() const noexcept
71  {
73  }
74 
79  [[nodiscard]] bool
80  IsRecursiveDirectCall() const noexcept
81  {
83  }
84 
88  [[nodiscard]] bool
89  IsDirectCall() const noexcept
90  {
93  }
94 
99  [[nodiscard]] bool
100  IsExternalCall() const noexcept
101  {
103  }
104 
109  [[nodiscard]] bool
110  IsIndirectCall() const noexcept
111  {
113  }
114 
122  [[nodiscard]] rvsdg::Output &
123  GetLambdaOutput() const noexcept
124  {
126  {
127  return *Output_;
128  }
129 
131  auto argument = jlm::util::assertedCast<jlm::rvsdg::RegionArgument>(Output_);
132  /*
133  * FIXME: This assumes that all recursion variables where added before the dependencies. It
134  * would be better if we did not use the index for retrieving the result, but instead
135  * explicitly encoded it in an phi_argument.
136  */
137  return *argument->region()->result(argument->index())->origin();
138  }
139 
146  [[nodiscard]] rvsdg::RegionArgument &
147  GetImport() const noexcept
148  {
150  return *jlm::util::assertedCast<rvsdg::RegionArgument>(Output_);
151  }
152 
161  [[nodiscard]] jlm::rvsdg::Output &
162  GetFunctionOrigin() const noexcept
163  {
164  return *Output_;
165  }
166 
171  [[nodiscard]] bool
172  isSetjmpCall();
173 
178  [[nodiscard]] bool
179  isVaStartCall();
180 
190  static std::unique_ptr<CallTypeClassifier>
192  {
193  rvsdg::AssertGetOwnerNode<rvsdg::LambdaNode>(output);
194  return std::make_unique<CallTypeClassifier>(CallType::NonRecursiveDirectCall, output);
195  }
196 
206  static std::unique_ptr<CallTypeClassifier>
208  {
209  return std::make_unique<CallTypeClassifier>(CallType::RecursiveDirectCall, output);
210  }
211 
221  static std::unique_ptr<CallTypeClassifier>
223  {
224  JLM_ASSERT(argument.region() == &argument.region()->graph()->GetRootRegion());
225  return std::make_unique<CallTypeClassifier>(CallType::ExternalCall, argument);
226  }
227 
234  static std::unique_ptr<CallTypeClassifier>
236  {
237  return std::make_unique<CallTypeClassifier>(CallType::IndirectCall, output);
238  }
239 
240 private:
243 };
244 
249 {
250 public:
251  ~CallOperation() override;
252 
253  explicit CallOperation(std::shared_ptr<const rvsdg::FunctionType> functionType)
254  : SimpleOperation(create_srctypes(functionType), functionType->Results()),
255  FunctionType_(std::move(functionType))
256  {}
257 
258  bool
259  operator==(const Operation & other) const noexcept override;
260 
261  [[nodiscard]] std::string
262  debug_string() const override;
263 
264  [[nodiscard]] const std::shared_ptr<const rvsdg::FunctionType> &
265  GetFunctionType() const noexcept
266  {
267  return FunctionType_;
268  }
269 
270  [[nodiscard]] std::unique_ptr<Operation>
271  copy() const override;
272 
278  [[nodiscard]] static size_t
279  NumArguments(const rvsdg::Node & node) noexcept
280  {
281  JLM_ASSERT(is<CallOperation>(&node));
282  return node.ninputs() - 1;
283  }
284 
290  [[nodiscard]] static rvsdg::Input *
291  Argument(const rvsdg::Node & node, const size_t n)
292  {
293  JLM_ASSERT(is<CallOperation>(&node));
295  return node.input(n + 1);
296  }
297 
301  [[nodiscard]] static rvsdg::Input &
302  GetFunctionInput(const rvsdg::Node & node) noexcept
303  {
304  JLM_ASSERT(is<CallOperation>(&node));
305  const auto functionInput = node.input(0);
306  JLM_ASSERT(is<rvsdg::FunctionType>(functionInput->Type()));
307  return *functionInput;
308  }
309 
313  [[nodiscard]] static rvsdg::Input &
314  GetIOStateInput(const rvsdg::Node & node) noexcept
315  {
316  JLM_ASSERT(is<CallOperation>(&node));
317  const auto ioState = node.input(node.ninputs() - 2);
318  JLM_ASSERT(is<IOStateType>(ioState->Type()));
319  return *ioState;
320  }
321 
325  [[nodiscard]] static rvsdg::Output &
326  GetIOStateOutput(const rvsdg::Node & node) noexcept
327  {
328  JLM_ASSERT(is<CallOperation>(&node));
329  const auto ioState = node.output(node.noutputs() - 2);
330  JLM_ASSERT(is<IOStateType>(ioState->Type()));
331  return *ioState;
332  }
333 
337  [[nodiscard]] static rvsdg::Input &
338  GetMemoryStateInput(const rvsdg::Node & node) noexcept
339  {
340  JLM_ASSERT(is<CallOperation>(&node));
341  const auto memoryState = node.input(node.ninputs() - 1);
342  JLM_ASSERT(is<MemoryStateType>(memoryState->Type()));
343  return *memoryState;
344  }
345 
349  [[nodiscard]] static rvsdg::Output &
350  GetMemoryStateOutput(const rvsdg::Node & node) noexcept
351  {
352  JLM_ASSERT(is<CallOperation>(&node));
353  const auto memoryState = node.output(node.noutputs() - 1);
354  JLM_ASSERT(is<MemoryStateType>(memoryState->Type()));
355  return *memoryState;
356  }
357 
368  [[nodiscard]] static rvsdg::SimpleNode *
369  tryGetMemoryStateEntryMerge(const rvsdg::Node & callNode) noexcept
370  {
371  JLM_ASSERT(is<CallOperation>(&callNode));
372  const auto node =
373  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*GetMemoryStateInput(callNode).origin());
374  return is<CallEntryMemoryStateMergeOperation>(node) ? node : nullptr;
375  }
376 
386  [[nodiscard]] static rvsdg::SimpleNode *
387  tryGetMemoryStateExitSplit(const rvsdg::Node & callNode) noexcept
388  {
389  JLM_ASSERT(is<CallOperation>(&callNode));
390 
391  // If a memory state exit split node is present, then we would expect the node to be the only
392  // user of the memory state output.
393  if (GetMemoryStateOutput(callNode).nusers() != 1)
394  return nullptr;
395 
396  const auto node =
397  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(GetMemoryStateOutput(callNode).SingleUser());
398  return is<CallExitMemoryStateSplitOperation>(node) ? node : nullptr;
399  }
400 
412  static rvsdg::Output &
413  TraceFunctionInput(const rvsdg::SimpleNode & callNode);
414 
422  static std::unique_ptr<CallTypeClassifier>
423  ClassifyCall(const rvsdg::SimpleNode & callNode);
424 
425  static std::unique_ptr<ThreeAddressCode>
427  const Variable * function,
428  std::shared_ptr<const rvsdg::FunctionType> functionType,
429  const std::vector<const Variable *> & arguments)
430  {
431  CheckFunctionInputType(function->type());
432 
433  auto op = std::make_unique<CallOperation>(std::move(functionType));
434  std::vector<const Variable *> operands({ function });
435  operands.insert(operands.end(), arguments.begin(), arguments.end());
436  return ThreeAddressCode::create(std::move(op), operands);
437  }
438 
439  static std::vector<rvsdg::Output *>
441  rvsdg::Output * function,
442  std::shared_ptr<const rvsdg::FunctionType> functionType,
443  const std::vector<rvsdg::Output *> & arguments)
444  {
445  return outputs(&CreateNode(function, std::move(functionType), arguments));
446  }
447 
448  static std::vector<rvsdg::Output *>
450  rvsdg::Region & region,
451  std::unique_ptr<CallOperation> callOperation,
452  const std::vector<rvsdg::Output *> & operands)
453  {
454  return outputs(&CreateNode(region, std::move(callOperation), operands));
455  }
456 
457  static rvsdg::SimpleNode &
459  rvsdg::Region & region,
460  std::unique_ptr<CallOperation> callOperation,
461  const std::vector<rvsdg::Output *> & operands)
462  {
463  CheckFunctionType(*callOperation->GetFunctionType());
464 
465  return rvsdg::SimpleNode::Create(region, std::move(callOperation), operands);
466  }
467 
468  static rvsdg::SimpleNode &
470  rvsdg::Output * function,
471  std::shared_ptr<const rvsdg::FunctionType> functionType,
472  const std::vector<rvsdg::Output *> & arguments)
473  {
474  CheckFunctionInputType(*function->Type());
475 
476  auto callOperation = std::make_unique<CallOperation>(std::move(functionType));
477  std::vector operands({ function });
478  operands.insert(operands.end(), arguments.begin(), arguments.end());
479 
480  return CreateNode(*function->region(), std::move(callOperation), operands);
481  }
482 
483 private:
484  static inline std::vector<std::shared_ptr<const rvsdg::Type>>
485  create_srctypes(const std::shared_ptr<const rvsdg::FunctionType> & functionType)
486  {
487  std::vector<std::shared_ptr<const rvsdg::Type>> types({ functionType });
488  for (auto & argumentType : functionType->Arguments())
489  types.emplace_back(argumentType);
490 
491  return types;
492  }
493 
494  static void
496  {
497  if (!is<rvsdg::FunctionType>(type))
498  throw util::Error("Expected function type.");
499  }
500 
501  static void
503  {
504  auto CheckArgumentTypes = [](const rvsdg::FunctionType & functionType)
505  {
506  if (functionType.NumArguments() < 2)
507  throw util::Error("Expected at least three argument types.");
508 
509  auto memoryStateArgumentIndex = functionType.NumArguments() - 1;
510  auto iOStateArgumentIndex = functionType.NumArguments() - 2;
511 
512  if (!is<MemoryStateType>(functionType.ArgumentType(memoryStateArgumentIndex)))
513  throw util::Error("Expected memory state type.");
514 
515  if (!is<IOStateType>(functionType.ArgumentType(iOStateArgumentIndex)))
516  throw util::Error("Expected IO state type.");
517  };
518 
519  auto CheckResultTypes = [](const rvsdg::FunctionType & functionType)
520  {
521  if (functionType.NumResults() < 2)
522  throw util::Error("Expected at least three result types.");
523 
524  auto memoryStateResultIndex = functionType.NumResults() - 1;
525  auto iOStateResultIndex = functionType.NumResults() - 2;
526 
527  if (!is<MemoryStateType>(functionType.ResultType(memoryStateResultIndex)))
528  throw util::Error("Expected memory state type.");
529 
530  if (!is<IOStateType>(functionType.ResultType(iOStateResultIndex)))
531  throw util::Error("Expected IO state type.");
532  };
533 
534  CheckArgumentTypes(functionType);
535  CheckResultTypes(functionType);
536  }
537 
538  std::shared_ptr<const rvsdg::FunctionType> FunctionType_;
539 };
540 
541 }
542 
543 #endif
Call operation class.
Definition: call.hpp:249
static rvsdg::Input * Argument(const rvsdg::Node &node, const size_t n)
Definition: call.hpp:291
static rvsdg::Input & GetIOStateInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:314
static rvsdg::SimpleNode & CreateNode(rvsdg::Region &region, std::unique_ptr< CallOperation > callOperation, const std::vector< rvsdg::Output * > &operands)
Definition: call.hpp:458
std::shared_ptr< const rvsdg::FunctionType > FunctionType_
Definition: call.hpp:538
static rvsdg::SimpleNode * tryGetMemoryStateExitSplit(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:387
static rvsdg::Input & GetFunctionInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:302
const std::shared_ptr< const rvsdg::FunctionType > & GetFunctionType() const noexcept
Definition: call.hpp:265
static std::vector< std::shared_ptr< const rvsdg::Type > > create_srctypes(const std::shared_ptr< const rvsdg::FunctionType > &functionType)
Definition: call.hpp:485
std::unique_ptr< Operation > copy() const override
Definition: call.cpp:33
static std::unique_ptr< CallTypeClassifier > ClassifyCall(const rvsdg::SimpleNode &callNode)
Classifies a call node.
Definition: call.cpp:47
static rvsdg::Input & GetMemoryStateInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:338
static void CheckFunctionInputType(const jlm::rvsdg::Type &type)
Definition: call.hpp:495
static size_t NumArguments(const rvsdg::Node &node) noexcept
Definition: call.hpp:279
static rvsdg::Output & TraceFunctionInput(const rvsdg::SimpleNode &callNode)
Traces function input of call node.
Definition: call.cpp:39
static std::vector< rvsdg::Output * > Create(rvsdg::Region &region, std::unique_ptr< CallOperation > callOperation, const std::vector< rvsdg::Output * > &operands)
Definition: call.hpp:449
static rvsdg::Output & GetIOStateOutput(const rvsdg::Node &node) noexcept
Definition: call.hpp:326
static void CheckFunctionType(const rvsdg::FunctionType &functionType)
Definition: call.hpp:502
bool operator==(const Operation &other) const noexcept override
Definition: call.cpp:20
static rvsdg::Output & GetMemoryStateOutput(const rvsdg::Node &node) noexcept
Definition: call.hpp:350
std::string debug_string() const override
Definition: call.cpp:27
static rvsdg::SimpleNode & CreateNode(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:469
CallOperation(std::shared_ptr< const rvsdg::FunctionType > functionType)
Definition: call.hpp:253
static rvsdg::SimpleNode * tryGetMemoryStateEntryMerge(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:369
static std::unique_ptr< ThreeAddressCode > create(const Variable *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< const Variable * > &arguments)
Definition: call.hpp:426
static std::vector< rvsdg::Output * > Create(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:440
Call node classifier.
Definition: call.hpp:24
rvsdg::Output & GetLambdaOutput() const noexcept
Returns the called function.
Definition: call.hpp:123
bool IsExternalCall() const noexcept
Determines whether call is an external call.
Definition: call.hpp:100
jlm::rvsdg::Output * Output_
Definition: call.hpp:242
static std::unique_ptr< CallTypeClassifier > CreateExternalCallClassifier(rvsdg::RegionArgument &argument)
Classify callee as external.
Definition: call.hpp:222
static std::unique_ptr< CallTypeClassifier > CreateIndirectCallClassifier(jlm::rvsdg::Output &output)
Classify callee as inderict.
Definition: call.hpp:235
bool IsIndirectCall() const noexcept
Determines whether call is an indirect call.
Definition: call.hpp:110
CallTypeClassifier(CallType callType, jlm::rvsdg::Output &output)
Definition: call.hpp:51
static std::unique_ptr< CallTypeClassifier > CreateNonRecursiveDirectCallClassifier(rvsdg::Output &output)
Classify callee as non-recursive.
Definition: call.hpp:191
bool IsNonRecursiveDirectCall() const noexcept
Determines whether call is a non-recursive direct call.
Definition: call.hpp:70
CallType GetCallType() const noexcept
Return call type.
Definition: call.hpp:60
rvsdg::RegionArgument & GetImport() const noexcept
Returns the imported function.
Definition: call.hpp:147
static std::unique_ptr< CallTypeClassifier > CreateRecursiveDirectCallClassifier(rvsdg::Output &output)
Classify callee as recursive.
Definition: call.hpp:207
bool IsRecursiveDirectCall() const noexcept
Determines whether call is a recursive direct call.
Definition: call.hpp:80
bool IsDirectCall() const noexcept
Definition: call.hpp:89
jlm::rvsdg::Output & GetFunctionOrigin() const noexcept
Return origin of a call node's function input.
Definition: call.hpp:162
static std::unique_ptr< llvm::ThreeAddressCode > create(std::unique_ptr< rvsdg::SimpleOperation > operation, const std::vector< const Variable * > &operands)
Definition: tac.hpp:135
Function type class.
size_t NumArguments() const noexcept
size_t NumResults() const noexcept
const jlm::rvsdg::Type & ArgumentType(size_t index) const noexcept
const jlm::rvsdg::Type & ResultType(size_t index) const noexcept
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
NodeInput * input(size_t index) const noexcept
Definition: node.hpp:615
rvsdg::Region * region() const noexcept
Definition: node.cpp:151
Represents the argument of a region.
Definition: region.hpp:41
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph * graph() const noexcept
Definition: region.hpp:363
static SimpleNode & Create(Region &region, std::unique_ptr< Operation > operation, const std::vector< rvsdg::Output * > &operands)
Definition: simple-node.hpp:49
SimpleOperation(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> operands, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> results)
Definition: operation.hpp:61
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
static std::string type(const Node *n)
Definition: view.cpp:255
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058