Jlm
IfConversion.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Nico Reißmann <nico.reissmann@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
9 #include <jlm/rvsdg/gamma.hpp>
10 #include <jlm/rvsdg/region.hpp>
12 
13 namespace jlm::llvm
14 {
15 
20 {
21 public:
22  ~IfConversionStatistics() override = default;
23 
24  explicit IfConversionStatistics(const util::FilePath & sourceFile)
25  : Statistics(Id::IfConversion, sourceFile)
26  {}
27 
28  void
29  Start() noexcept
30  {
32  }
33 
34  void
35  Stop() noexcept
36  {
38  }
39 
40  static std::unique_ptr<IfConversionStatistics>
41  Create(const util::FilePath & sourceFile)
42  {
43  return std::make_unique<IfConversionStatistics>(sourceFile);
44  }
45 };
46 
47 IfConversion::~IfConversion() noexcept = default;
48 
50  : Transformation("IfConversion")
51 {}
52 
53 void
55 {
56  auto statistics = IfConversionStatistics::Create(module.SourceFilePath().value());
57 
58  statistics->Start();
59  HandleRegion(module.Rvsdg().GetRootRegion());
60  statistics->Stop();
61 
62  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
63 }
64 
65 void
67 {
68  for (auto & node : region.Nodes())
69  {
70  if (const auto gammaNode = dynamic_cast<const rvsdg::GammaNode *>(&node))
71  {
72  HandleGammaNode(*gammaNode);
73  }
74  else if (const auto structuralNode = dynamic_cast<const rvsdg::StructuralNode *>(&node))
75  {
76  for (size_t n = 0; n < structuralNode->nsubregions(); n++)
77  {
78  const auto subregion = structuralNode->subregion(n);
79  HandleRegion(*subregion);
80  }
81  }
82  else if (is<rvsdg::SimpleOperation>(&node))
83  {
84  // Nothing needs to be done
85  }
86  else
87  {
88  JLM_UNREACHABLE("Unsupported node type.");
89  }
90  }
91 }
92 
93 void
95 {
96  if (gammaNode.nsubregions() != 2)
97  return;
98 
99  const auto gammaPredicate = gammaNode.predicate()->origin();
100  for (auto [branchResult, gammaOutput] : gammaNode.GetExitVars())
101  {
102  const auto region0Argument =
103  dynamic_cast<const rvsdg::RegionArgument *>(branchResult[0]->origin());
104  const auto region1Argument =
105  dynamic_cast<const rvsdg::RegionArgument *>(branchResult[1]->origin());
106 
107  if (region0Argument == nullptr || region1Argument == nullptr)
108  {
109  // This output's operands are not just values that are routed through the gamma.
110  // Nothing can be done
111  continue;
112  }
113 
114  const auto origin0 = region0Argument->input()->origin();
115  const auto origin1 = region1Argument->input()->origin();
116 
117  if (origin0 == origin1)
118  {
119  // Both input operands to the gamma are the same and therefore invariant. No select is needed.
120  gammaOutput->divert_users(origin0);
121  continue;
122  }
123 
124  auto [matchNode, matchOperation] =
125  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::MatchOperation>(*gammaPredicate);
126  if (matchOperation)
127  {
128  if (std::distance(matchOperation->begin(), matchOperation->end()) != 1)
129  {
130  // FIXME: This case could actually be handled if we wanted to by just making the condition
131  // to the select to be a "not-equal default value".
132 
133  // FIXME: We would NOT like to perform the match checks for each exit variable. This can be
134  // done once at the beginning.
135 
136  // The match operation has multiple alternatives that map to a single subregion.
137  // Nothing can be done.
138  continue;
139  }
140 
141  const auto matchOrigin = matchNode->input(0)->origin();
142  const auto caseValue = matchOperation->begin()->first;
143  const auto caseSubregion = matchOperation->begin()->second;
144  const auto defaultSubregion = matchOperation->default_alternative();
145  const auto numMatchBits = matchOperation->nbits();
146  JLM_ASSERT(caseSubregion != defaultSubregion);
147 
148  if (numMatchBits == 1 && caseValue == caseSubregion)
149  {
150  // We have an identity mapping:
151  // 1. 0 -> 0, default 1, or
152  // 2. 1 -> 1, default 0
153  // There is no need to insert operations for the select predicate
154  auto & selectNode = rvsdg::CreateOpNode<SelectOperation>(
155  { matchOrigin, origin1, origin0 },
156  gammaOutput->Type());
157  gammaOutput->divert_users(selectNode.output(0));
158  }
159  else
160  {
161  // FIXME: This will recreate the select predicate operations for each gamma output for
162  // which we create a select.
163  auto & constantNode = rvsdg::CreateOpNode<IntegerConstantOperation>(
164  *gammaNode.region(),
165  IntegerValueRepresentation(numMatchBits, caseValue));
166  auto & eqNode = rvsdg::CreateOpNode<IntegerEqOperation>(
167  { constantNode.output(0), matchOrigin },
168  numMatchBits);
169 
170  auto trueAlternative = caseSubregion == 0 ? origin0 : origin1;
171  auto falseAlternative = caseSubregion == 0 ? origin1 : origin0;
172  auto & selectNode = rvsdg::CreateOpNode<SelectOperation>(
173  { eqNode.output(0), trueAlternative, falseAlternative },
174  gammaOutput->Type());
175  gammaOutput->divert_users(selectNode.output(0));
176  }
177  }
178  else
179  {
180  const auto falseAlternative = origin0;
181  const auto trueAlternative = origin1;
182  auto & controlToIntNode = rvsdg::CreateOpNode<ControlToIntOperation>(
183  { gammaPredicate },
186  auto & selectNode = rvsdg::CreateOpNode<SelectOperation>(
187  { controlToIntNode.output(0), trueAlternative, falseAlternative },
188  gammaOutput->Type());
189  gammaOutput->divert_users(selectNode.output(0));
190  }
191  }
192 }
193 
194 }
static jlm::util::StatisticsCollector statisticsCollector
If-Conversion Transformation statistics.
~IfConversionStatistics() override=default
static std::unique_ptr< IfConversionStatistics > Create(const util::FilePath &sourceFile)
IfConversionStatistics(const util::FilePath &sourceFile)
If-Conversion Transformation.
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static void HandleGammaNode(const rvsdg::GammaNode &gammaNode)
static void HandleRegion(rvsdg::Region &region)
~IfConversion() noexcept override
static std::shared_ptr< const BitType > Create(std::size_t nbits)
Creates bit type of specified width.
Definition: type.cpp:45
static std::shared_ptr< const ControlType > Create(std::size_t nalternatives)
Instantiates control type.
Definition: control.cpp:50
Conditional operator / pattern matching.
Definition: gamma.hpp:99
std::vector< ExitVar > GetExitVars() const
Gets all exit variables for this gamma.
Definition: gamma.cpp:381
rvsdg::Input * predicate() const noexcept
Definition: gamma.hpp:487
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
rvsdg::Region * region() const noexcept
Definition: node.hpp:761
void divert_users(jlm::rvsdg::Output *new_origin)
Definition: node.hpp:301
Represents the argument of a region.
Definition: region.hpp:41
StructuralInput * input() const noexcept
Definition: region.hpp:69
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
NodeRange Nodes() noexcept
Definition: region.hpp:328
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
size_t nsubregions() const noexcept
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:574
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:137
Statistics(const Statistics::Id &statisticsId, util::FilePath sourceFile)
Definition: Statistics.hpp:76
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:158
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
#define JLM_UNREACHABLE(msg)
Definition: common.hpp:43
Global memory state passed between functions.
static const char * Timer
Definition: Statistics.hpp:251