54 return Types_.find(&output)->second;
63 JLM_ASSERT(outputMap.find(&output) == outputMap.end());
64 outputMap[&output] = std::move(modRefChainInformation);
67 std::optional<ModRefChainInformation>
76 auto & outputMap = regionMapIt->second;
77 const auto outputMapIt = outputMap.find(&output);
78 if (outputMapIt == outputMap.end())
83 return outputMapIt->second;
92 static std::unique_ptr<Context>
95 return std::make_unique<Context>();
100 std::unordered_map<const rvsdg::Output *, ModRefChainInformation>;
113 std::unordered_map<const rvsdg::Output *, ModRefChainLink::Type>
Types_{};
114 std::unordered_map<const rvsdg::Region *, ModRefChainInformationMap>
RegionMap_{};
163 for (
auto & output : simpleNode.Outputs())
165 if (output.IsDead() && is<MemoryStateType>(output.Type()))
175 throw std::logic_error(
util::strfmt(
"Unhandled node type: ", node->DebugString()));
197 for (
auto & subregion : gammaNode.
Subregions())
202 for (
auto & [branchResults, output] : gammaNode.
GetExitVars())
204 if (is<MemoryStateType>(output->Type()))
206 for (
const auto branchResult : branchResults)
215 for (
auto & subregion : gammaNode.
Subregions())
217 Context_->dropModRefChainInformation(subregion);
230 if (!is<MemoryStateType>(loopVar.output->Type()))
234 const auto hasModificationChainLinkAboveInRegion =
242 if (loopVar.output->IsDead())
259 const bool hasModRefChainLinkAboveInRegion =
traceModRefChains(startOutput, summary);
263 for (
const auto & [links] : refSubchains)
267 std::vector<rvsdg::Output *> joinOperands;
269 for (
auto [linkOutput, linkModRefType] : links)
273 modRefChainInput.divert_to(newMemoryStateOperand);
274 joinOperands.push_back(linkOutput);
278 if (!links.front().output->IsDead())
281 links.front().output->divertUsersWhere(
285 return rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user) != &joinNode;
291 return hasModRefChainLinkAboveInRegion;
297 if (
auto [loadNode, loadOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<LoadOperation>(output);
303 if (
const auto thetaNode = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(output))
305 return *thetaNode->MapOutputLoopVar(output).input;
308 throw std::logic_error(
"Unhandled node type!");
311 std::vector<LoadChainSeparation::ModRefChain>
314 std::vector<ModRefChain> refSubchains;
315 for (
auto linkIt = modRefChain.
links.begin(); linkIt != modRefChain.
links.end();)
324 auto nextLinkIt = std::next(linkIt);
325 if (nextLinkIt == modRefChain.
links.end()
334 refSubchains.push_back({});
337 refSubchains.back().links.push_back(*linkIt);
350 if (
const auto modRefChainInformationOpt =
Context_->tryGetModRefChainInformation(startOutput))
353 return modRefChainInformationOpt.value().hasModificationChainLinkAboveInRegion;
358 bool doneTracing =
false;
359 bool hasModRefChainLinkAboveInRegion =
false;
368 auto & node = rvsdg::AssertGetOwnerNode<rvsdg::Node>(*currentOutput);
378 hasModRefChainLinkAboveInRegion =
true;
380 for (
auto [entryVarInput, _] : gammaNode.
GetEntryVars())
382 if (is<MemoryStateType>(entryVarInput->Type()))
384 hasModRefChainLinkAboveInRegion |=
385 traceModRefChains(*entryVarInput->origin(), summary);
392 const auto modRefChainLinkType =
Context_->getModRefChainLinkType(*currentOutput);
393 hasModRefChainLinkAboveInRegion |=
395 currentModRefChain.
add({ currentOutput, modRefChainLinkType });
400 auto & operation = simpleNode.GetOperation();
410 hasModRefChainLinkAboveInRegion =
true;
417 hasModRefChainLinkAboveInRegion =
true;
426 hasModRefChainLinkAboveInRegion =
true;
433 hasModRefChainLinkAboveInRegion =
true;
447 for (
auto & nodeInput : node.Inputs())
449 hasModRefChainLinkAboveInRegion |=
465 hasModRefChainLinkAboveInRegion |=
471 for (
auto & nodeInput : node.Inputs())
473 hasModRefChainLinkAboveInRegion |=
480 for (
auto & nodeInput : node.Inputs())
482 hasModRefChainLinkAboveInRegion |=
489 for (
auto & nodeInput : node.Inputs())
491 hasModRefChainLinkAboveInRegion |=
510 throw std::logic_error(
511 util::strfmt(
"Unhandled operation type: ", operation.debug_string()));
516 throw std::logic_error(
util::strfmt(
"Unhandled node type: ", node.DebugString()));
518 }
while (!doneTracing);
520 summary.add(std::move(currentModRefChain));
521 Context_->addModRefChainInformation(startOutput, { hasModRefChainLinkAboveInRegion });
522 return hasModRefChainLinkAboveInRegion;
static rvsdg::Input & GetMemoryStateInput(const rvsdg::Node &node) noexcept
static rvsdg::Input & mapMemoryStateOutputToInput(rvsdg::Output &output) noexcept
std::optional< ModRefChainInformation > tryGetModRefChainInformation(const rvsdg::Output &output) const
void dropModRefChainInformation(const rvsdg::Region ®ion)
std::unordered_map< const rvsdg::Output *, ModRefChainInformation > ModRefChainInformationMap
static std::unique_ptr< Context > create()
std::unordered_map< const rvsdg::Output *, ModRefChainLink::Type > Types_
bool hasModRefChainLinkType(const rvsdg::Output &output) const noexcept
std::unordered_map< const rvsdg::Region *, ModRefChainInformationMap > RegionMap_
void addModRefChainInformation(const rvsdg::Output &output, ModRefChainInformation modRefChainInformation)
ModRefChainInformationMap & getOrInsertModRefChainInformationMap(const rvsdg::Region ®ion)
void add(const rvsdg::Output &output, const ModRefChainLink::Type &type)
ModRefChainLink::Type getModRefChainLinkType(const rvsdg::Output &output) const
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
std::unique_ptr< Context > Context_
bool separateReferenceChains(rvsdg::Output &startOutput)
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
void separateRefenceChainsInTheta(rvsdg::ThetaNode &thetaNode)
void separateRefenceChainsInGamma(rvsdg::GammaNode &gammaNode)
static std::vector< ModRefChain > extractReferenceSubchains(const ModRefChain &modRefChain)
bool traceModRefChains(rvsdg::Output &startOutput, ModRefChainSummary &summary)
void separateReferenceChainsInLambda(rvsdg::LambdaNode &lambdaNode)
~LoadChainSeparation() noexcept override
void separateReferenceChainsInRegion(rvsdg::Region ®ion)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::Input & mapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::SimpleNode & CreateNode(const std::vector< rvsdg::Output * > &operands)
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
UndefValueOperation class.
Conditional operator / pattern matching.
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
std::vector< EntryVar > GetEntryVars() const
Gets all entry variables for this gamma.
Region & GetRootRegion() const noexcept
rvsdg::Region * subregion() const noexcept
rvsdg::Region * region() const noexcept
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
A phi node represents the fixpoint of mutually recursive definitions.
Represent acyclic RVSDG subgraphs.
SubregionIteratorRange Subregions()
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
rvsdg::Region * subregion() const noexcept
Global memory state passed between functions.
rvsdg::Input & GetMemoryStateRegionResult(const rvsdg::LambdaNode &lambdaNode) noexcept
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
Region * TryGetOwnerRegion(const rvsdg::Input &input) noexcept
static std::string strfmt(Args... args)
bool hasModificationChainLinkAboveInRegion
std::vector< ModRefChain > modRefChains
std::vector< ModRefChainLink > links
void add(ModRefChainLink modRefChainLink)