Jlm
IfConversion.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/rvsdg/gamma.hpp>
10 #include <jlm/rvsdg/region.hpp>
12 
13 namespace jlm::llvm
14 {
15 
20 {
21 public:
22  ~IfConversionStatistics() override = default;
23 
24  explicit IfConversionStatistics(const util::FilePath & sourceFile)
25  : Statistics(Id::IfConversion, sourceFile)
26  {}
27 
28  void
29  Start() noexcept
30  {
32  }
33 
34  void
35  Stop() noexcept
36  {
38  }
39 
40  static std::unique_ptr<IfConversionStatistics>
41  Create(const util::FilePath & sourceFile)
42  {
43  return std::make_unique<IfConversionStatistics>(sourceFile);
44  }
45 };
46 
47 IfConversion::~IfConversion() noexcept = default;
48 
50  : Transformation("IfConversion")
51 {}
52 
53 void
55 {
56  auto statistics = IfConversionStatistics::Create(module.SourceFilePath().value());
57 
58  statistics->Start();
59  HandleRegion(module.Rvsdg().GetRootRegion());
60  statistics->Stop();
61 
62  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
63 }
64 
65 void
67 {
68  for (auto & node : region.Nodes())
69  {
70  if (const auto gammaNode = dynamic_cast<const rvsdg::GammaNode *>(&node))
71  {
72  HandleGammaNode(*gammaNode);
73  }
74  else if (const auto structuralNode = dynamic_cast<const rvsdg::StructuralNode *>(&node))
75  {
76  for (size_t n = 0; n < structuralNode->nsubregions(); n++)
77  {
78  const auto subregion = structuralNode->subregion(n);
79  HandleRegion(*subregion);
80  }
81  }
82  else if (is<rvsdg::SimpleOperation>(&node))
83  {
84  // Nothing needs to be done
85  }
86  else
87  {
88  JLM_UNREACHABLE("Unsupported node type.");
89  }
90  }
91 }
92 
93 void
95 {
96  if (gammaNode.nsubregions() != 2)
97  return;
98 
99  const auto gammaPredicate = gammaNode.predicate()->origin();
100  for (auto [branchResult, gammaOutput] : gammaNode.GetExitVars())
101  {
102  const auto region0Argument =
103  dynamic_cast<const rvsdg::RegionArgument *>(branchResult[0]->origin());
104  const auto region1Argument =
105  dynamic_cast<const rvsdg::RegionArgument *>(branchResult[1]->origin());
106 
107  if (region0Argument == nullptr || region1Argument == nullptr)
108  {
109  // This output's operands are not just values that are routed through the gamma.
110  // Nothing can be done
111  continue;
112  }
113 
114  const auto origin0 = region0Argument->input()->origin();
115  const auto origin1 = region1Argument->input()->origin();
116 
117  if (origin0 == origin1)
118  {
119  // Both input operands to the gamma are the same and therefore invariant. No select is needed.
120  gammaOutput->divert_users(origin0);
121  continue;
122  }
123 
124  const auto matchNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*gammaPredicate);
125  if (is<rvsdg::MatchOperation>(matchNode))
126  {
127  const auto matchOperation =
128  util::assertedCast<const rvsdg::MatchOperation>(&matchNode->GetOperation());
129  JLM_ASSERT(matchOperation->nalternatives() == 2);
130  JLM_ASSERT(std::distance(matchOperation->begin(), matchOperation->end()) == 1);
131 
132  const auto matchOrigin = matchNode->input(0)->origin();
133  const auto caseValue = matchOperation->begin()->first;
134  const auto caseSubregion = matchOperation->begin()->second;
135  const auto defaultSubregion = matchOperation->default_alternative();
136  const auto numMatchBits = matchOperation->nbits();
137  JLM_ASSERT(caseSubregion != defaultSubregion);
138 
139  if (numMatchBits == 1 && caseValue == caseSubregion)
140  {
141  // We have an identity mapping:
142  // 1. 0 -> 0, default 1, or
143  // 2. 1 -> 1, default 0
144  // There is no need to insert operations for the select predicate
145  auto & selectNode = rvsdg::CreateOpNode<SelectOperation>(
146  { matchOrigin, origin1, origin0 },
147  gammaOutput->Type());
148  gammaOutput->divert_users(selectNode.output(0));
149  }
150  else
151  {
152  // FIXME: This will recreate the select predicate operations for each gamma output for
153  // which we create a select.
154  auto & constantNode = rvsdg::CreateOpNode<IntegerConstantOperation>(
155  *gammaNode.region(),
156  IntegerValueRepresentation(numMatchBits, caseValue));
157  auto & eqNode = rvsdg::CreateOpNode<IntegerEqOperation>(
158  { constantNode.output(0), matchOrigin },
159  numMatchBits);
160 
161  auto trueAlternative = caseSubregion == 0 ? origin0 : origin1;
162  auto falseAlternative = caseSubregion == 0 ? origin1 : origin0;
163  auto & selectNode = rvsdg::CreateOpNode<SelectOperation>(
164  { eqNode.output(0), trueAlternative, falseAlternative },
165  gammaOutput->Type());
166  gammaOutput->divert_users(selectNode.output(0));
167  }
168  }
169  else
170  {
171  const auto falseAlternative = origin0;
172  const auto trueAlternative = origin1;
173  auto & controlToIntNode = rvsdg::CreateOpNode<ControlToIntOperation>(
174  { gammaPredicate },
177  auto & selectNode = rvsdg::CreateOpNode<SelectOperation>(
178  { controlToIntNode.output(0), trueAlternative, falseAlternative },
179  gammaOutput->Type());
180  gammaOutput->divert_users(selectNode.output(0));
181  }
182  }
183 }
184 
185 }
If-Conversion Transformation statistics.
~IfConversionStatistics() override=default
static std::unique_ptr< IfConversionStatistics > Create(const util::FilePath &sourceFile)
IfConversionStatistics(const util::FilePath &sourceFile)
If-Conversion Transformation.
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static void HandleGammaNode(const rvsdg::GammaNode &gammaNode)
static void HandleRegion(rvsdg::Region &region)
~IfConversion() noexcept override
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const ControlType > Create(std::size_t nalternatives)
Instantiates control type.
Definition: control.cpp:50
Conditional operator / pattern matching.
Definition: gamma.hpp:99
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
Definition: gamma.cpp:361
rvsdg::Input * predicate() const noexcept
Definition: gamma.hpp:398
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Represents the argument of a region.
Definition: region.hpp:41
StructuralInput * input() const noexcept
Definition: region.hpp:69
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
NodeRange Nodes() noexcept
Definition: region.hpp:328
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
size_t nsubregions() const noexcept
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:563
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:134
Statistics(const Statistics::Id &statisticsId, util::FilePath sourceFile)
Definition: Statistics.hpp:73
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:155
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
Global memory state passed between functions.
static const char * Timer
Definition: Statistics.hpp:240