49 return Types_.find(&output)->second;
52 static std::unique_ptr<Context>
55 return std::make_unique<Context>();
59 std::unordered_map<const rvsdg::Output *, ModRefChainLink::Type>
Types_{};
108 for (
auto & output : simpleNode.Outputs())
110 if (output.IsDead() && is<MemoryStateType>(output.Type()))
120 throw std::logic_error(
util::strfmt(
"Unhandled node type: ", node->DebugString()));
139 for (
auto & subregion : gammaNode.
Subregions())
144 std::vector<util::HashSet<rvsdg::Output *>> visitedOutputs(gammaNode.
nsubregions());
145 for (
auto & [branchResults, output] : gammaNode.
GetExitVars())
147 if (is<MemoryStateType>(output->Type()))
149 for (
const auto branchResult : branchResults)
151 const auto regionIndex = branchResult->region()->index();
152 JLM_ASSERT(regionIndex < visitedOutputs.size());
170 if (!is<MemoryStateType>(loopVar.output->Type()))
174 auto hasModificationChainLink =
182 if (loopVar.output->IsDead())
201 for (
const auto & [_, links] : refSubchains)
205 std::vector<rvsdg::Output *> joinOperands;
207 for (
auto [linkOutput, linkModRefType] : links)
211 modRefChainInput.divert_to(newMemoryStateOperand);
212 joinOperands.push_back(linkOutput);
216 if (!links.front().output->IsDead())
219 links.front().output->divertUsersWhere(
223 return rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user) != &joinNode;
235 if (
auto [loadNode, loadOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<LoadOperation>(output);
241 if (
const auto thetaNode = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
243 return *thetaNode->MapOutputLoopVar(output).input;
246 throw std::logic_error(
"Unhandled node type!");
249 std::vector<LoadChainSeparation::ModRefChain>
252 std::vector<ModRefChain> refSubchains;
253 for (
auto linkIt = modRefChain.
links.begin(); linkIt != modRefChain.
links.end();)
262 auto nextLinkIt = std::next(linkIt);
263 if (nextLinkIt == modRefChain.
links.end()
272 refSubchains.push_back({});
275 refSubchains.back().links.push_back(*linkIt);
291 if (!visitedOutputs.
insert(&startOutput))
298 bool doneTracing =
false;
307 auto & node = rvsdg::AssertGetOwnerNode<rvsdg::Node>(*currentOutput);
318 for (
auto [entryVarInput, _] : gammaNode.
GetEntryVars())
320 if (is<MemoryStateType>(entryVarInput->Type()))
322 traceModRefChains(*entryVarInput->origin(), visitedOutputs, summary);
329 const auto modRefChainLinkType =
Context_->getModRefChainLinkType(*currentOutput);
330 currentModRefChain.
add({ currentOutput, modRefChainLinkType });
335 auto & operation = simpleNode.GetOperation();
375 for (
auto & nodeInput : node.Inputs())
397 for (
auto & nodeInput : node.Inputs())
405 for (
auto & nodeInput : node.Inputs())
413 for (
auto & nodeInput : node.Inputs())
433 throw std::logic_error(
434 util::strfmt(
"Unhandled operation type: ", operation.debug_string()));
439 throw std::logic_error(
util::strfmt(
"Unhandled node type: ", node.DebugString()));
441 }
while (!doneTracing);
443 summary.add(std::move(currentModRefChain));
static rvsdg::Input & GetMemoryStateInput(const rvsdg::Node &node) noexcept
static rvsdg::Input & mapMemoryStateOutputToInput(rvsdg::Output &output) noexcept
static std::unique_ptr< Context > create()
std::unordered_map< const rvsdg::Output *, ModRefChainLink::Type > Types_
bool hasModRefChainLinkType(const rvsdg::Output &output) const noexcept
void add(const rvsdg::Output &output, const ModRefChainLink::Type &type)
ModRefChainLink::Type getModRefChainLinkType(const rvsdg::Output &output) const
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
std::unique_ptr< Context > Context_
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
void separateRefenceChainsInGamma(rvsdg::GammaNode &gammaNode)
static std::vector< ModRefChain > extractReferenceSubchains(const ModRefChain &modRefChain)
void traceModRefChains(rvsdg::Output &startOutput, util::HashSet< rvsdg::Output * > &visitedOutputs, ModRefChainSummary &summary)
void separateReferenceChainsInLambda(rvsdg::LambdaNode &lambdaNode)
bool separateReferenceChains(rvsdg::Output &startOutput, util::HashSet< rvsdg::Output * > &visitedOutputs)
~LoadChainSeparation() noexcept override
void separateReferenceChainsInRegion(rvsdg::Region ®ion)
void separateRefenceChainsInTheta(rvsdg::ThetaNode &thetaNode, util::HashSet< rvsdg::Output * > &visitedOutputs)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::SimpleNode & CreateNode(const std::vector< rvsdg::Output * > &operands)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
UndefValueOperation class.
Conditional operator / pattern matching.
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Region & GetRootRegion() const noexcept
rvsdg::Region * subregion() const noexcept
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
A phi node represents the fixpoint of mutually recursive definitions.
Represent acyclic RVSDG subgraphs.
SubregionIteratorRange Subregions()
size_t nsubregions() const noexcept
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
rvsdg::Region * subregion() const noexcept
bool insert(ItemType item)
Global memory state passed between functions.
rvsdg::Input & GetMemoryStateRegionResult(const rvsdg::LambdaNode &lambdaNode) noexcept
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
static std::string type(const Node *n)
Region * TryGetOwnerRegion(const rvsdg::Input &input) noexcept
static std::string strfmt(Args... args)
bool hasModificationChainLink
std::vector< ModRefChain > modRefChains
std::vector< ModRefChainLink > links
void add(ModRefChainLink modRefChainLink)