22 #include <unordered_map>
65 static std::unique_ptr<Statistics>
68 return std::make_unique<Statistics>(sourceFile);
154 throw std::logic_error(
"Output does not belong to a congruence set");
180 auto nextSet =
sets_.size();
186 if (
sets_[it->second].leader == &leader)
192 sets_[it->second].followers.Remove(&leader);
193 it->second = nextSet;
197 sets_.emplace_back(leader);
210 return *
sets_[index].leader;
226 const bool newFollower =
sets_[index].followers.insert(&follower);
239 if (
sets_[it->second].leader == &follower)
240 throw std::logic_error(
"Cannot turn a leader into a follower");
242 const bool removed =
sets_[it->second].followers.Remove(&follower);
258 return sets_[index].followers;
286 const auto o1Set = context.
getSetFor(o1);
287 const auto o2Set = context.
getSetFor(o2);
289 return o1Set == o2Set;
312 if (!simpleNode1 || !simpleNode2)
315 if (simpleNode1->ninputs() != simpleNode2->ninputs())
318 if (simpleNode1->GetOperation() != simpleNode2->GetOperation())
323 for (
auto & input : simpleNode1->Inputs())
325 const auto origin1 = input.origin();
326 const auto origin2 = simpleNode2->
input(input.index())->origin();
342 for (
auto & output : leader.
Outputs())
367 for (
size_t i = 0; i < leader.
noutputs(); i++)
369 const auto & leaderOutput = *leader.
output(i);
370 const auto & followerOutput = *follower.
output(i);
371 const auto leaderSet = context.
getSetFor(leaderOutput);
398 const auto & output0Leader = context.
getLeader(output0Set);
399 const auto & leaderNode = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(output0Leader);
424 if (existingLeaderNode == &node)
426 leaders.push_back(&node);
443 for (
auto leader : leaders)
454 leaders.push_back(&node);
475 if (leaderNode == &node)
488 const auto tryFindCongruentUserOf = [&](
const rvsdg::Output & output) ->
bool
492 for (
auto & user : output.Users())
494 if (user.index() != 0)
497 const auto otherNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user);
502 if (otherNode == &node)
507 if (otherNode != otherNodeLeader)
524 const auto origin0Set = context.
getSetFor(*origin);
525 const auto & origin0Leader = context.
getLeader(origin0Set);
526 const auto & origin0Followers = context.
getFollowers(origin0Set);
527 if (tryFindCongruentUserOf(origin0Leader))
529 for (
auto follower : origin0Followers.Items())
531 if (tryFindCongruentUserOf(*follower))
574 const std::vector<CommonNodeElimination::Context::CongruenceSetIndex> & partitions,
591 const auto currentPartition = context.
tryGetSetFor(*argument);
592 const auto key = std::make_pair(currentPartition, partitions[argument->index()]);
597 if (
const auto it = newSets.find(key); it != newSets.end())
600 const auto toFollow = it->second;
603 if (currentPartition == toFollow)
642 bool anyChanges =
false;
646 if (subregion.narguments() == 0)
653 std::vector<size_t> partitions(subregion.narguments());
658 for (
const auto argument : subregion.Arguments())
660 if (
const auto input = argument->input())
663 partitions[argument->index()] = context.
getSetFor(*input->origin());
668 partitions[argument->index()] = nextUniquePartitionKey++;
693 static std::optional<CommonNodeElimination::Context::CongruenceSetIndex>
698 std::optional<CommonNodeElimination::Context::CongruenceSetIndex> sharedCongruenceSet;
704 const auto inputCongruenceSet = context.
getSetFor(*argument->input()->origin());
705 if (!sharedCongruenceSet.has_value())
707 sharedCongruenceSet = inputCongruenceSet;
709 else if (*sharedCongruenceSet != inputCongruenceSet)
722 return sharedCongruenceSet;
734 [[nodiscard]]
static size_t
742 const auto set = context.
getSetFor(*branchResult->origin());
760 [[nodiscard]]
static bool
767 &rvsdg::AssertGetOwnerNode<rvsdg::GammaNode>(*first.
output)
768 == &rvsdg::AssertGetOwnerNode<rvsdg::GammaNode>(*second.
output));
770 for (
size_t i = 0; i < numResults; i++)
774 if (firstSet != secondSet)
793 std::unordered_map<size_t, CommonNodeElimination::Context::CongruenceSetIndex> & leaderHashes,
801 const auto [_, inserted] = leaderHashes.emplace(hash, congruenceSet);
827 std::unordered_map<size_t, CommonNodeElimination::Context::CongruenceSetIndex> & leaderHashes,
834 const auto [it, inserted] = leaderHashes.emplace(hash, 0);
842 auto & otherLeader = context.
getLeader(it->second);
870 std::unordered_map<size_t, CommonNodeElimination::Context::CongruenceSetIndex> leaderHashes;
876 bool skipInvarianceCheck =
false;
879 const auto existingSet = context.
tryGetSetFor(*exitVar.output);
885 const auto & exisitingLeader = context.
getLeader(existingSet);
886 if (&exisitingLeader == exitVar.output)
896 if (rvsdg::TryGetOwnerNode<rvsdg::GammaNode>(exisitingLeader) == &gamma)
907 skipInvarianceCheck =
true;
911 if (!skipInvarianceCheck)
915 entryVarCongruenceSet.has_value())
917 context.
addFollower(*entryVarCongruenceSet, *exitVar.output);
952 std::vector<CommonNodeElimination::Context::CongruenceSetIndex> partitions;
953 for (
const auto & loopVar : loopVars)
955 partitions.push_back(context.
getSetFor(*loopVar.post->origin()));
970 resultToOutputSetMapping;
971 for (
auto & loopVar : loopVars)
976 const auto inputOriginSet = context.
getSetFor(*loopVar.input->origin());
977 context.
addFollower(inputOriginSet, *loopVar.output);
982 const auto resultSet = context.
getSetFor(*loopVar.post->origin());
983 const auto it = resultToOutputSetMapping.find(resultSet);
984 if (it != resultToOutputSetMapping.end())
991 resultToOutputSetMapping.emplace(resultSet, outputSet);
1055 if (node->ninputs() == 0)
1057 markSimpleTopNode(simple, leaders, context);
1061 markSimpleNode(simple, context);
1076 const auto outputSet = context.
getSetFor(output);
1078 auto & leader = context.
getLeader(outputSet);
1079 if (&leader == &output)
1091 bool divertInSubregions =
false;
1096 divertInSubregions =
true;
1100 divertInSubregions =
true;
1104 divertInSubregions =
true;
1108 divertInSubregions =
true;
1116 if (divertInSubregions)
1129 for (
auto argument : region.
Arguments())
1146 for (
auto & output : node->Outputs())
1157 rvsdg::RvsdgModule & module,
1160 const auto & rvsdg = module.Rvsdg();
1161 auto & rootRegion = rvsdg.GetRootRegion();
1164 auto statistics = Statistics::Create(module.SourceFilePath().value());
1166 statistics->startMarkStatistics(rvsdg);
1169 statistics->endMarkStatistics();
1171 statistics->startDivertStatistics();
1173 statistics->endDivertStatistics(rvsdg);
static jlm::util::StatisticsCollector statisticsCollector
static constexpr auto NoCongruenceSetIndex
CongruenceSetIndex getOrCreateSetForLeader(const rvsdg::Output &leader)
CongruenceSetIndex numCongruenceSets() const
void addFollower(CongruenceSetIndex index, const rvsdg::Output &follower)
const util::HashSet< const rvsdg::Output * > & getFollowers(CongruenceSetIndex index) const
size_t CongruenceSetIndex
bool hasSet(const rvsdg::Output &output) const
CongruenceSetIndex getSetFor(const rvsdg::Output &output) const
CongruenceSetIndex tryGetSetFor(const rvsdg::Output &output) const
const rvsdg::Output & getLeader(CongruenceSetIndex index) const
std::vector< CongruenceSet > sets_
std::unordered_map< const rvsdg::Output *, CongruenceSetIndex > congruenceSetMapping_
void startMarkStatistics(const rvsdg::Graph &graph) noexcept
void startDivertStatistics() noexcept
void endMarkStatistics() noexcept
void endDivertStatistics(const rvsdg::Graph &graph) noexcept
static std::unique_ptr< Statistics > Create(const util::FilePath &sourceFile)
~Statistics() override=default
const char * DivertTimerLabel_
Statistics(const util::FilePath &sourceFile)
const char * MarkTimerLabel_
Common Node Elimination Discovers simple nodes, region arguments and structural node outputs that are...
~CommonNodeElimination() noexcept override
Conditional operator / pattern matching.
ExitVar MapOutputExitVar(const rvsdg::Output &output) const
Maps gamma output to exit variable description.
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
OutputIteratorRange Outputs() noexcept
rvsdg::Region * region() const noexcept
NodeOutput * output(size_t index) const noexcept
size_t ninputs() const noexcept
size_t noutputs() const noexcept
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
void divert_users(jlm::rvsdg::Output *new_origin)
A phi node represents the fixpoint of mutually recursive definitions.
Represents the argument of a region.
Represent acyclic RVSDG subgraphs.
RegionArgumentRange Arguments() noexcept
size_t narguments() const noexcept
NodeInput * input(size_t index) const noexcept
NodeOutput * output(size_t index) const noexcept
SubregionIteratorRange Subregions()
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
rvsdg::Region * subregion() const noexcept
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
util::Timer & GetTimer(const std::string &name)
util::Timer & AddTimer(std::string name)
void AddMeasurement(std::string name, T value)
Global memory state passed between functions.
static void lookupOrInsertGammaExitVarInHashmap(const rvsdg::GammaNode::ExitVar &exitVar, const rvsdg::GammaNode &gamma, std::unordered_map< size_t, CommonNodeElimination::Context::CongruenceSetIndex > &leaderHashes, CommonNodeElimination::Context &context)
static bool partitionArguments(const rvsdg::Region ®ion, const std::vector< CommonNodeElimination::Context::CongruenceSetIndex > &partitions, CommonNodeElimination::Context &context)
static void markSimpleTopNode(const rvsdg::SimpleNode &node, TopNodeLeaderList &leaders, CommonNodeElimination::Context &context)
static bool markSubregionsFromInputs(const rvsdg::StructuralNode &node, CommonNodeElimination::Context &context)
static void markGamma(const rvsdg::GammaNode &gamma, CommonNodeElimination::Context &context)
static void markGraphImports(const rvsdg::Region ®ion, CommonNodeElimination::Context &context)
static bool checkNodesCongruent(const rvsdg::Node &node1, const rvsdg::Node &node2, CommonNodeElimination::Context &context)
static void divertOutput(rvsdg::Output &output, CommonNodeElimination::Context &context)
static void divertInRegion(rvsdg::Region &, CommonNodeElimination::Context &)
void markNodeAsLeader(const rvsdg::Node &leader, CommonNodeElimination::Context &context)
static void insertGammaExitVarInHashmap(rvsdg::GammaNode::ExitVar &exitVar, CommonNodeElimination::Context::CongruenceSetIndex congruenceSet, std::unordered_map< size_t, CommonNodeElimination::Context::CongruenceSetIndex > &leaderHashes, CommonNodeElimination::Context &context)
static size_t getGammaExitVariableHash(const rvsdg::GammaNode::ExitVar &exitVar, CommonNodeElimination::Context &context)
std::vector< const rvsdg::Node * > TopNodeLeaderList
static void markRegion(const rvsdg::Region &, CommonNodeElimination::Context &context)
static bool areOutputsCongruent(const rvsdg::Output &o1, const rvsdg::Output &o2, CommonNodeElimination::Context &context)
void markNodesAsCongruent(const rvsdg::Node &leader, const rvsdg::Node &follower, CommonNodeElimination::Context &context)
static void divertInStructuralNode(rvsdg::StructuralNode &node, CommonNodeElimination::Context &context)
const rvsdg::SimpleNode * tryGetLeaderNode(const rvsdg::SimpleNode &node, CommonNodeElimination::Context &context)
static void markTheta(const rvsdg::ThetaNode &theta, CommonNodeElimination::Context &context)
static bool areGammaExitVariablesCongruent(const rvsdg::GammaNode::ExitVar &first, const rvsdg::GammaNode::ExitVar &second, CommonNodeElimination::Context &context)
static std::optional< CommonNodeElimination::Context::CongruenceSetIndex > tryGetGammaExitVarCongruenceSet(rvsdg::GammaNode::ExitVar &exitVar, CommonNodeElimination::Context &context)
static void markStructuralNode(const rvsdg::StructuralNode &node, CommonNodeElimination::Context &context)
static void markSimpleNode(const rvsdg::SimpleNode &node, CommonNodeElimination::Context &context)
void MatchTypeOrFail(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
static bool ThetaLoopVarIsInvariant(const ThetaNode::LoopVar &loopVar) noexcept
void MatchType(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
size_t ninputs(const rvsdg::Region *region) noexcept
std::size_t CombineHashes(std::size_t hash, Args... args)
CongruenceSet(const rvsdg::Output &leader)
util::HashSet< const rvsdg::Output * > followers
const rvsdg::Output * leader
A variable routed out of all gamma regions as result.
rvsdg::Output * output
Output of gamma.
std::vector< rvsdg::Input * > branchResult
Variable exit points (results per subregion).
static const char * NumRvsdgInputsAfter
static const char * NumRvsdgInputsBefore