Jlm
lambda.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2018 Nico Reißmann <nico.reissmann@gmail.com>
3  * Copyright 2025 Helge Bahmann <hcb@chaoticmind.net>
4  * See COPYING for terms of redistribution.
5  */
6 
7 #include <jlm/rvsdg/lambda.hpp>
8 #include <jlm/util/strfmt.hpp>
9 
10 namespace jlm::rvsdg
11 {
12 
14 
15 LambdaOperation::LambdaOperation(std::shared_ptr<const FunctionType> type)
16  : type_(std::move(type))
17 {}
18 
19 std::string
21 {
22  return util::strfmt("Lambda[", Type()->debug_string(), "]");
23 }
24 
25 bool
26 LambdaOperation::operator==(const Operation & other) const noexcept
27 {
28  auto op = dynamic_cast<const LambdaOperation *>(&other);
29  return op && op->type() == type();
30 }
31 
32 std::unique_ptr<rvsdg::Operation>
34 {
35  return std::make_unique<LambdaOperation>(*this);
36 }
37 
38 LambdaNode::~LambdaNode() = default;
39 
40 LambdaNode::LambdaNode(rvsdg::Region & parent, std::unique_ptr<LambdaOperation> op)
41  : StructuralNode(&parent, 1),
42  Operation_(std::move(op))
43 {
44  for (auto & argumentType : GetOperation().Type()->Arguments())
45  {
46  rvsdg::RegionArgument::Create(*subregion(), nullptr, argumentType);
47  }
48 }
49 
51 LambdaNode::GetOperation() const noexcept
52 {
53  return *Operation_;
54 }
55 
56 [[nodiscard]] std::vector<rvsdg::Output *>
58 {
59  std::vector<rvsdg::Output *> arguments;
60  const auto & type = GetOperation().Type();
61  for (std::size_t n = 0; n < type->Arguments().size(); ++n)
62  {
63  arguments.push_back(subregion()->argument(n));
64  }
65  return arguments;
66 }
67 
68 [[nodiscard]] std::vector<rvsdg::Input *>
70 {
71  std::vector<rvsdg::Input *> results;
72  for (std::size_t n = 0; n < subregion()->nresults(); ++n)
73  {
74  results.push_back(subregion()->result(n));
75  }
76  return results;
77 }
78 
79 [[nodiscard]] LambdaNode::ContextVar
80 LambdaNode::MapInputContextVar(const rvsdg::Input & input) const noexcept
81 {
82  JLM_ASSERT(rvsdg::TryGetOwnerNode<LambdaNode>(input) == this);
83  return ContextVar{ const_cast<rvsdg::Input *>(&input),
84  subregion()->argument(GetOperation().Type()->NumArguments() + input.index()) };
85 }
86 
87 [[nodiscard]] std::optional<LambdaNode::ContextVar>
88 LambdaNode::MapBinderContextVar(const rvsdg::Output & output) const noexcept
89 {
90  JLM_ASSERT(rvsdg::TryGetOwnerRegion(output) == subregion());
91  auto numArguments = GetOperation().Type()->NumArguments();
92  if (output.index() >= numArguments)
93  {
94  return ContextVar{ input(output.index() - GetOperation().Type()->NumArguments()),
95  const_cast<rvsdg::Output *>(&output) };
96  }
97  else
98  {
99  return std::nullopt;
100  }
101 }
102 
103 std::variant<LambdaNode::ArgumentVar, LambdaNode::ContextVar>
105 {
106  JLM_ASSERT(rvsdg::TryGetOwnerNode<LambdaNode>(output) == this);
107  std::size_t nargs = GetOperation().Type()->NumArguments();
108  if (output.index() < nargs)
109  {
110  return ArgumentVar{ subregion()->argument(output.index()) };
111  }
112  else
113  {
114  return ContextVar{ input(output.index() - nargs), subregion()->argument(output.index()) };
115  }
116 }
117 
118 [[nodiscard]] std::vector<LambdaNode::ContextVar>
120 {
121  std::vector<ContextVar> vars;
122  for (size_t n = 0; n < ninputs(); ++n)
123  {
124  vars.push_back(
125  ContextVar{ input(n), subregion()->argument(n + GetOperation().Type()->NumArguments()) });
126  }
127  return vars;
128 }
129 
132 {
133  const auto input =
134  addInput(std::make_unique<StructuralInput>(this, &origin, origin.Type()), true);
135  const auto argument = &RegionArgument::Create(*subregion(), input, origin.Type());
136  return ContextVar{ input, argument };
137 }
138 
139 LambdaNode *
140 LambdaNode::Create(rvsdg::Region & parent, std::unique_ptr<LambdaOperation> operation)
141 {
142  return new LambdaNode(parent, std::move(operation));
143 }
144 
146 LambdaNode::finalize(const std::vector<jlm::rvsdg::Output *> & results)
147 {
148  /* check if finalized was already called */
149  if (noutputs() > 0)
150  {
151  JLM_ASSERT(noutputs() == 1);
152  return output();
153  }
154 
155  if (GetOperation().type().NumResults() != results.size())
156  throw util::Error("Incorrect number of results.");
157 
158  for (size_t n = 0; n < results.size(); n++)
159  {
160  auto & expected = GetOperation().type().ResultType(n);
161  auto & received = *results[n]->Type();
162  if (*results[n]->Type() != GetOperation().type().ResultType(n))
163  throw util::Error("Expected " + expected.debug_string() + ", got " + received.debug_string());
164 
165  if (results[n]->region() != subregion())
166  throw util::Error("Invalid operand region.");
167  }
168 
169  for (const auto & origin : results)
170  rvsdg::RegionResult::Create(*origin->region(), *origin, nullptr, origin->Type());
171 
172  return addOutput(std::make_unique<StructuralOutput>(this, GetOperation().Type()));
173 }
174 
176 LambdaNode::output() const noexcept
177 {
178  return StructuralNode::output(0);
179 }
180 
181 LambdaNode *
182 LambdaNode::copy(rvsdg::Region * region, const std::vector<jlm::rvsdg::Output *> & operands) const
183 {
184  return util::assertedCast<LambdaNode>(rvsdg::Node::copy(region, operands));
185 }
186 
187 LambdaNode *
189 {
190  const auto & op = GetOperation();
191  auto lambda = Create(
192  *region,
193  std::unique_ptr<LambdaOperation>(util::assertedCast<LambdaOperation>(op.copy().release())));
194 
195  /* add context variables */
196  rvsdg::SubstitutionMap subregionmap;
197  for (const auto & cv : GetContextVars())
198  {
199  auto origin = &smap.lookup(*cv.input->origin());
200  subregionmap.insert(cv.inner, lambda->AddContextVar(*origin).inner);
201  }
202 
203  /* collect function arguments */
204  auto args = GetFunctionArguments();
205  auto newArgs = lambda->GetFunctionArguments();
206  JLM_ASSERT(args.size() == newArgs.size());
207  for (std::size_t n = 0; n < args.size(); ++n)
208  {
209  subregionmap.insert(args[n], newArgs[n]);
210  }
211 
212  /* copy subregion */
213  subregion()->copy(lambda->subregion(), subregionmap);
214 
215  /* collect function results */
216  std::vector<jlm::rvsdg::Output *> results;
217  for (auto result : GetFunctionResults())
218  results.push_back(&subregionmap.lookup(*result->origin()));
219 
220  /* finalize lambda */
221  auto o = lambda->finalize(results);
222  smap.insert(output(), o);
223 
224  return lambda;
225 }
226 
227 LambdaBuilder::LambdaBuilder(Region & region, std::vector<std::shared_ptr<const Type>> argtypes)
228  : Node_(LambdaNode::Create(
229  region,
230  std::make_unique<LambdaOperation>(FunctionType::Create(std::move(argtypes), {}))))
231 {
232  // Note that the above inserts a "placeholder" function type, for now.
233  // This is to avoid requiring the caller to specify the return type(s)
234  // already when starting to construct the object. It is sometimes easier
235  // to let them be determined while building.
236 }
237 
238 std::vector<Output *>
240 {
241  JLM_ASSERT(Node_);
242  return Node_->GetFunctionArguments();
243 }
244 
247 {
248  JLM_ASSERT(Node_);
249  return Node_->subregion();
250 }
251 
254 {
255  JLM_ASSERT(Node_);
256  return Node_->AddContextVar(origin);
257 }
258 
259 Output &
261  const std::vector<jlm::rvsdg::Output *> & results,
262  std::unique_ptr<LambdaOperation> op)
263 {
264  JLM_ASSERT(Node_);
265  Node_->Operation_ = std::move(op);
266  auto output = Node_->finalize(results);
267  Node_ = nullptr;
268  return *output;
269 }
270 
271 [[nodiscard]] rvsdg::LambdaNode &
273 {
274  auto it = &node;
275  while (it)
276  {
277  if (auto lambda = dynamic_cast<rvsdg::LambdaNode *>(it))
278  return *lambda;
279  it = it->region()->node();
280  }
281  throw std::logic_error("node was not in a lambda");
282 }
283 
284 [[nodiscard]] const rvsdg::LambdaNode &
286 {
287  return getSurroundingLambdaNode(const_cast<rvsdg::Node &>(node));
288 }
289 
290 }
Function type class.
const jlm::rvsdg::Type & ResultType(size_t index) const noexcept
Output & Finalize(const std::vector< jlm::rvsdg::Output * > &results, std::unique_ptr< LambdaOperation > op)
Verifies well-formedness of lambda node and completes it.
Definition: lambda.cpp:260
LambdaNode::ContextVar AddContextVar(jlm::rvsdg::Output &origin)
Adds a context/free variable to the lambda node.
Definition: lambda.cpp:253
rvsdg::Region * GetRegion() noexcept
Returns region to place nodes in.
Definition: lambda.cpp:246
LambdaBuilder(Region &region, std::vector< std::shared_ptr< const Type >> argtypes)
Creates builder for a lambda construct.
Definition: lambda.cpp:227
std::vector< Output * > Arguments()
Obtains definition points of parameters to the function.
Definition: lambda.cpp:239
Lambda node.
Definition: lambda.hpp:83
LambdaNode * copy(rvsdg::Region *region, const std::vector< jlm::rvsdg::Output * > &operands) const override
Definition: lambda.cpp:182
rvsdg::Output * finalize(const std::vector< jlm::rvsdg::Output * > &results)
Definition: lambda.cpp:146
std::variant< ArgumentVar, ContextVar > MapArgument(const rvsdg::Output &output) const
Maps region argument to its disposition (formal argument or context var).
Definition: lambda.cpp:104
std::vector< rvsdg::Output * > GetFunctionArguments() const
Definition: lambda.cpp:57
rvsdg::Region * subregion() const noexcept
Definition: lambda.hpp:138
ContextVar MapInputContextVar(const rvsdg::Input &input) const noexcept
Maps input to context variable.
Definition: lambda.cpp:80
std::optional< ContextVar > MapBinderContextVar(const rvsdg::Output &output) const noexcept
Maps bound variable reference to context variable.
Definition: lambda.cpp:88
ContextVar AddContextVar(jlm::rvsdg::Output &origin)
Adds a context/free variable to the lambda node.
Definition: lambda.cpp:131
LambdaNode(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > op)
Definition: lambda.cpp:40
std::unique_ptr< LambdaOperation > Operation_
Definition: lambda.hpp:289
static LambdaNode * Create(rvsdg::Region &parent, std::unique_ptr< LambdaOperation > operation)
Definition: lambda.cpp:140
std::vector< rvsdg::Input * > GetFunctionResults() const
Definition: lambda.cpp:69
rvsdg::Output * output() const noexcept
Definition: lambda.cpp:176
std::vector< ContextVar > GetContextVars() const noexcept
Gets all bound context variables.
Definition: lambda.cpp:119
LambdaOperation & GetOperation() const noexcept override
Definition: lambda.cpp:51
Lambda operation.
Definition: lambda.hpp:29
LambdaOperation(std::shared_ptr< const FunctionType > type)
Definition: lambda.cpp:15
bool operator==(const Operation &other) const noexcept override
Definition: lambda.cpp:26
const FunctionType & type() const noexcept
Definition: lambda.hpp:36
const std::shared_ptr< const FunctionType > & Type() const noexcept
Definition: lambda.hpp:42
std::unique_ptr< Operation > copy() const override
Definition: lambda.cpp:33
std::string debug_string() const override
Definition: lambda.cpp:20
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
size_t ninputs() const noexcept
Definition: node.hpp:609
size_t noutputs() const noexcept
Definition: node.hpp:644
virtual Node * copy(rvsdg::Region *region, const std::vector< jlm::rvsdg::Output * > &operands) const
Definition: node.cpp:369
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
Definition: node.hpp:366
size_t index() const noexcept
Definition: node.hpp:274
static RegionArgument & Create(rvsdg::Region &region, StructuralInput *input, std::shared_ptr< const rvsdg::Type > type)
Creates region entry argument.
Definition: region.cpp:62
static RegionResult & Create(rvsdg::Region &region, rvsdg::Output &origin, StructuralOutput *output, std::shared_ptr< const rvsdg::Type > type)
Create region exit result.
Definition: region.cpp:111
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
RegionResult * result(size_t index) const noexcept
Definition: region.hpp:471
void copy(Region *target, SubstitutionMap &smap) const
Copy a region with substitutions.
Definition: region.cpp:314
rvsdg::StructuralNode * node() const noexcept
Definition: region.hpp:369
size_t nresults() const noexcept
Definition: region.hpp:465
RegionArgument * argument(size_t index) const noexcept
Definition: region.hpp:437
StructuralInput * addInput(std::unique_ptr< StructuralInput > input, bool notifyRegion)
StructuralOutput * addOutput(std::unique_ptr< StructuralOutput > input)
StructuralOutput * output(size_t index) const noexcept
StructuralInput * input(size_t index) const noexcept
void insert(const Output *original, Output *substitute)
Output & lookup(const Output &original) const
constexpr Type() noexcept
Definition: type.hpp:46
#define JLM_ASSERT(x)
Definition: common.hpp:16
static std::string type(const Node *n)
Definition: view.cpp:255
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
Region * TryGetOwnerRegion(const rvsdg::Input &input) noexcept
Definition: node.hpp:1021
rvsdg::LambdaNode & getSurroundingLambdaNode(rvsdg::Node &node)
Definition: lambda.cpp:272
static std::string strfmt(Args... args)
Definition: strfmt.hpp:35
Formal argument variable.
Definition: lambda.hpp:124
Bound context variable.
Definition: lambda.hpp:100