Jlm
lambda.hpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 Nico Reißmann <nico.reissmann@gmail.com>
3  * Copyright 2025 Helge Bahmann <hcb@chaoticmind.net>
4  * See COPYING for terms of redistribution.
5  */
6 
7 #ifndef JLM_RVSDG_LAMBDA_HPP
8 #define JLM_RVSDG_LAMBDA_HPP
9 
11 #include <jlm/rvsdg/graph.hpp>
15 
16 #include <optional>
17 #include <utility>
18 
19 namespace jlm::rvsdg
20 {
21 
22 class LambdaBuilder;
23 
29 {
30 public:
31  ~LambdaOperation() override;
32 
33  explicit LambdaOperation(std::shared_ptr<const FunctionType> type);
34 
35  [[nodiscard]] const FunctionType &
36  type() const noexcept
37  {
38  return *type_;
39  }
40 
41  [[nodiscard]] const std::shared_ptr<const FunctionType> &
42  Type() const noexcept
43  {
44  return type_;
45  }
46 
47  [[nodiscard]] std::string
48  debug_string() const override;
49 
50  bool
51  operator==(const Operation & other) const noexcept override;
52 
53  [[nodiscard]] std::unique_ptr<Operation>
54  copy() const override;
55 
56 private:
57  std::shared_ptr<const FunctionType> type_;
58 };
59 
82 class LambdaNode final : public rvsdg::StructuralNode
83 {
84 public:
85  ~LambdaNode() override;
86 
87 private:
88  LambdaNode(rvsdg::Region & parent, std::unique_ptr<LambdaOperation> op);
89 
90 public:
99  struct ContextVar
100  {
109 
118  };
119 
123  struct ArgumentVar
124  {
129  };
130 
131  [[nodiscard]] std::vector<rvsdg::Output *>
132  GetFunctionArguments() const;
133 
134  [[nodiscard]] std::vector<rvsdg::Input *>
135  GetFunctionResults() const;
136 
137  [[nodiscard]] rvsdg::Region *
138  subregion() const noexcept
139  {
140  return StructuralNode::subregion(0);
141  }
142 
143  [[nodiscard]] LambdaOperation &
144  GetOperation() const noexcept override;
145 
157  ContextVar
158  AddContextVar(jlm::rvsdg::Output & origin);
159 
178  [[nodiscard]] ContextVar
179  MapInputContextVar(const rvsdg::Input & input) const noexcept;
180 
198  [[nodiscard]] std::optional<ContextVar>
199  MapBinderContextVar(const rvsdg::Output & output) const noexcept;
200 
209  [[nodiscard]] std::vector<ContextVar>
210  GetContextVars() const noexcept;
211 
221  std::variant<ArgumentVar, ContextVar>
222  MapArgument(const rvsdg::Output & output) const;
223 
233  template<typename F>
234  size_t
235  RemoveLambdaInputsWhere(const F & match);
236 
244  size_t
246  {
247  auto match = [](const rvsdg::Input &)
248  {
249  return true;
250  };
251 
253  }
254 
255  [[nodiscard]] rvsdg::Output *
256  output() const noexcept;
257 
258  LambdaNode *
259  copy(rvsdg::Region * region, const std::vector<jlm::rvsdg::Output *> & operands) const override;
260 
261  LambdaNode *
262  copy(rvsdg::Region * region, rvsdg::SubstitutionMap & smap) const override;
263 
275  static LambdaNode *
276  Create(rvsdg::Region & parent, std::unique_ptr<LambdaOperation> operation);
277 
285  rvsdg::Output *
286  finalize(const std::vector<jlm::rvsdg::Output *> & results);
287 
288 private:
289  std::unique_ptr<LambdaOperation> Operation_;
290 
291  friend class LambdaBuilder;
292 };
293 
294 template<typename F>
295 size_t
297 {
298  util::HashSet<size_t> inputIndices;
299  util::HashSet<size_t> argumentIndices;
300  for (auto [input, argument] : GetContextVars())
301  {
302  if (argument->IsDead() && match(*input))
303  {
304  inputIndices.insert(input->index());
305  argumentIndices.insert(argument->index());
306  }
307  }
308 
309  [[maybe_unused]] const auto numRemoveArguments = subregion()->RemoveArguments(argumentIndices);
310  JLM_ASSERT(numRemoveArguments == argumentIndices.Size());
311 
312  [[maybe_unused]] const auto numRemovedInputs = RemoveInputs(inputIndices);
313  JLM_ASSERT(numRemovedInputs == inputIndices.Size());
314 
315  return numRemovedInputs;
316 }
317 
322 {
323 public:
331  LambdaBuilder(Region & region, std::vector<std::shared_ptr<const Type>> argtypes);
332 
343  std::vector<Output *>
344  Arguments();
345 
356  rvsdg::Region *
357  GetRegion() noexcept;
358 
373 
389  Output &
390  Finalize(const std::vector<jlm::rvsdg::Output *> & results, std::unique_ptr<LambdaOperation> op);
391 
392 private:
394 };
395 
403 [[nodiscard]] rvsdg::LambdaNode &
405 
406 [[nodiscard]] const rvsdg::LambdaNode &
408 
409 }
410 
411 #endif
Function type class.
size_t index() const noexcept
Definition: node.hpp:52
Constructs a lambda node.
Definition: lambda.hpp:322
Lambda node.
Definition: lambda.hpp:83
LambdaNode * copy(rvsdg::Region *region, const std::vector< jlm::rvsdg::Output * > &operands) const override
Definition: lambda.cpp:182
rvsdg::Output * finalize(const std::vector< jlm::rvsdg::Output * > &results)
Definition: lambda.cpp:146
std::variant< ArgumentVar, ContextVar > MapArgument(const rvsdg::Output &output) const
Maps region argument to its disposition (formal argument or context var).
Definition: lambda.cpp:104
size_t RemoveLambdaInputsWhere(const F &match)
Definition: lambda.hpp:296
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
ContextVar MapInputContextVar(const rvsdg::Input &input) const noexcept
Maps input to context variable.
Definition: lambda.cpp:80
friend class LambdaBuilder
Definition: lambda.hpp:291
std::optional< ContextVar > MapBinderContextVar(const rvsdg::Output &output) const noexcept
Maps bound variable reference to context variable.
Definition: lambda.cpp:88
ContextVar AddContextVar(jlm::rvsdg::Output &origin)
Adds a context/free variable to the lambda node.
Definition: lambda.cpp:131
LambdaNode(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > op)
Definition: lambda.cpp:40
size_t PruneLambdaInputs()
Definition: lambda.hpp:245
std::unique_ptr< LambdaOperation > Operation_
Definition: lambda.hpp:289
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
rvsdg::Output * output() const noexcept
Definition: lambda.cpp:176
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
Lambda operation.
Definition: lambda.hpp:29
LambdaOperation(std::shared_ptr< const FunctionType > type)
Definition: lambda.cpp:15
bool operator==(const Operation &other) const noexcept override
Definition: lambda.cpp:26
const FunctionType & type() const noexcept
Definition: lambda.hpp:36
const std::shared_ptr< const FunctionType > & Type() const noexcept
Definition: lambda.hpp:42
std::unique_ptr< Operation > copy() const override
Definition: lambda.cpp:33
std::shared_ptr< const FunctionType > type_
Definition: lambda.hpp:57
std::string debug_string() const override
Definition: lambda.cpp:20
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
size_t RemoveInputs(const util::HashSet< size_t > &indices)
Definition: node.cpp:306
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
size_t RemoveArguments(const util::HashSet< size_t > &indices)
Definition: region.cpp:210
StructuralInput * input(size_t index) const noexcept
rvsdg::Region * subregion(size_t index) const noexcept
bool insert(ItemType item)
Definition: HashSet.hpp:210
std::size_t Size() const noexcept
Definition: HashSet.hpp:187
#define JLM_ASSERT(x)
Definition: common.hpp:16
jlm::rvsdg::Output * match(size_t nbits, const std::unordered_map< uint64_t, uint64_t > &mapping, uint64_t default_alternative, size_t nalternatives, jlm::rvsdg::Output *operand)
Definition: control.cpp:179
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
rvsdg::LambdaNode & getSurroundingLambdaNode(rvsdg::Node &node)
Definition: lambda.cpp:272
Formal argument variable.
Definition: lambda.hpp:124
rvsdg::Output * arg
Access argument object in subregion.
Definition: lambda.hpp:128
Bound context variable.
Definition: lambda.hpp:100
rvsdg::Input * input
Input variable bound into lambda node.
Definition: lambda.hpp:108
rvsdg::Output * inner
Access to bound object in subregion.
Definition: lambda.hpp:117