59 static std::unique_ptr<Statistics>
62 return std::make_unique<Statistics>(sourceFile);
148 throw std::logic_error(
"Output does not belong to a congruence set");
174 auto nextSet =
sets_.size();
180 if (
sets_[it->second].leader == &leader)
186 sets_[it->second].followers.Remove(&leader);
187 it->second = nextSet;
191 sets_.emplace_back(leader);
204 return *
sets_[index].leader;
220 const bool newFollower =
sets_[index].followers.insert(&follower);
233 if (
sets_[it->second].leader == &follower)
234 throw std::logic_error(
"Cannot turn a leader into a follower");
236 const bool removed =
sets_[it->second].followers.Remove(&follower);
252 return sets_[index].followers;
280 const auto o1Set = context.
getSetFor(o1);
281 const auto o2Set = context.
getSetFor(o2);
283 return o1Set == o2Set;
306 if (!simpleNode1 || !simpleNode2)
309 if (simpleNode1->ninputs() != simpleNode2->ninputs())
312 if (simpleNode1->GetOperation() != simpleNode2->GetOperation())
317 for (
auto & input : simpleNode1->Inputs())
319 const auto origin1 = input.origin();
320 const auto origin2 = simpleNode2->
input(input.index())->origin();
336 for (
auto & output : leader.
Outputs())
361 for (
size_t i = 0; i < leader.
noutputs(); i++)
363 const auto & leaderOutput = *leader.
output(i);
364 const auto & followerOutput = *follower.
output(i);
365 const auto leaderSet = context.
getSetFor(leaderOutput);
397 const auto & output0Leader = context.
getLeader(output0Set);
398 const auto & leaderNode = rvsdg::AssertGetOwnerNode<rvsdg::SimpleNode>(output0Leader);
423 if (existingLeaderNode == &node)
425 leaders.push_back(&node);
442 for (
auto leader : leaders)
453 leaders.push_back(&node);
474 if (leaderNode == &node)
487 const auto tryFindCongruentUserOf = [&](
const rvsdg::Output & output) ->
bool
491 for (
auto & user : output.Users())
493 if (user.index() != 0)
496 const auto otherNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(user);
501 if (otherNode == &node)
506 if (otherNode != otherNodeLeader)
523 const auto origin0Set = context.
getSetFor(*origin);
524 const auto & origin0Leader = context.
getLeader(origin0Set);
525 const auto & origin0Followers = context.
getFollowers(origin0Set);
526 if (tryFindCongruentUserOf(origin0Leader))
528 for (
auto follower : origin0Followers.Items())
530 if (tryFindCongruentUserOf(*follower))
573 const std::vector<CommonNodeElimination::Context::CongruenceSetIndex> & partitions,
590 const auto currentPartition = context.
tryGetSetFor(*argument);
591 const auto key = std::make_pair(currentPartition, partitions[argument->index()]);
596 if (
const auto it = newSets.find(key); it != newSets.end())
599 const auto toFollow = it->second;
602 if (currentPartition == toFollow)
641 bool anyChanges =
false;
646 std::vector<size_t> partitions(subregion.narguments());
650 for (
const auto argument : subregion.Arguments())
652 if (
const auto input = argument->input())
655 partitions[argument->index()] = context.
getSetFor(*input->origin());
660 partitions[argument->index()] = nextUniquePartitionKey++;
684 static std::optional<CommonNodeElimination::Context::CongruenceSetIndex>
689 std::optional<CommonNodeElimination::Context::CongruenceSetIndex> sharedCongruenceSet;
695 const auto inputCongruenceSet = context.
getSetFor(*argument->input()->origin());
696 if (!sharedCongruenceSet.has_value())
698 sharedCongruenceSet = inputCongruenceSet;
700 else if (*sharedCongruenceSet != inputCongruenceSet)
713 return sharedCongruenceSet;
738 entryVarCongruenceSet.has_value())
740 context.
addFollower(*entryVarCongruenceSet, *exitVar.output);
773 std::vector<CommonNodeElimination::Context::CongruenceSetIndex> partitions;
774 for (
const auto & loopVar : loopVars)
776 partitions.push_back(context.
getSetFor(*loopVar.post->origin()));
791 resultToOutputSetMapping;
792 for (
auto & loopVar : loopVars)
797 const auto inputOriginSet = context.
getSetFor(*loopVar.input->origin());
798 context.
addFollower(inputOriginSet, *loopVar.output);
803 const auto resultSet = context.
getSetFor(*loopVar.post->origin());
804 const auto it = resultToOutputSetMapping.find(resultSet);
805 if (it != resultToOutputSetMapping.end())
812 resultToOutputSetMapping.emplace(resultSet, outputSet);
876 if (node->ninputs() == 0)
878 markSimpleTopNode(simple, leaders, context);
882 markSimpleNode(simple, context);
897 const auto outputSet = context.
getSetFor(output);
899 auto & leader = context.
getLeader(outputSet);
900 if (&leader == &output)
912 bool divertInSubregions =
false;
917 divertInSubregions =
true;
921 divertInSubregions =
true;
925 divertInSubregions =
true;
929 divertInSubregions =
true;
937 if (divertInSubregions)
967 for (
auto & output : node->Outputs())
978 rvsdg::RvsdgModule & module,
979 util::StatisticsCollector & statisticsCollector)
981 const auto & rvsdg = module.Rvsdg();
982 auto & rootRegion = rvsdg.GetRootRegion();
985 auto statistics = Statistics::Create(module.SourceFilePath().value());
987 statistics->startMarkStatistics(rvsdg);
990 statistics->endMarkStatistics();
992 statistics->startDivertStatistics();
994 statistics->endDivertStatistics(rvsdg);
996 statisticsCollector.CollectDemandedStatistics(std::move(statistics));
static constexpr auto NoCongruenceSetIndex
CongruenceSetIndex getOrCreateSetForLeader(const rvsdg::Output &leader)
CongruenceSetIndex numCongruenceSets() const
void addFollower(CongruenceSetIndex index, const rvsdg::Output &follower)
const util::HashSet< const rvsdg::Output * > & getFollowers(CongruenceSetIndex index) const
size_t CongruenceSetIndex
bool hasSet(const rvsdg::Output &output) const
CongruenceSetIndex getSetFor(const rvsdg::Output &output) const
CongruenceSetIndex tryGetSetFor(const rvsdg::Output &output) const
const rvsdg::Output & getLeader(CongruenceSetIndex index) const
std::vector< CongruenceSet > sets_
std::unordered_map< const rvsdg::Output *, CongruenceSetIndex > congruenceSetMapping_
void startMarkStatistics(const rvsdg::Graph &graph) noexcept
void startDivertStatistics() noexcept
void endMarkStatistics() noexcept
void endDivertStatistics(const rvsdg::Graph &graph) noexcept
static std::unique_ptr< Statistics > Create(const util::FilePath &sourceFile)
~Statistics() override=default
const char * DivertTimerLabel_
Statistics(const util::FilePath &sourceFile)
const char * MarkTimerLabel_
Common Node Elimination Discovers simple nodes, region arguments and structural node outputs that are...
~CommonNodeElimination() noexcept override
Conditional operator / pattern matching.
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
OutputIteratorRange Outputs() noexcept
rvsdg::Region * region() const noexcept
NodeOutput * output(size_t index) const noexcept
size_t ninputs() const noexcept
size_t noutputs() const noexcept
const std::shared_ptr< const rvsdg::Type > & Type() const noexcept
void divert_users(jlm::rvsdg::Output *new_origin)
A phi node represents the fixpoint of mutually recursive definitions.
Represents the argument of a region.
Represent acyclic RVSDG subgraphs.
RegionArgumentRange Arguments() noexcept
size_t narguments() const noexcept
NodeInput * input(size_t index) const noexcept
SubregionIteratorRange Subregions()
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
rvsdg::Region * subregion() const noexcept
util::Timer & GetTimer(const std::string &name)
util::Timer & AddTimer(std::string name)
void AddMeasurement(std::string name, T value)
Global memory state passed between functions.
static bool partitionArguments(const rvsdg::Region ®ion, const std::vector< CommonNodeElimination::Context::CongruenceSetIndex > &partitions, CommonNodeElimination::Context &context)
static void markSimpleTopNode(const rvsdg::SimpleNode &node, TopNodeLeaderList &leaders, CommonNodeElimination::Context &context)
static bool markSubregionsFromInputs(const rvsdg::StructuralNode &node, CommonNodeElimination::Context &context)
static void markGamma(const rvsdg::GammaNode &gamma, CommonNodeElimination::Context &context)
static void markGraphImports(const rvsdg::Region ®ion, CommonNodeElimination::Context &context)
static bool checkNodesCongruent(const rvsdg::Node &node1, const rvsdg::Node &node2, CommonNodeElimination::Context &context)
const rvsdg::Node * tryGetLeaderNode(const rvsdg::Node &node, CommonNodeElimination::Context &context)
static void divertOutput(rvsdg::Output &output, CommonNodeElimination::Context &context)
static void divertInRegion(rvsdg::Region &, CommonNodeElimination::Context &)
void markNodeAsLeader(const rvsdg::Node &leader, CommonNodeElimination::Context &context)
std::vector< const rvsdg::Node * > TopNodeLeaderList
static void markRegion(const rvsdg::Region &, CommonNodeElimination::Context &context)
static bool areOutputsCongruent(const rvsdg::Output &o1, const rvsdg::Output &o2, CommonNodeElimination::Context &context)
void markNodesAsCongruent(const rvsdg::Node &leader, const rvsdg::Node &follower, CommonNodeElimination::Context &context)
static void divertInStructuralNode(rvsdg::StructuralNode &node, CommonNodeElimination::Context &context)
static void markTheta(const rvsdg::ThetaNode &theta, CommonNodeElimination::Context &context)
static std::optional< CommonNodeElimination::Context::CongruenceSetIndex > tryGetGammaExitVarCongruenceSet(rvsdg::GammaNode::ExitVar &exitVar, CommonNodeElimination::Context &context)
static void markStructuralNode(const rvsdg::StructuralNode &node, CommonNodeElimination::Context &context)
static void markSimpleNode(const rvsdg::SimpleNode &node, CommonNodeElimination::Context &context)
void MatchTypeOrFail(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
static bool ThetaLoopVarIsInvariant(const ThetaNode::LoopVar &loopVar) noexcept
void MatchType(T &obj, const Fns &... fns)
Pattern match over subclass type of given object.
size_t ninputs(const rvsdg::Region *region) noexcept
CongruenceSet(const rvsdg::Output &leader)
util::HashSet< const rvsdg::Output * > followers
const rvsdg::Output * leader
A variable routed out of all gamma regions as result.
std::vector< rvsdg::Input * > branchResult
Variable exit points (results per subregion).
static const char * NumRvsdgInputsAfter
static const char * NumRvsdgInputsBefore