Jlm
call.hpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
6 #ifndef JLM_LLVM_IR_OPERATORS_CALL_HPP
7 #define JLM_LLVM_IR_OPERATORS_CALL_HPP
8 
13 #include <jlm/llvm/ir/tac.hpp>
14 #include <jlm/llvm/ir/types.hpp>
15 #include <jlm/rvsdg/Phi.hpp>
17 
18 namespace jlm::llvm
19 {
20 
25 class CallTypeClassifier final
26 {
27 public:
28  enum class CallType
29  {
35 
41 
46 
51  };
52 
54  : CallType_(callType),
55  Output_(&output)
56  {}
57 
61  [[nodiscard]] CallType
62  GetCallType() const noexcept
63  {
64  return CallType_;
65  }
66 
71  [[nodiscard]] bool
72  IsNonRecursiveDirectCall() const noexcept
73  {
75  }
76 
81  [[nodiscard]] bool
82  IsRecursiveDirectCall() const noexcept
83  {
85  }
86 
90  [[nodiscard]] bool
91  IsDirectCall() const noexcept
92  {
95  }
96 
101  [[nodiscard]] bool
102  IsExternalCall() const noexcept
103  {
105  }
106 
111  [[nodiscard]] bool
112  IsIndirectCall() const noexcept
113  {
115  }
116 
124  [[nodiscard]] rvsdg::Output &
125  GetLambdaOutput() const noexcept
126  {
128  {
129  return *Output_;
130  }
131 
133  auto argument = jlm::util::assertedCast<jlm::rvsdg::RegionArgument>(Output_);
134  /*
135  * FIXME: This assumes that all recursion variables where added before the dependencies. It
136  * would be better if we did not use the index for retrieving the result, but instead
137  * explicitly encoded it in an phi_argument.
138  */
139  return *argument->region()->result(argument->index())->origin();
140  }
141 
148  [[nodiscard]] rvsdg::RegionArgument &
149  GetImport() const noexcept
150  {
152  return *jlm::util::assertedCast<rvsdg::RegionArgument>(Output_);
153  }
154 
163  [[nodiscard]] jlm::rvsdg::Output &
164  GetFunctionOrigin() const noexcept
165  {
166  return *Output_;
167  }
168 
173  [[nodiscard]] bool
174  isSetjmpCall();
175 
180  [[nodiscard]] bool
181  isVaStartCall();
182 
192  static std::unique_ptr<CallTypeClassifier>
194  {
195  rvsdg::AssertGetOwnerNode<rvsdg::LambdaNode>(output);
196  return std::make_unique<CallTypeClassifier>(CallType::NonRecursiveDirectCall, output);
197  }
198 
208  static std::unique_ptr<CallTypeClassifier>
210  {
211  return std::make_unique<CallTypeClassifier>(CallType::RecursiveDirectCall, output);
212  }
213 
223  static std::unique_ptr<CallTypeClassifier>
225  {
226  JLM_ASSERT(argument.region() == &argument.region()->graph()->GetRootRegion());
227  return std::make_unique<CallTypeClassifier>(CallType::ExternalCall, argument);
228  }
229 
236  static std::unique_ptr<CallTypeClassifier>
238  {
239  return std::make_unique<CallTypeClassifier>(CallType::IndirectCall, output);
240  }
241 
242 private:
245 };
246 
251 {
252 public:
253  ~CallOperation() override;
254 
255  explicit CallOperation(
256  std::shared_ptr<const rvsdg::FunctionType> functionType,
257  CallingConvention callingConvention,
258  AttributeList attributes)
259  : SimpleOperation(create_srctypes(functionType), functionType->Results()),
260  FunctionType_(std::move(functionType)),
261  callingConvention_(callingConvention),
262  attributes_(std::move(attributes))
263  {}
264 
265  bool
266  operator==(const Operation & other) const noexcept override;
267 
268  [[nodiscard]] std::string
269  debug_string() const override;
270 
271  [[nodiscard]] const std::shared_ptr<const rvsdg::FunctionType> &
272  GetFunctionType() const noexcept
273  {
274  return FunctionType_;
275  }
276 
277  [[nodiscard]] CallingConvention
278  getCallingConvention() const noexcept
279  {
280  return callingConvention_;
281  }
282 
283  [[nodiscard]] const AttributeList &
284  getAttributes() const noexcept
285  {
286  return attributes_;
287  }
288 
289  [[nodiscard]] std::unique_ptr<Operation>
290  copy() const override;
291 
297  [[nodiscard]] static size_t
298  NumArguments(const rvsdg::Node & node) noexcept
299  {
300  JLM_ASSERT(is<CallOperation>(&node));
301  return node.ninputs() - 1;
302  }
303 
309  [[nodiscard]] static rvsdg::Input *
310  Argument(const rvsdg::Node & node, const size_t n)
311  {
312  JLM_ASSERT(is<CallOperation>(&node));
314  return node.input(n + 1);
315  }
316 
320  [[nodiscard]] static rvsdg::Input &
321  GetFunctionInput(const rvsdg::Node & node) noexcept
322  {
323  JLM_ASSERT(is<CallOperation>(&node));
324  const auto functionInput = node.input(0);
325  JLM_ASSERT(is<rvsdg::FunctionType>(functionInput->Type()));
326  return *functionInput;
327  }
328 
332  [[nodiscard]] static rvsdg::Input &
333  GetIOStateInput(const rvsdg::Node & node) noexcept
334  {
335  JLM_ASSERT(is<CallOperation>(&node));
336  const auto ioState = node.input(node.ninputs() - 2);
337  JLM_ASSERT(is<IOStateType>(ioState->Type()));
338  return *ioState;
339  }
340 
344  [[nodiscard]] static rvsdg::Output &
345  GetIOStateOutput(const rvsdg::Node & node) noexcept
346  {
347  JLM_ASSERT(is<CallOperation>(&node));
348  const auto ioState = node.output(node.noutputs() - 2);
349  JLM_ASSERT(is<IOStateType>(ioState->Type()));
350  return *ioState;
351  }
352 
356  [[nodiscard]] static rvsdg::Input &
357  GetMemoryStateInput(const rvsdg::Node & node) noexcept
358  {
359  JLM_ASSERT(is<CallOperation>(&node));
360  const auto memoryState = node.input(node.ninputs() - 1);
361  JLM_ASSERT(is<MemoryStateType>(memoryState->Type()));
362  return *memoryState;
363  }
364 
368  [[nodiscard]] static rvsdg::Output &
369  GetMemoryStateOutput(const rvsdg::Node & node) noexcept
370  {
371  JLM_ASSERT(is<CallOperation>(&node));
372  const auto memoryState = node.output(node.noutputs() - 1);
373  JLM_ASSERT(is<MemoryStateType>(memoryState->Type()));
374  return *memoryState;
375  }
376 
387  [[nodiscard]] static rvsdg::SimpleNode *
388  tryGetMemoryStateEntryMerge(const rvsdg::Node & callNode) noexcept
389  {
390  JLM_ASSERT(is<CallOperation>(&callNode));
391  const auto node =
392  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*GetMemoryStateInput(callNode).origin());
393  return is<CallEntryMemoryStateMergeOperation>(node) ? node : nullptr;
394  }
395 
405  [[nodiscard]] static rvsdg::SimpleNode *
406  tryGetMemoryStateExitSplit(const rvsdg::Node & callNode) noexcept
407  {
408  JLM_ASSERT(is<CallOperation>(&callNode));
409 
410  // If a memory state exit split node is present, then we would expect the node to be the only
411  // user of the memory state output.
412  if (GetMemoryStateOutput(callNode).nusers() != 1)
413  return nullptr;
414 
415  const auto node =
416  rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(GetMemoryStateOutput(callNode).SingleUser());
417  return is<CallExitMemoryStateSplitOperation>(node) ? node : nullptr;
418  }
419 
431  static rvsdg::Output &
432  TraceFunctionInput(const rvsdg::SimpleNode & callNode);
433 
441  static std::unique_ptr<CallTypeClassifier>
442  ClassifyCall(const rvsdg::SimpleNode & callNode);
443 
444  static std::unique_ptr<ThreeAddressCode>
446  const Variable * function,
447  std::shared_ptr<const rvsdg::FunctionType> functionType,
448  CallingConvention callingConvention,
449  AttributeList attributes,
450  const std::vector<const Variable *> & arguments)
451  {
452  CheckFunctionInputType(function->type());
453 
454  auto op = std::make_unique<CallOperation>(
455  std::move(functionType),
456  callingConvention,
457  std::move(attributes));
458  std::vector<const Variable *> operands({ function });
459  operands.insert(operands.end(), arguments.begin(), arguments.end());
460  return ThreeAddressCode::create(std::move(op), operands);
461  }
462 
463  static std::vector<rvsdg::Output *>
465  rvsdg::Output * function,
466  std::shared_ptr<const rvsdg::FunctionType> functionType,
467  const std::vector<rvsdg::Output *> & arguments)
468  {
469  return outputs(&CreateNode(function, std::move(functionType), arguments));
470  }
471 
472  static std::vector<rvsdg::Output *>
474  rvsdg::Output * function,
475  std::shared_ptr<const rvsdg::FunctionType> functionType,
476  CallingConvention callingConvention,
477  AttributeList attributes,
478  const std::vector<rvsdg::Output *> & arguments)
479  {
480  return outputs(&CreateNode(
481  function,
482  std::move(functionType),
483  callingConvention,
484  std::move(attributes),
485  arguments));
486  }
487 
488  static rvsdg::SimpleNode &
490  rvsdg::Region & region,
491  std::unique_ptr<CallOperation> callOperation,
492  const std::vector<rvsdg::Output *> & operands)
493  {
494  CheckFunctionType(*callOperation->GetFunctionType());
495 
496  return rvsdg::SimpleNode::Create(region, std::move(callOperation), operands);
497  }
498 
504  static rvsdg::SimpleNode &
506  rvsdg::Output * function,
507  std::shared_ptr<const rvsdg::FunctionType> functionType,
508  const std::vector<rvsdg::Output *> & arguments)
509  {
510  return CreateNode(
511  function,
512  std::move(functionType),
515  arguments);
516  }
517 
518  static rvsdg::SimpleNode &
520  rvsdg::Output * function,
521  std::shared_ptr<const rvsdg::FunctionType> functionType,
522  CallingConvention callingConvention,
523  AttributeList attributes,
524  const std::vector<rvsdg::Output *> & arguments)
525  {
526  CheckFunctionInputType(*function->Type());
527 
528  auto callOperation = std::make_unique<CallOperation>(
529  std::move(functionType),
530  callingConvention,
531  std::move(attributes));
532  std::vector operands({ function });
533  operands.insert(operands.end(), arguments.begin(), arguments.end());
534 
535  return CreateNode(*function->region(), std::move(callOperation), operands);
536  }
537 
538 private:
539  static inline std::vector<std::shared_ptr<const rvsdg::Type>>
540  create_srctypes(const std::shared_ptr<const rvsdg::FunctionType> & functionType)
541  {
542  std::vector<std::shared_ptr<const rvsdg::Type>> types({ functionType });
543  for (auto & argumentType : functionType->Arguments())
544  types.emplace_back(argumentType);
545 
546  return types;
547  }
548 
549  static void
551  {
552  if (!is<rvsdg::FunctionType>(type))
553  throw util::Error("Expected function type.");
554  }
555 
556  static void
558  {
559  auto CheckArgumentTypes = [](const rvsdg::FunctionType & functionType)
560  {
561  if (functionType.NumArguments() < 2)
562  throw util::Error("Expected at least three argument types.");
563 
564  auto memoryStateArgumentIndex = functionType.NumArguments() - 1;
565  auto iOStateArgumentIndex = functionType.NumArguments() - 2;
566 
567  if (!is<MemoryStateType>(functionType.ArgumentType(memoryStateArgumentIndex)))
568  throw util::Error("Expected memory state type.");
569 
570  if (!is<IOStateType>(functionType.ArgumentType(iOStateArgumentIndex)))
571  throw util::Error("Expected IO state type.");
572  };
573 
574  auto CheckResultTypes = [](const rvsdg::FunctionType & functionType)
575  {
576  if (functionType.NumResults() < 2)
577  throw util::Error("Expected at least three result types.");
578 
579  auto memoryStateResultIndex = functionType.NumResults() - 1;
580  auto iOStateResultIndex = functionType.NumResults() - 2;
581 
582  if (!is<MemoryStateType>(functionType.ResultType(memoryStateResultIndex)))
583  throw util::Error("Expected memory state type.");
584 
585  if (!is<IOStateType>(functionType.ResultType(iOStateResultIndex)))
586  throw util::Error("Expected IO state type.");
587  };
588 
589  CheckArgumentTypes(functionType);
590  CheckResultTypes(functionType);
591  }
592 
593  std::shared_ptr<const rvsdg::FunctionType> FunctionType_;
596 };
597 
598 }
599 
600 #endif
static AttributeList createEmptyList()
Definition: attribute.hpp:427
Call operation class.
Definition: call.hpp:251
CallingConvention getCallingConvention() const noexcept
Definition: call.hpp:278
static rvsdg::Input * Argument(const rvsdg::Node &node, const size_t n)
Definition: call.hpp:310
static rvsdg::Input & GetIOStateInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:333
static rvsdg::SimpleNode & CreateNode(rvsdg::Region &region, std::unique_ptr< CallOperation > callOperation, const std::vector< rvsdg::Output * > &operands)
Definition: call.hpp:489
std::shared_ptr< const rvsdg::FunctionType > FunctionType_
Definition: call.hpp:593
static rvsdg::SimpleNode * tryGetMemoryStateExitSplit(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:406
static rvsdg::Input & GetFunctionInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:321
const std::shared_ptr< const rvsdg::FunctionType > & GetFunctionType() const noexcept
Definition: call.hpp:272
static std::vector< std::shared_ptr< const rvsdg::Type > > create_srctypes(const std::shared_ptr< const rvsdg::FunctionType > &functionType)
Definition: call.hpp:540
std::unique_ptr< Operation > copy() const override
Definition: call.cpp:35
static rvsdg::SimpleNode & CreateNode(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, CallingConvention callingConvention, AttributeList attributes, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:519
static std::unique_ptr< CallTypeClassifier > ClassifyCall(const rvsdg::SimpleNode &callNode)
Classifies a call node.
Definition: call.cpp:49
static rvsdg::Input & GetMemoryStateInput(const rvsdg::Node &node) noexcept
Definition: call.hpp:357
const AttributeList & getAttributes() const noexcept
Definition: call.hpp:284
static void CheckFunctionInputType(const jlm::rvsdg::Type &type)
Definition: call.hpp:550
static size_t NumArguments(const rvsdg::Node &node) noexcept
Definition: call.hpp:298
static rvsdg::Output & TraceFunctionInput(const rvsdg::SimpleNode &callNode)
Traces function input of call node.
Definition: call.cpp:41
static rvsdg::Output & GetIOStateOutput(const rvsdg::Node &node) noexcept
Definition: call.hpp:345
static void CheckFunctionType(const rvsdg::FunctionType &functionType)
Definition: call.hpp:557
bool operator==(const Operation &other) const noexcept override
Definition: call.cpp:20
CallOperation(std::shared_ptr< const rvsdg::FunctionType > functionType, CallingConvention callingConvention, AttributeList attributes)
Definition: call.hpp:255
static std::unique_ptr< ThreeAddressCode > create(const Variable *function, std::shared_ptr< const rvsdg::FunctionType > functionType, CallingConvention callingConvention, AttributeList attributes, const std::vector< const Variable * > &arguments)
Definition: call.hpp:445
static rvsdg::Output & GetMemoryStateOutput(const rvsdg::Node &node) noexcept
Definition: call.hpp:369
std::string debug_string() const override
Definition: call.cpp:29
static std::vector< rvsdg::Output * > Create(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, CallingConvention callingConvention, AttributeList attributes, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:473
AttributeList attributes_
Definition: call.hpp:595
static rvsdg::SimpleNode & CreateNode(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:505
static rvsdg::SimpleNode * tryGetMemoryStateEntryMerge(const rvsdg::Node &callNode) noexcept
Definition: call.hpp:388
CallingConvention callingConvention_
Definition: call.hpp:594
static std::vector< rvsdg::Output * > Create(rvsdg::Output *function, std::shared_ptr< const rvsdg::FunctionType > functionType, const std::vector< rvsdg::Output * > &arguments)
Definition: call.hpp:464
Call node classifier.
Definition: call.hpp:26
rvsdg::Output & GetLambdaOutput() const noexcept
Returns the called function.
Definition: call.hpp:125
bool IsExternalCall() const noexcept
Determines whether call is an external call.
Definition: call.hpp:102
jlm::rvsdg::Output * Output_
Definition: call.hpp:244
static std::unique_ptr< CallTypeClassifier > CreateExternalCallClassifier(rvsdg::RegionArgument &argument)
Classify callee as external.
Definition: call.hpp:224
static std::unique_ptr< CallTypeClassifier > CreateIndirectCallClassifier(jlm::rvsdg::Output &output)
Classify callee as inderict.
Definition: call.hpp:237
bool IsIndirectCall() const noexcept
Determines whether call is an indirect call.
Definition: call.hpp:112
CallTypeClassifier(CallType callType, jlm::rvsdg::Output &output)
Definition: call.hpp:53
static std::unique_ptr< CallTypeClassifier > CreateNonRecursiveDirectCallClassifier(rvsdg::Output &output)
Classify callee as non-recursive.
Definition: call.hpp:193
bool IsNonRecursiveDirectCall() const noexcept
Determines whether call is a non-recursive direct call.
Definition: call.hpp:72
CallType GetCallType() const noexcept
Return call type.
Definition: call.hpp:62
rvsdg::RegionArgument & GetImport() const noexcept
Returns the imported function.
Definition: call.hpp:149
static std::unique_ptr< CallTypeClassifier > CreateRecursiveDirectCallClassifier(rvsdg::Output &output)
Classify callee as recursive.
Definition: call.hpp:209
bool IsRecursiveDirectCall() const noexcept
Determines whether call is a recursive direct call.
Definition: call.hpp:82
bool IsDirectCall() const noexcept
Definition: call.hpp:91
jlm::rvsdg::Output & GetFunctionOrigin() const noexcept
Return origin of a call node's function input.
Definition: call.hpp:164
static std::unique_ptr< llvm::ThreeAddressCode > create(std::unique_ptr< rvsdg::SimpleOperation > operation, const std::vector< const Variable * > &operands)
Definition: tac.hpp:135
Function type class.
size_t NumArguments() const noexcept
size_t NumResults() const noexcept
const jlm::rvsdg::Type & ArgumentType(size_t index) const noexcept
const jlm::rvsdg::Type & ResultType(size_t index) const noexcept
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
NodeInput * input(size_t index) const noexcept
Definition: node.hpp:615
rvsdg::Region * region() const noexcept
Definition: node.cpp:151
Represents the argument of a region.
Definition: region.hpp:41
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
Graph * graph() const noexcept
Definition: region.hpp:363
static SimpleNode & Create(Region &region, std::unique_ptr< Operation > operation, const std::vector< rvsdg::Output * > &operands)
Definition: simple-node.hpp:49
SimpleOperation(std::vector< std::shared_ptr< const jlm::rvsdg::Type >> operands, std::vector< std::shared_ptr< const jlm::rvsdg::Type >> results)
Definition: operation.hpp:61
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)
Definition: node.hpp:1058