31 static rvsdg::Output *
33 rvsdg::Input * state_edge,
34 std::vector<rvsdg::SimpleNode *> & mem_ops,
37 static rvsdg::Output *
49 previous_state_edge = state_edge;
51 if (state_edge->
origin() == end)
61 else if (
auto ln = rvsdg::TryGetOwnerNode<LoopNode>(*state_edge))
63 auto si = util::assertedCast<rvsdg::StructuralInput>(state_edge);
64 auto arg = si->arguments.begin().ptr();
65 std::vector<rvsdg::SimpleNode *> mem_ops;
67 if (std::count(mem_ops.begin(), mem_ops.end(), target_call))
70 auto new_out = ln->AddLoopVar(new_edge);
71 new_edge_user->divert_to(new_out);
72 auto new_in = util::assertedCast<rvsdg::StructuralInput>(
get_mem_state_user(new_edge));
77 new_in->arguments.begin().ptr(),
83 JLM_ASSERT(rvsdg::TryGetOwnerNode<LoopNode>(*out));
90 auto sn = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*si);
91 auto new_si = new_edge_user;
92 auto new_sn = &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*new_si);
93 auto [branchNode, branchOperation] =
94 rvsdg::TryGetSimpleNodeAndOptionalOp<BranchOperation>(*state_edge);
95 auto [muxNode, muxOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<MuxOperation>(*state_edge);
103 new_edge_user->divert_to(nmux);
105 for (
size_t i = 0; i < sn->noutputs(); ++i)
110 JLM_ASSERT(rvsdg::IsOwnerNodeOperation<MuxOperation>(*out));
114 else if (branchOperation)
117 JLM_ASSERT(rvsdg::IsOwnerNodeOperation<BranchOperation>(*new_edge_user));
119 return util::assertedCast<rvsdg::RegionResult>(
get_mem_state_user(sn->output(0)))->output();
121 else if (muxOperation && !muxOperation->loop)
124 JLM_ASSERT(rvsdg::IsOwnerNodeOperation<MuxOperation>(*new_edge_user));
125 return sn->output(0);
127 else if (muxOperation)
130 JLM_ASSERT(rvsdg::IsOwnerNodeOperation<MuxOperation>(*new_edge_user));
133 new_edge = new_sn->output(0);
135 else if (rvsdg::IsOwnerNodeOperation<LoopConstantBufferOperation>(*state_edge))
138 JLM_ASSERT(rvsdg::IsOwnerNodeOperation<MuxOperation>(*new_edge_user));
140 new_edge = new_sn->output(0);
142 else if (rvsdg::IsOwnerNodeOperation<llvm::MemoryStateSplitOperation>(*state_edge))
145 for (
size_t i = 0; i < sn->noutputs(); ++i)
148 std::vector<rvsdg::SimpleNode *> mem_ops;
151 JLM_ASSERT(rvsdg::IsOwnerNodeOperation<llvm::MemoryStateMergeOperation>(*after_merge));
152 if (std::count(mem_ops.begin(), mem_ops.end(), target_call))
163 new_edge = new_edge_user->origin();
166 rvsdg::IsOwnerNodeOperation<llvm::MemoryStateMergeOperation>(*state_edge)
167 || rvsdg::IsOwnerNodeOperation<llvm::LambdaExitMemoryStateMergeOperation>(*state_edge))
170 return sn->output(0);
172 else if (rvsdg::IsOwnerNodeOperation<StateGateOperation>(*state_edge))
176 else if (rvsdg::IsOwnerNodeOperation<llvm::LoadNonVolatileOperation>(*state_edge))
180 else if (rvsdg::IsOwnerNodeOperation<llvm::CallOperation>(*state_edge))
182 auto state_origin = state_edge->
origin();
183 if (sn == target_call)
188 si->divert_to(new_edge);
189 new_edge_user->divert_to(sn->output(sn->noutputs() - 1));
190 new_edge = new_edge_user->origin();
206 std::vector<std::tuple<rvsdg::SimpleNode *, rvsdg::Input *>> & outstanding_dec_reqs,
207 std::vector<rvsdg::SimpleNode *> & mem_ops,
213 for (
auto op : mem_ops)
217 outstanding_dec_reqs.push_back(std::make_tuple(op, state_edge_before));
220 for (
auto op : mem_ops)
222 if (rvsdg::is<const llvm::StoreNonVolatileOperation>(op))
228 for (
auto resp : mem_ops)
234 for (
auto [req, state_edge_req] : outstanding_dec_reqs)
237 if (*req_constant == *res_constant)
242 state_edge_req->divert_to(split_outputs[0]);
245 std::vector<rvsdg::Output *>
operands(
246 { state_edge_after, split_outputs[1], split_outputs[2] });
248 after_user->divert_to(merge_out);
249 trace_edge(state_edge_req, split_outputs[1], req, state_edge_after);
250 trace_edge(state_edge_req, split_outputs[2], resp, state_edge_after);
258 std::vector<rvsdg::SimpleNode *> & mem_ops,
267 if (mem_ops.size() == 1
268 && (rvsdg::is<const llvm::StoreNonVolatileOperation>(mem_ops[0])
269 || rvsdg::is<const llvm::LoadNonVolatileOperation>(mem_ops[0])))
272 JLM_ASSERT(rvsdg::TryGetOwnerNode<LoopNode>(*state_edge_before));
274 rvsdg::TryGetOwnerNode<LoopNode>(*state_edge_before)
275 == rvsdg::TryGetOwnerNode<LoopNode>(*state_edge_after));
283 std::vector<rvsdg::SimpleNode *> & mem_ops,
309 std::vector<std::tuple<rvsdg::SimpleNode *, rvsdg::Input *>> outstanding_dec_reqs;
316 else if (rvsdg::TryGetOwnerNode<LoopNode>(*state_edge))
318 std::vector<rvsdg::SimpleNode *> loop_mem_ops;
319 auto si = jlm::util::assertedCast<rvsdg::StructuralInput>(state_edge);
320 auto arg = si->arguments.begin().ptr();
329 state_edge = new_state_edge;
330 mem_ops.insert(mem_ops.cend(), loop_mem_ops.begin(), loop_mem_ops.end());
333 auto si = state_edge;
334 auto sn = &rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(*si);
335 auto [branchNode, branchOperation] =
336 rvsdg::TryGetSimpleNodeAndOptionalOp<BranchOperation>(*state_edge);
337 auto [muxNode, muxOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<MuxOperation>(*state_edge);
342 std::vector<rvsdg::SimpleNode *> gamma_mem_ops;
345 for (
size_t i = 0; i < sn->noutputs(); ++i)
349 JLM_ASSERT(rvsdg::IsOwnerNodeOperation<MuxOperation>(*out));
355 state_edge = new_state_edge;
356 mem_ops.insert(mem_ops.cend(), gamma_mem_ops.begin(), gamma_mem_ops.end());
358 else if (branchOperation)
362 return util::assertedCast<rvsdg::RegionResult>(
get_mem_state_user(sn->output(0)))->output();
364 else if (muxOperation && !muxOperation->loop)
367 return sn->output(0);
369 else if (muxOperation)
375 else if (rvsdg::IsOwnerNodeOperation<LoopConstantBufferOperation>(*state_edge))
381 rvsdg::IsOwnerNodeOperation<llvm::MemoryStateSplitOperation>(*state_edge)
382 || rvsdg::IsOwnerNodeOperation<llvm::LambdaEntryMemoryStateSplitOperation>(*state_edge))
384 for (
size_t i = 0; i < sn->noutputs(); ++i)
390 rvsdg::IsOwnerNodeOperation<llvm::MemoryStateMergeOperation>(*followed)
391 || rvsdg::IsOwnerNodeOperation<llvm::LambdaExitMemoryStateMergeOperation>(*followed));
397 rvsdg::IsOwnerNodeOperation<llvm::MemoryStateMergeOperation>(*state_edge)
398 || rvsdg::IsOwnerNodeOperation<llvm::LambdaExitMemoryStateMergeOperation>(*state_edge))
400 return sn->output(0);
402 else if (rvsdg::IsOwnerNodeOperation<StateGateOperation>(*state_edge))
404 mem_ops.push_back(sn);
407 else if (rvsdg::IsOwnerNodeOperation<llvm::LoadNonVolatileOperation>(*state_edge))
409 mem_ops.push_back(sn);
412 else if (rvsdg::IsOwnerNodeOperation<llvm::CallOperation>(*state_edge))
414 mem_ops.push_back(sn);
417 else if (rvsdg::IsOwnerNodeOperation<llvm::StoreNonVolatileOperation>(*state_edge))
419 mem_ops.push_back(sn);
423 if (rvsdg::IsOwnerNodeOperation<llvm::MemoryStateSplitOperation>(*state_edge))
427 get_mem_state_user(rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*state_edge)->output(0));
441 JLM_ASSERT(rvsdg::TryGetOwnerNode<LoopNode>(*loop_state_input));
442 auto si = util::assertedCast<rvsdg::StructuralInput>(loop_state_input);
443 auto arg = si->arguments.begin().ptr();
445 auto [muxNode, muxOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<MuxOperation>(*user);
446 JLM_ASSERT(muxOperation && muxOperation->loop);
447 auto mux_node = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*user);
449 *mux_node->input(0)->origin(),
450 *mux_node->input(1)->origin())[0];
451 mux_node->output(0)->divert_users(lcb_out);
459 const auto & graph = rvsdgModule.
Rvsdg();
461 if (rootRegion->numNodes() != 1)
463 throw std::logic_error(
"Root should have only one node now");
466 const auto lambda =
dynamic_cast<const rvsdg::LambdaNode *
>(rootRegion->Nodes().begin().ptr());
469 throw std::logic_error(
"Node needs to be a lambda");
492 JLM_ASSERT(entryNode->noutputs() == exitNode->ninputs());
494 for (
size_t i = 0; i < entryNode->noutputs(); ++i)
496 std::vector<rvsdg::SimpleNode *> mem_ops;
497 std::vector<std::tuple<rvsdg::SimpleNode *, rvsdg::Input *>> dummy;
504 exitNode->input(i)->
origin());
509 dne.
Run(*lambda->subregion(), statisticsCollector);
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &predicate, jlm::rvsdg::Output &value, bool loop=false)
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &predicate, jlm::rvsdg::Output &value)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
~MemoryStateDecoupling() noexcept override
static std::vector< jlm::rvsdg::Output * > create(jlm::rvsdg::Output &predicate, const std::vector< jlm::rvsdg::Output * > &alternatives, bool discarding, bool loop=false)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static rvsdg::Output * Create(const std::vector< rvsdg::Output * > &operands)
static std::vector< rvsdg::Output * > Create(rvsdg::Output &operand, const size_t numResults)
Region & GetRootRegion() const noexcept
rvsdg::Region * region() const noexcept
Represents the result of a region.
#define JLM_UNREACHABLE(msg)
static rvsdg::Output * follow_state_edge(rvsdg::Input *state_edge, std::vector< rvsdg::SimpleNode * > &mem_ops, bool modify)
bool is_dec_res(rvsdg::SimpleNode *node)
const llvm::IntegerConstantOperation * trace_constant(const rvsdg::Output *dst)
static void optimize_single_mem_op_loop(std::vector< rvsdg::SimpleNode * > &mem_ops, rvsdg::Input *state_edge_before, rvsdg::Output *state_edge_after)
static void decouple_mem_state(rvsdg::RvsdgModule &rvsdgModule)
static void handle_structural(std::vector< std::tuple< rvsdg::SimpleNode *, rvsdg::Input * >> &outstanding_dec_reqs, std::vector< rvsdg::SimpleNode * > &mem_ops, rvsdg::Input *state_edge_before, rvsdg::Output *state_edge_after)
rvsdg::Input * get_mem_state_user(rvsdg::Output *state_edge)
bool is_dec_req(rvsdg::SimpleNode *node)
static rvsdg::Output * trace_edge(rvsdg::Input *state_edge, rvsdg::Output *new_edge, rvsdg::SimpleNode *target_call, rvsdg::Output *end)
void convert_loop_state_to_lcb(rvsdg::Input *loop_state_input)
rvsdg::SimpleNode * tryGetMemoryStateEntrySplit(const rvsdg::LambdaNode &lambdaNode) noexcept
rvsdg::SimpleNode * tryGetMemoryStateExitMerge(const rvsdg::LambdaNode &lambdaNode) noexcept
rvsdg::Output & GetMemoryStateRegionArgument(const rvsdg::LambdaNode &lambdaNode) noexcept
static void remove(Node *node)
@ State
Designate a state type.
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)