28 static rvsdg::RegionResult *
32 std::vector<rvsdg::Node *> & load_nodes,
33 const std::vector<rvsdg::Node *> & store_nodes,
34 std::vector<rvsdg::Node *> & decouple_nodes)
51 else if (
auto gammaNode = rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(user))
53 auto ip = gammaNode->AddEntryVar(new_edge);
54 std::vector<jlm::rvsdg::Output *> vec;
55 new_edge = gammaNode->AddExitVar(ip.branchArgument).output;
56 new_next.divert_to(new_edge);
58 auto rolevar = gammaNode->MapInput(user);
60 if (
auto entryvar = std::get_if<rvsdg::GammaNode::EntryVar>(&rolevar))
62 for (
size_t i = 0; i < gammaNode->nsubregions(); ++i)
65 entryvar->branchArgument[i],
70 common_edge = subres->output();
74 else if (
auto theta = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(user))
76 auto olv = theta->MapInputLoopVar(user);
77 auto lv = theta->AddLoopVar(new_edge);
78 trace_edge(olv.pre, lv.pre, load_nodes, store_nodes, decouple_nodes);
79 common_edge = olv.output;
81 new_next.divert_to(new_edge);
83 else if (
auto sn = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user))
85 auto op = &sn->GetOperation();
89 if (store_nodes.end() != std::find(store_nodes.begin(), store_nodes.end(), sn))
91 user.divert_to(new_edge);
92 sn->output(0)->divert_users(common_edge);
93 new_edge = sn->output(0);
94 new_next.divert_to(new_edge);
98 common_edge = sn->output(0);
104 if (load_nodes.end() != std::find(load_nodes.begin(), load_nodes.end(), sn))
107 user.divert_to(new_edge);
108 sn->output(1)->divert_users(common_edge);
109 new_next.divert_to(sn->output(1));
110 new_edge = sn->output(1);
111 new_next.divert_to(new_edge);
115 common_edge = sn->output(1);
120 int oi = sn->noutputs() - sn->ninputs() + user.
index();
122 if (decouple_nodes.end() != std::find(decouple_nodes.begin(), decouple_nodes.end(), sn))
125 user.divert_to(new_edge);
126 sn->output(oi)->divert_users(common_edge);
127 new_next.divert_to(sn->output(oi));
128 new_edge = new_next.origin();
132 common_edge = sn->output(oi);
138 common_edge = sn->output(0);
148 std::vector<rvsdg::Node *>
151 std::function<void(
rvsdg::Region &, std::vector<rvsdg::Node *> &)> gatherCalls =
152 [&gatherCalls](
rvsdg::Region & region, std::vector<rvsdg::Node *> & calls)
159 for (
auto & subregion : structuralNode->Subregions())
161 gatherCalls(subregion, calls);
165 if (rvsdg::is<llvm::CallOperation>(node))
168 if (functionName.rfind(
"decouple") == functionName.npos)
170 calls.push_back(node);
176 std::vector<rvsdg::Node *> calls;
177 gatherCalls(region, calls);
184 const auto lambdaSubregion = lambdaNode.
subregion();
188 for (
auto & tp : tracedPointerNodesVector)
190 auto & decouple_nodes = tp.decoupleNodes;
191 auto decouple_requests_cnt = decouple_nodes.size();
193 for (
size_t i = 0; i < decouple_requests_cnt; ++i)
195 auto req = decouple_nodes[i];
196 auto channel = req->input(1)->origin();
199 decouple_nodes.push_back(decouple_response);
205 for (
auto call : nonDecoupleCalls)
207 tracedPointerNodesVector.emplace_back();
208 tracedPointerNodesVector.back().decoupleNodes.push_back(call);
211 const size_t numMemoryStates = tracedPointerNodesVector.size() + 1;
214 std::vector<llvm::MemoryNodeId> memoryNodeIds;
215 for (
size_t i = 0; i < numMemoryStates; ++i)
217 memoryNodeIds.push_back(i);
220 auto & lambdaEntrySplitNode =
222 auto memoryStates =
outputs(&lambdaEntrySplitNode);
227 auto common_edge = memoryStates.back();
229 memoryStateArgument.divertUsersWhere(
233 return &input != lambdaEntrySplitNode.input(0);
237 memoryStates.back() = state_result->origin();
243 memoryStates.pop_back();
244 state_result->divert_to(lambdaExitMergeNode.output(0));
246 for (
auto tp : tracedPointerNodesVector)
248 auto new_edge = memoryStates.back();
249 memoryStates.pop_back();
250 trace_edge(common_edge, new_edge, tp.loadNodes, tp.storeNodes, tp.decoupleNodes);
263 const auto & graph = rvsdgModule.
Rvsdg();
265 if (rootRegion->numNodes() != 1)
267 throw std::logic_error(
"Root should have only one node now");
270 const auto lambdaNode =
274 throw std::logic_error(
"Node needs to be a lambda");
~MemoryStateSeparation() noexcept override
static std::vector< rvsdg::Node * > gatherNonDecoupleCalls(rvsdg::Region ®ion)
static void separateMemoryStates(const rvsdg::LambdaNode &lambdaNode)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static rvsdg::SimpleNode & CreateNode(rvsdg::Output &operand, std::vector< MemoryNodeId > memoryNodeIds)
static rvsdg::Node & CreateNode(rvsdg::Region ®ion, const std::vector< rvsdg::Output * > &operands, const std::vector< MemoryNodeId > &memoryNodeIds)
Region & GetRootRegion() const noexcept
rvsdg::Region * subregion() const noexcept
size_t index() const noexcept
size_t nusers() const noexcept
Represents the result of a region.
Represent acyclic RVSDG subgraphs.
#define JLM_UNREACHABLE(msg)
rvsdg::SimpleNode * find_decouple_response(const rvsdg::LambdaNode *lambda, const llvm::IntegerConstantOperation *request_constant)
std::string get_function_name(jlm::rvsdg::Input *input)
const llvm::IntegerConstantOperation * trace_constant(const rvsdg::Output *dst)
std::vector< TracedPointerNodes > TracePointerArguments(const rvsdg::LambdaNode *lambda)
static rvsdg::Output * trace_edge(rvsdg::Input *state_edge, rvsdg::Output *new_edge, rvsdg::SimpleNode *target_call, rvsdg::Output *end)
rvsdg::Input & GetMemoryStateRegionResult(const rvsdg::LambdaNode &lambdaNode) noexcept
rvsdg::Output & GetMemoryStateRegionArgument(const rvsdg::LambdaNode &lambdaNode) noexcept
static std::vector< jlm::rvsdg::Output * > outputs(const Node *node)