Jlm
ScalarEvolution.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Andreas Lilleby Hjulstad <andreas.lilleby.hjulstad@gmail.com>
3  * See COPYING for terms of redistribution.
4  */
5 
11 #include <jlm/llvm/ir/Trace.hpp>
14 #include <jlm/rvsdg/theta.hpp>
16 #include <jlm/util/Statistics.hpp>
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <queue>
21 
22 namespace jlm::llvm
23 {
24 
26 {
27 public:
28  ~Context() = default;
29 
30  Context() = default;
31 
32  Context(const Context &) = delete;
33 
34  Context(Context &&) = delete;
35 
36  Context &
37  operator=(const Context &) = delete;
38 
39  Context &
40  operator=(Context &&) = delete;
41 
42  void
44  {
45  LoopVars_.insert(&var);
46  }
47 
48  size_t
50  {
51  return LoopVars_.size();
52  }
53 
54  static std::unique_ptr<Context>
56  {
57  return std::make_unique<Context>();
58  }
59 
60  std::unique_ptr<SCEVChainRecurrence>
62  {
63  const auto it = ChrecMap_.find(&output);
64  if (it == ChrecMap_.end() || !it->second)
65  return nullptr;
66 
67  return SCEV::CloneAs<SCEVChainRecurrence>(*it->second);
68  }
69 
70  std::unique_ptr<SCEV>
72  {
73  const auto it = SCEVMap_.find(&output);
74  if (it == SCEVMap_.end() || !it->second)
75  return nullptr;
76 
77  return it->second->Clone();
78  }
79 
80  void
81  InsertChrec(rvsdg::Output & output, const std::unique_ptr<SCEVChainRecurrence> & chrec)
82  {
83  ChrecMap_.insert_or_assign(&output, SCEV::CloneAs<SCEVChainRecurrence>(*chrec));
84  }
85 
86  const std::unordered_map<rvsdg::Output *, std::unique_ptr<SCEVChainRecurrence>> &
87  GetChrecMap() const noexcept
88  {
89  return ChrecMap_;
90  }
91 
92  const std::unordered_map<rvsdg::Output *, std::unique_ptr<SCEV>> &
93  GetSCEVMap() const noexcept
94  {
95  return SCEVMap_;
96  }
97 
98  int
99  GetNumInductionVariablesWithOrder(const size_t n) const
100  {
101  int count = 0;
102  for (auto & [out, chrec] : ChrecMap_)
103  {
104  // Count induction variables (loop variables with a computed recurrence) with specific order
105  if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*out)
106  && out->Type()->Kind() != rvsdg::TypeKind::State)
107  {
108  if (chrec->GetOperands().size() == n + 1 && !IsUnknown(*chrec))
109  count++;
110  }
111  }
112  return count;
113  }
114 
115  size_t
117  {
118  int count = 0;
119  for (auto & [out, chrec] : ChrecMap_)
120  {
121  if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*out)
122  && out->Type()->Kind() != rvsdg::TypeKind::State)
123  {
124  // Only count chrecs that are not unknown
125  if (!IsUnknown(*chrec))
126  count++;
127  }
128  }
129  return count;
130  }
131 
132  void
133  InsertSCEV(rvsdg::Output & output, const std::unique_ptr<SCEV> & scev)
134  {
135  SCEVMap_.insert_or_assign(&output, scev->Clone());
136  }
137 
138  void
140  {
141  NumLoops_++;
142  }
143 
144  size_t
145  GetNumLoops() const
146  {
147  return NumLoops_;
148  }
149 
150  void
151  SetTripCount(const rvsdg::ThetaNode & thetaNode, const size_t tripCount)
152  {
153  TripCountMap_.insert_or_assign(&thetaNode, tripCount);
154  }
155 
156  size_t
157  GetTripCount(const rvsdg::ThetaNode & thetaNode) const
158  {
159  return TripCountMap_.at(&thetaNode);
160  }
161 
162  const std::unordered_map<const rvsdg::ThetaNode *, size_t> &
163  GetTripCountMap() const noexcept
164  {
165  return TripCountMap_;
166  }
167 
168 private:
169  std::unordered_map<rvsdg::Output *, std::unique_ptr<SCEVChainRecurrence>> ChrecMap_;
170  std::unordered_map<rvsdg::Output *, std::unique_ptr<SCEV>> SCEVMap_;
171  std::unordered_map<const rvsdg::ThetaNode *, size_t> TripCountMap_;
172  std::unordered_set<const rvsdg::Output *> LoopVars_;
173 
174  size_t NumLoops_ = 0;
175 };
176 
178 {
179 
180 public:
181  ~Statistics() noexcept override = default;
182 
183  explicit Statistics(const util::FilePath & sourceFile)
184  : util::Statistics(Id::ScalarEvolution, sourceFile)
185  {}
186 
187  void
188  Start() noexcept
189  {
191  }
192 
193  void
194  Stop(const Context & context) noexcept
195  {
197  AddMeasurement(Label::NumTotalInductionVariables, context.GetNumTotalInductionVariables());
200  context.GetNumInductionVariablesWithOrder(0));
203  context.GetNumInductionVariablesWithOrder(1));
206  context.GetNumInductionVariablesWithOrder(2));
207  AddMeasurement(Label::NumLoopVariablesTotal, context.GetNumTotalLoopVars());
208  AddMeasurement(Label::NumLoops, context.GetNumLoops());
209  AddMeasurement(Label::TripCounts, GetTripCountString(context.GetTripCountMap()));
210  }
211 
212  static std::string
213  GetTripCountString(const std::unordered_map<const rvsdg::ThetaNode *, size_t> & tripCountMap)
214  {
215  std::string s = "";
216  bool first = true;
217  for (auto & [thetaNode, tripCount] : tripCountMap)
218  {
219  if (!first)
220  s += ',';
221  first = false;
222 
223  s += "ID(" + std::to_string(thetaNode->subregion()->getRegionId())
224  + ")=" + std::to_string(tripCount);
225  }
226  return s;
227  }
228 
229  static std::unique_ptr<Statistics>
230  Create(const util::FilePath & sourceFile)
231  {
232  return std::make_unique<Statistics>(sourceFile);
233  }
234 };
235 
237  : rvsdg::Transformation("ScalarEvolution")
238 {}
239 
240 ScalarEvolution::~ScalarEvolution() noexcept = default;
241 
242 std::unordered_map<const rvsdg::Output *, std::unique_ptr<SCEVChainRecurrence>>
243 ScalarEvolution::GetChrecMap() const
244 {
245  std::unordered_map<const rvsdg::Output *, std::unique_ptr<SCEVChainRecurrence>> mapCopy{};
246  for (auto & [output, chrec] : Context_->GetChrecMap())
247  {
248  mapCopy.emplace(output, SCEV::CloneAs<SCEVChainRecurrence>(*chrec));
249  }
250  return mapCopy;
251 }
252 
253 std::unordered_map<const rvsdg::Output *, std::unique_ptr<SCEV>>
255 {
256  std::unordered_map<const rvsdg::Output *, std::unique_ptr<SCEV>> mapCopy{};
257  for (auto & [output, scev] : Context_->GetSCEVMap())
258  {
259  mapCopy.emplace(output, scev->Clone());
260  }
261  return mapCopy;
262 }
263 
264 std::unordered_map<const rvsdg::ThetaNode *, size_t>
266 {
267  return Context_->GetTripCountMap();
268 }
269 
270 void
272  rvsdg::RvsdgModule & rvsdgModule,
274 {
275  auto statistics = Statistics::Create(rvsdgModule.SourceFilePath().value());
276  statistics->Start();
277 
279  rvsdg::Region & rootRegion = rvsdgModule.Rvsdg().GetRootRegion();
280  AnalyzeRegion(rootRegion);
282 
283  statistics->Stop(*Context_);
284  statisticsCollector.CollectDemandedStatistics(std::move(statistics));
285 };
286 
287 void
289 {
290  for (auto & node : region.Nodes())
291  {
292  if (auto structuralNode = dynamic_cast<rvsdg::StructuralNode *>(&node))
293  {
294  for (auto & subregion : structuralNode->Subregions())
295  {
296  AnalyzeRegion(subregion);
297  }
298  if (auto thetaNode = dynamic_cast<rvsdg::ThetaNode *>(structuralNode))
299  {
300  Context_->AddLoopToCount();
301  // Add number of loop vars in theta (for statistics)
302  for (const auto loopVar : thetaNode->GetLoopVars())
303  {
304  if (loopVar.pre->Type()->Kind() != rvsdg::TypeKind::State)
305  {
306  // Only add loop variables that are not states
307  Context_->AddLoopVar(*loopVar.pre);
308  }
309  }
310 
311  PerformSCEVAnalysis(*thetaNode);
312 
313  auto tripCount = GetPredictedTripCount(*thetaNode);
314 
315  if (tripCount.has_value())
316  Context_->SetTripCount(*thetaNode, *tripCount);
317  }
318  }
319  }
320 }
321 
322 bool
324 {
325  if (auto constantStep = dynamic_cast<const SCEVConstant *>(&stepSCEV))
326  {
327  return constantStep->GetValue() < 0;
328  }
329  if (auto recurrenceStep = dynamic_cast<const SCEVChainRecurrence *>(&stepSCEV))
330  {
331  JLM_ASSERT(SCEVChainRecurrence::IsAffine(*recurrenceStep));
332 
333  const auto start = dynamic_cast<const SCEVConstant *>(recurrenceStep->GetStartValue());
334  auto stepPtr = recurrenceStep->GetStep();
335  const auto step = dynamic_cast<const SCEVConstant *>(stepPtr->get());
336 
337  if (!start || !step)
338  throw std::logic_error("Step can only contain constant SCEVs!");
339 
340  const auto a = start->GetValue();
341  const auto b = step->GetValue();
342 
343  return a <= 0 && b <= 0 && !(a == 0 && b == 0);
344  }
345  throw std::logic_error("Wrong type for step!");
346 }
347 
348 bool
350 {
351  if (auto constantStep = dynamic_cast<const SCEVConstant *>(&stepSCEV))
352  {
353  return constantStep->GetValue() > 0;
354  }
355  if (auto recurrenceStep = dynamic_cast<const SCEVChainRecurrence *>(&stepSCEV))
356  {
357  JLM_ASSERT(SCEVChainRecurrence::IsAffine(*recurrenceStep));
358 
359  const auto start = dynamic_cast<const SCEVConstant *>(recurrenceStep->GetStartValue());
360  auto stepPtr = recurrenceStep->GetStep();
361  const auto step = dynamic_cast<const SCEVConstant *>(stepPtr->get());
362 
363  if (!start || !step)
364  throw std::logic_error("Step can only contain constant SCEVs!");
365 
366  const auto a = start->GetValue();
367  const auto b = step->GetValue();
368 
369  return a >= 0 && b >= 0 && !(a == 0 && b == 0);
370  }
371  throw std::logic_error("Wrong type for step!");
372 }
373 
374 bool
376 {
377  if (auto constantStep = dynamic_cast<const SCEVConstant *>(&stepSCEV))
378  {
379  return constantStep->GetValue() == 0;
380  }
381  if (auto recurrenceStep = dynamic_cast<const SCEVChainRecurrence *>(&stepSCEV))
382  {
383  JLM_ASSERT(SCEVChainRecurrence::IsAffine(*recurrenceStep));
384 
385  const auto start = dynamic_cast<const SCEVConstant *>(recurrenceStep->GetStartValue());
386  auto stepPtr = recurrenceStep->GetStep();
387  const auto step = dynamic_cast<const SCEVConstant *>(stepPtr->get());
388 
389  if (!start || !step)
390  throw std::logic_error("Step can only contain constant SCEVs!");
391 
392  const auto a = start->GetValue();
393  const auto b = step->GetValue();
394 
395  return a == 0 && b == 0;
396  }
397  throw std::logic_error("Wrong type for step!");
398 }
399 
400 std::optional<size_t>
402 {
403  const auto pred = thetaNode.predicate();
404  const auto & [node, matchOperation] =
405  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::MatchOperation>(*pred->origin());
406  if (!matchOperation)
407  return std::nullopt;
408 
409  JLM_ASSERT(node->ninputs() == 1); // Match node only has 1 input
410 
411  const auto origin = node->input(0)->origin();
412  const auto comparisonNode = rvsdg::TryGetOwnerNode<rvsdg::SimpleNode>(*origin);
413  if (!comparisonNode)
414  return std::nullopt;
415 
416  const auto * comparisonOperation = &comparisonNode->GetOperation();
417  if (!(rvsdg::is<IntegerSltOperation>(*comparisonOperation)
418  || rvsdg::is<IntegerSleOperation>(*comparisonOperation)
419  || rvsdg::is<IntegerUltOperation>(*comparisonOperation)
420  || rvsdg::is<IntegerUleOperation>(*comparisonOperation)
421  || rvsdg::is<IntegerSgtOperation>(*comparisonOperation)
422  || rvsdg::is<IntegerSgeOperation>(*comparisonOperation)
423  || rvsdg::is<IntegerUgtOperation>(*comparisonOperation)
424  || rvsdg::is<IntegerUgeOperation>(*comparisonOperation)
425  || rvsdg::is<IntegerNeOperation>(*comparisonOperation)
426  || rvsdg::is<IntegerEqOperation>(*comparisonOperation)))
427  return std::nullopt;
428 
429  auto & lhs = *comparisonNode->input(0)->origin();
430  auto & rhs = *comparisonNode->input(1)->origin();
431  auto lhsChrec = Context_->TryGetChrecForOutput(lhs);
432  auto rhsChrec = Context_->TryGetChrecForOutput(rhs);
433 
434  if (!lhsChrec)
435  lhsChrec = GetOrCreateChainRecurrence(lhs, *GetOrCreateSCEVForOutput(lhs), thetaNode);
436 
437  if (!rhsChrec)
438  rhsChrec = GetOrCreateChainRecurrence(rhs, *GetOrCreateSCEVForOutput(rhs), thetaNode);
439 
440  int64_t bound = 0;
441  std::unique_ptr<SCEVChainRecurrence> chrec{};
442 
443  if (SCEVChainRecurrence::IsConstant(*lhsChrec))
444  {
445  const auto constantSCEV = dynamic_cast<SCEVConstant *>(lhsChrec->GetOperand(0));
446  if (!constantSCEV)
447  return std::nullopt;
448 
449  bound = constantSCEV->GetValue();
450  chrec = SCEV::CloneAs<SCEVChainRecurrence>(*rhsChrec);
451  }
452  else if (SCEVChainRecurrence::IsConstant(*rhsChrec))
453  {
454  const auto constantSCEV = dynamic_cast<SCEVConstant *>(rhsChrec->GetOperand(0));
455  if (!constantSCEV)
456  return std::nullopt;
457 
458  bound = constantSCEV->GetValue();
459  chrec = SCEV::CloneAs<SCEVChainRecurrence>(*lhsChrec);
460  }
461  else
462  {
463  // None of them are invariant, we can't reliably compute the backedge taken count
464  return std::nullopt;
465  }
466 
468  {
469  // We can only compute the trip count reliably for affine and quadratic recurrences. In other
470  // cases, we return nullopt
471  return std::nullopt;
472  }
473 
474  for (const auto op : chrec->GetOperands())
475  {
476  if (!dynamic_cast<const SCEVConstant *>(op))
477  {
478  // If any of the operands is not a constant, we cannot compute the trip count, and should
479  // return early
480  return std::nullopt;
481  }
482  }
483 
484  const auto start = dynamic_cast<const SCEVConstant *>(chrec->GetStartValue())->GetValue();
485  const auto stepOpt = chrec->GetStep();
486  if (!stepOpt)
487  return std::nullopt;
488 
489  const auto & stepSCEV = **stepOpt;
490 
491  if (rvsdg::is<IntegerSltOperation>(*comparisonOperation)
492  || rvsdg::is<IntegerUltOperation>(*comparisonOperation))
493  {
494  // Trivial case (backedge is not taken and the only iteration is the first one)
495  if (start >= bound)
496  return 1;
497  if (start < bound && IsStepPositive(stepSCEV))
498  {
499  const auto backedgeTakenCount =
500  ComputeBackedgeTakenCountForChrec(*chrec, bound, comparisonOperation);
501  if (backedgeTakenCount.has_value())
502  {
503  // The trip count for a loop is the backedge taken count plus one
504  return *backedgeTakenCount + 1;
505  }
506  }
507  }
508  if (rvsdg::is<IntegerSleOperation>(*comparisonOperation)
509  || rvsdg::is<IntegerUleOperation>(*comparisonOperation))
510  {
511  if (start > bound)
512  return 1;
513  if (start <= bound && IsStepPositive(stepSCEV))
514  {
515  const auto backedgeTakenCount =
516  ComputeBackedgeTakenCountForChrec(*chrec, bound, comparisonOperation);
517  if (backedgeTakenCount.has_value())
518  {
519  return *backedgeTakenCount + 1;
520  }
521  }
522  }
523  if (rvsdg::is<IntegerSgtOperation>(*comparisonOperation)
524  || rvsdg::is<IntegerUgtOperation>(*comparisonOperation))
525  {
526  if (start <= bound)
527  return 1;
528  if (start > bound && IsStepNegative(stepSCEV))
529  {
530  const auto backedgeTakenCount =
531  ComputeBackedgeTakenCountForChrec(*chrec, bound, comparisonOperation);
532  if (backedgeTakenCount.has_value())
533  {
534  return *backedgeTakenCount + 1;
535  }
536  }
537  }
538  if (rvsdg::is<IntegerSgeOperation>(*comparisonOperation)
539  || rvsdg::is<IntegerUgeOperation>(*comparisonOperation))
540  {
541  if (start < bound)
542  return 1;
543  if (start >= bound && IsStepNegative(stepSCEV))
544  {
545  const auto backedgeTakenCount =
546  ComputeBackedgeTakenCountForChrec(*chrec, bound, comparisonOperation);
547  if (backedgeTakenCount.has_value())
548  {
549  return *backedgeTakenCount + 1;
550  }
551  }
552  }
553 
554  if (rvsdg::is<IntegerNeOperation>(*comparisonOperation))
555  {
556  if (SCEVChainRecurrence::IsAffine(*chrec))
557  {
558  // With Ne and Eq comparisons, we only compute non-trivial backedge counts for affine
559  // recurrences as there is no general way to compute it for quadratic recurrences.
560  const auto step = dynamic_cast<const SCEVConstant *>(&stepSCEV)->GetValue();
561  if (IsStepPositive(stepSCEV))
562  {
563  const auto backedgeTakenCount =
564  ComputeBackedgeTakenCountForChrec(*chrec, bound, comparisonOperation);
565  // We need to make sure that it does not pass the bound value (results infinite loop)
566  if (start <= bound && (bound - start) % step == 0)
567  return *backedgeTakenCount + 1;
568  }
569  if (IsStepNegative(stepSCEV))
570  {
571  const auto backedgeTakenCount =
572  ComputeBackedgeTakenCountForChrec(*chrec, bound, comparisonOperation);
573  if (start >= bound && (bound - start) % step == 0)
574  return *backedgeTakenCount + 1;
575  }
576  }
577  if (start == bound)
578  return 1;
579  }
580 
581  if (rvsdg::is<IntegerEqOperation>(*comparisonOperation))
582  {
583  if (start == bound)
584  {
585  if (!IsStepZero(stepSCEV))
586  return 2; // Backedge taken once
587  }
588  else
589  return 1;
590  }
591 
593  {
594  // For quadratic recurrences, if the step is neither positive, negative or zero, we are not able
595  // to accurately compute the trip count.
596  if (!(IsStepPositive(stepSCEV) || IsStepNegative(stepSCEV) || IsStepZero(stepSCEV)))
597  {
598  return std::nullopt;
599  }
600  }
601 
602  // If we have not returned a value at this point, we have an infinite loop.
603  return std::nullopt;
604 }
605 
606 std::optional<size_t>
608  const SCEVChainRecurrence & chrec,
609  const int64_t bound,
610  const rvsdg::SimpleOperation * comparisonOperation)
611 {
612  const auto start = dynamic_cast<const SCEVConstant *>(chrec.GetStartValue())->GetValue();
613  const auto stepOpt = chrec.GetStep();
614  if (!stepOpt)
615  return std::nullopt;
616 
617  const auto & stepSCEV = *stepOpt;
618 
619  bool isEqualsComparison = rvsdg::is<IntegerSleOperation>(*comparisonOperation)
620  || rvsdg::is<IntegerUleOperation>(*comparisonOperation)
621  || rvsdg::is<IntegerSgeOperation>(*comparisonOperation)
622  || rvsdg::is<IntegerUgeOperation>(*comparisonOperation);
623 
624  // Check the size of the step recurrence: 1 -> Affine, 2 -> Quadratic
625  // We can only compute the backedge taken count for these two cases
627  {
628  const auto stepConstant = dynamic_cast<const SCEVConstant *>(stepSCEV.get());
629  const auto step = stepConstant->GetValue();
630 
631  // f(i) = a + b * i
632  // f(i) = k => a + b * i = k => i = (k - a)/b
633  size_t result = std::ceil(static_cast<double>(bound - start) / step);
634 
635  if (isEqualsComparison)
636  {
637  // If we have an equals comparison and the value of the difference between the bound and the
638  // start is a whole multiple of the step size, we get another backedge taken
639  if ((bound - start) % step == 0)
640  result += 1;
641  }
642  return result;
643  }
645  {
646  // Create a quadratic equation for the recurrence {a,+,b,+,c}
647  // The start value is a, and the increments are b, b+c, b+2c, ..., so the accumulated values are
648  // a+b, (a+b)+(b+c), (a+b)+(b+c)+(b+2c), ..., that is,
649  // a+b, a+2b+c, a+3b+3c, ...
650  // After i iterations the value is a + ib + i(i-1)/2 c = f(i).
651  const auto stepRecurrence = dynamic_cast<const SCEVChainRecurrence *>(stepSCEV.get());
652  const int64_t stepFirst =
653  dynamic_cast<const SCEVConstant *>(stepRecurrence->GetStartValue())->GetValue();
654 
655  const int64_t stepSecond =
656  dynamic_cast<const SCEVConstant *>(stepRecurrence->GetStep()->get())->GetValue();
657 
658  // Let f(i) = a + ib + i(i-1)/2 c
659  //
660  // We want to find out when this polynomial is equal to the compare value, i.e. f(i) = k.
661  // This is equivalent with the expression f(i) - k "switching sign" from positive to negative.
662  // Conversely, this is also when the predicate condition will no longer hold.
663  //
664  // The equation f(i) - k = 0 is written as:
665  // a + ib + i(i-1)/2 c - k = 0, or 2(a-k) + 2b i + i(i-1) c = 0.
666  // In a quadratic form it becomes:
667  // c i^2 + (2b - c) i + 2(a - k) = 0.
668  //
669  // We use the quadratic formula to solve this.
670 
671  const int64_t a = stepSecond;
672  const int64_t b = 2 * stepFirst - stepSecond;
673  const int64_t c = 2 * (start - bound);
674 
675  const auto quadraticResult = SolveQuadraticEquation(a, b, c);
676  if (!quadraticResult.has_value())
677  return std::nullopt;
678 
679  size_t result = *quadraticResult;
680 
681  if (isEqualsComparison)
682  {
683  // Same as for affine, but instead of checking using modulo, we evaluate the value at the
684  // result and check
685  const int64_t valueAtResult =
686  start + result * stepFirst + result * (result - 1) / 2 * stepSecond;
687  if (valueAtResult == bound)
688  result += 1;
689  }
690  return result;
691  }
692  return std::nullopt;
693 }
694 
695 std::optional<size_t>
696 ScalarEvolution::SolveQuadraticEquation(int64_t a, int64_t b, int64_t c)
697 {
698  // If a is negative, negate all the coefficients to simplify the math
699  if (a < 0)
700  {
701  a = -a;
702  b = -b;
703  c = -c;
704  }
705 
706  const auto d = b * b - 4 * a * c; // Discriminant
707 
708  if (d < 0)
709  return std::nullopt;
710 
711  // Integer square root of the discriminant
712  int64_t sq = std::floor(std::sqrt(d));
713 
714  // Check if square root is exact
715  const bool inexactSq = (sq * sq != d);
716 
717  // Adjust if sq^2 > discriminant (shouldn't happen with floor, but just to be safe)
718  if (sq * sq > d)
719  sq -= 1;
720 
721  int64_t x = 0;
722  int64_t rem = 0;
723 
724  // The vertex (min/max value) of the parabola f(x) = Ax^2 + Bx + C is at -B/2A. Since A > 0, the
725  // vertex is at a non-positive x location iff B >= 0. In that case the first zero crossing is the
726  // greater root. If B < 0, the vertex is at a positive x location, meaning both roots are positive
727  // and the smaller root is the first crossing.
728  if (b < 0)
729  {
730  // The square root is rounded down, so the roots may be inexact. When using the quadratic
731  // formula, the low root could be greater than the exact one. To make sure this does not happen,
732  // we add 1 if the root is inexact when calculating the low root.
733  x = (-b - (sq + (inexactSq ? 1 : 0))) / (2 * a);
734  rem = (-b - sq) % (2 * a);
735  }
736  else
737  {
738  x = (-b + sq) / (2 * a);
739  rem = (-b + sq) % (2 * a);
740  }
741 
742  // Result should be non-negative
743  if (x < 0)
744  x = 0;
745 
746  // Check for exact solution
747  if (!inexactSq && rem == 0)
748  {
749  return x;
750  }
751 
752  // The exact value of the square root should be between sq and sq + 1
753  // Check for sign change between f(x) and f(x+1)
754  const int64_t valueAtX = (a * x + b) * x + c;
755  const int64_t valueAtXPlusOne = (a * (x + 1) + b) * (x + 1) + c;
756 
757  const bool signChange =
758  ((valueAtX < 0) != (valueAtXPlusOne < 0)) || ((valueAtX == 0) != (valueAtXPlusOne == 0));
759  // Sign did not change, not a valid solution
760  if (!signChange)
761  return std::nullopt;
762 
763  x += 1;
764  return x;
765 }
766 
767 void
769 {
770  bool changed{};
771  do
772  {
773  changed = false;
774 
775  std::vector<std::pair<rvsdg::Output *, std::unique_ptr<SCEV>>> pending;
776  for (auto & [output, chrec] : Context_->GetChrecMap())
777  {
778  if (auto newSCEV = TryReplaceInitForSCEV(*chrec, *output))
779  {
780  pending.emplace_back(output, std::move(*newSCEV));
781  changed = true;
782  }
783  }
784 
785  for (auto & [output, scev] : pending)
786  {
787  // Check if the result is actually a chrec
788  if (auto * chrec = dynamic_cast<SCEVChainRecurrence *>(scev.get()))
789  {
790  Context_->InsertChrec(*output, SCEV::CloneAs<SCEVChainRecurrence>(*chrec));
791  }
792  else
793  {
794  // The transformation produced a non-chrec SCEV (n-ary expression), store it in the SCEV
795  // map instead
796  Context_->InsertSCEV(*output, std::move(scev));
797  }
798  }
799  } while (changed);
800 }
801 
802 std::optional<std::unique_ptr<SCEV>>
804 {
805  if (const auto initSCEV = dynamic_cast<const SCEVInit *>(&scev))
806  {
807  // Found an Init node, find the origin of its input value and get or create its chain
808  // recurrence
809  const auto & initPrePointer = initSCEV->GetPrePointer();
810  if (const auto innerTheta = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(initPrePointer))
811  {
812  const auto correspondingInput = innerTheta->MapPreLoopVar(initPrePointer).input;
813  auto & inputOrigin = rvsdg::traceOutput(*correspondingInput->origin());
814  if (const auto originSCEV = Context_->TryGetSCEVForOutput(inputOrigin))
815  {
816  // We have found a SCEV for the origin of the input, find the corresponding theta node so
817  // we can create a recurrence for it
818  const auto thetaParent = rvsdg::TryGetOwnerNode<rvsdg::ThetaNode>(inputOrigin);
819  const auto outerTheta =
820  thetaParent ? thetaParent
821  : util::assertedCast<rvsdg::ThetaNode>(inputOrigin.region()->node());
822 
823  const auto chrec = GetOrCreateChainRecurrence(inputOrigin, *originSCEV, *outerTheta);
824 
825  // Create a chain recurrence for the SCEV, with the outer theta as the loop
826  return chrec->Clone();
827  }
828  }
829  }
830  if (const auto nArySCEV = dynamic_cast<const SCEVNAryExpr *>(&scev))
831  {
832  // An n-ary scev is any scev with an arbitrary number of operands: chain recurrence, n-ary add
833  // and n-ary mult. We want to recursively check all it's operands for Init nodes
834  auto clone = SCEV::CloneAs<SCEVNAryExpr>(*nArySCEV);
835  const auto operands = nArySCEV->GetOperands();
836  bool changed = false;
837  for (size_t i = 0; i < operands.size(); ++i)
838  {
839  if (auto result = TryReplaceInitForSCEV(*operands[i], output))
840  {
841  if (*result)
842  {
843  // Replace the Init operand with the chrec
844  changed = true;
845  clone->ReplaceOperand(i, std::move(*result));
846  }
847  }
848  }
849  if (!changed)
850  return std::nullopt;
851 
852  if (dynamic_cast<const SCEVChainRecurrence *>(&scev))
853  {
854  // Result is a new chain recurrence, return it
855  return clone;
856  }
857  // If it is an n-ary expression (Add or Mul), we try to fold the operands into themselves,
858  // e.g. if, after replacing Init nodes with recurrences, we have ({0,+,1} + {1,+,2}) in an
859  // n-ary add expression, we can fold this into {1,+,3}.
860  return FoldNAryExpression(*clone, output);
861  }
862  // Default is to just return nothing
863  return std::nullopt;
864 }
865 
866 void
868 {
869  for (const auto loopVar : thetaNode.GetLoopVars())
870  {
871  // In some cases (e.g. with store operations), we still want to create a SCEV tree for the loop
872  // variable even though it is a state variable. However, we still want to filter out state
873  // variables that are purely for scaffolding as they are uninteresting for the analysis.
874  if (loopVar.pre->Type()->Kind() == rvsdg::TypeKind::State
875  && rvsdg::ThetaLoopVarIsInvariant(loopVar))
876  {
877  continue;
878  }
879  const auto post = loopVar.post;
880  // We compute the SCEV for each loop variable in a recursive bottom up fashion,
881  // starting at the post's origin
882  auto scev = GetOrCreateSCEVForOutput(*post->origin());
883  Context_->InsertSCEV(*loopVar.output, scev); // Save the SCEV at the theta outputs as well
884  }
885 
886  auto dependencyGraph = CreateDependencyGraph(thetaNode);
887 
889  for (const auto & [output, deps] : dependencyGraph)
890  {
891  if (CanCreateChainRecurrence(*output, dependencyGraph))
892  validOutputs.insert(output);
893  }
894 
895  // Filter the dependency graph to only contain the outputs of the SCEVs that are valid chain
896  // recurrences and update dependencies accordingly
897  auto filteredDependencyGraph = dependencyGraph;
898  for (auto it = filteredDependencyGraph.begin(); it != filteredDependencyGraph.end();)
899  {
900  if (!validOutputs.Contains(it->first))
901  {
902  for (auto & [node, deps] : filteredDependencyGraph)
903  deps.erase(it->first);
904  it = filteredDependencyGraph.erase(it);
905  }
906  else
907  ++it;
908  }
909 
910  const auto order = TopologicalSort(filteredDependencyGraph);
911 
912  for (auto output : order)
913  {
914  std::unique_ptr<SCEV> scev{};
915  if (const auto theta = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*output);
916  &thetaNode == theta)
917  {
918  // For loop variables, we need to retrieve and use the SCEV saved at the post's origin,
919  // equivalent to a "backedge" which describes how the value at the pre pointer is updated
920  auto & newOutput = *thetaNode.MapPreLoopVar(*output).post->origin();
921  scev = Context_->TryGetSCEVForOutput(newOutput);
922  }
923  else
924  scev = Context_->TryGetSCEVForOutput(*output);
925 
926  JLM_ASSERT(scev);
927 
928  auto chrec = GetOrCreateChainRecurrence(*output, *scev, thetaNode);
929  Context_->InsertChrec(*output, chrec);
930  }
931 
932  for (auto & [output, scev] : Context_->GetSCEVMap())
933  {
934  if (std::find(order.begin(), order.end(), output) == order.end())
935  {
936  auto unknownChainRecurrence =
937  SCEVChainRecurrence::Create(thetaNode, *output, SCEVUnknown::Create());
938  Context_->InsertChrec(*output, unknownChainRecurrence);
939  }
940  }
941 }
942 
943 std::unique_ptr<SCEV>
945 {
946  if (const auto existing = Context_->TryGetSCEVForOutput(output))
947  return existing->Clone();
948 
949  std::unique_ptr<SCEV> result{};
950  if (rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(output))
951  {
952  // We know this is a loop variable, create a placeholder SCEV for now, and compute the
953  // expression later
954  result = SCEVPlaceholder::Create(output);
955  }
956 
957  const auto & [simpleNode, simpleOperation] =
958  rvsdg::TryGetSimpleNodeAndOptionalOp<rvsdg::SimpleOperation>(output);
959 
960  if (simpleNode)
961  {
962  if (rvsdg::is<IOBarrierOperation>(*simpleOperation))
963  {
964  const auto barredInputOrigin = IOBarrierOperation::BarredInput(*simpleNode).origin();
965  result = GetOrCreateSCEVForOutput(*barredInputOrigin);
966  }
967  else if (
968  rvsdg::is<SExtOperation>(*simpleOperation) || rvsdg::is<ZExtOperation>(*simpleOperation))
969  {
970  JLM_ASSERT(simpleNode->ninputs() == 1);
971  result = GetOrCreateSCEVForOutput(*simpleNode->input(0)->origin());
972  }
973  else if (const auto gepOp = dynamic_cast<const GetElementPtrOperation *>(&*simpleOperation))
974  {
975  JLM_ASSERT(simpleNode->ninputs() >= 2);
976  const auto baseIndex = simpleNode->input(0)->origin();
977  JLM_ASSERT(is<PointerType>(baseIndex->Type()));
978 
979  const auto & pointeeType = gepOp->getPointeeType();
980 
981  auto baseScev = GetOrCreateSCEVForOutput(*baseIndex);
982 
983  auto wholeTypeIndex = GetOrCreateSCEVForOutput(*simpleNode->input(1)->origin());
984  const auto wholeTypeSize = GetTypeAllocSize(pointeeType);
985 
986  std::unique_ptr<SCEV> offset =
987  SCEVMulExpr::Create(std::move(wholeTypeIndex), SCEVConstant::Create(wholeTypeSize));
988  if (auto innerOffset = ComputeSCEVForGepInnerOffset(*simpleNode, 2, pointeeType))
989  offset = SCEVAddExpr::Create(std::move(offset), std::move(innerOffset));
990 
991  result = SCEVAddExpr::Create(std::move(baseScev), std::move(offset));
992  }
993  else if (const auto constOp = dynamic_cast<const IntegerConstantOperation *>(&*simpleOperation))
994  {
995  const auto value = constOp->Representation().to_int();
996  result = SCEVConstant::Create(value);
997  }
998  else if (rvsdg::is<IntegerBinaryOperation>(*simpleOperation))
999  {
1000  JLM_ASSERT(simpleNode->ninputs() == 2);
1001  const auto lhs = simpleNode->input(0)->origin();
1002  const auto rhs = simpleNode->input(1)->origin();
1003 
1004  auto lhsScev = GetOrCreateSCEVForOutput(*lhs);
1005  auto rhsScev = GetOrCreateSCEVForOutput(*rhs);
1006  if (rvsdg::is<IntegerAddOperation>(*simpleOperation))
1007  {
1008  result = SCEVAddExpr::Create(std::move(lhsScev), std::move(rhsScev));
1009  }
1010  else if (rvsdg::is<IntegerSubOperation>(*simpleOperation))
1011  {
1012  auto rhsNegativeScev = GetNegativeSCEV(*rhsScev);
1013 
1014  result = SCEVAddExpr::Create(std::move(lhsScev), std::move(rhsNegativeScev));
1015  }
1016  else if (rvsdg::is<IntegerMulOperation>(*simpleOperation))
1017  {
1018  result = SCEVMulExpr::Create(std::move(lhsScev), std::move(rhsScev));
1019  }
1020  else if (rvsdg::is<IntegerShlOperation>(*simpleOperation))
1021  {
1022  if (const auto * rhsConst = dynamic_cast<SCEVConstant *>(rhsScev.get()))
1023  {
1024  const auto shiftAmount = rhsConst->GetValue();
1025  auto factor = SCEVConstant::Create(1ULL << shiftAmount);
1026  result = SCEVMulExpr::Create(std::move(lhsScev), std::move(factor));
1027  }
1028  }
1029  }
1030  else
1031  {
1032  // Unknown operation, we traverse through to it's inputs
1033  for (auto & input : simpleNode->Inputs())
1034  {
1035  GetOrCreateSCEVForOutput(*input.origin());
1036  }
1037  }
1038  }
1039 
1040  if (!result)
1041  // If none of the cases match, return an unknown SCEV expression
1042  result = SCEVUnknown::Create();
1043 
1044  // Save the result in the cache
1045  Context_->InsertSCEV(output, result);
1046 
1047  return result;
1048 }
1049 
1050 std::unique_ptr<SCEV>
1052  const rvsdg::SimpleNode & gepNode,
1053  const size_t inputIndex,
1054  const rvsdg::Type & type)
1055 {
1056  JLM_ASSERT(inputIndex >= 2);
1057 
1058  if (inputIndex >= gepNode.ninputs())
1059  {
1060  return nullptr;
1061  }
1062 
1063  const auto gepInput = gepNode.input(inputIndex);
1064  if (const auto arrayType = dynamic_cast<const ArrayType *>(&type))
1065  {
1066  const auto & elementType = *arrayType->GetElementType();
1067  auto offset = SCEVMulExpr::Create(
1068  GetOrCreateSCEVForOutput(*gepInput->origin()),
1069  SCEVConstant::Create(GetTypeAllocSize(elementType)));
1070 
1071  auto subOffset = ComputeSCEVForGepInnerOffset(gepNode, inputIndex + 1, elementType);
1072 
1073  if (!subOffset)
1074  return offset;
1075 
1076  return SCEVAddExpr::Create(std::move(offset), std::move(subOffset));
1077  }
1078  if (const auto structType = dynamic_cast<const StructType *>(&type))
1079  {
1080  const auto indexingValue = tryGetConstantSignedInteger(*gepInput->origin());
1081 
1082  if (!indexingValue.has_value())
1083  return nullptr;
1084 
1085  const auto & fieldType = structType->getElementType(*indexingValue);
1086 
1087  auto offset = SCEVConstant::Create(structType->GetFieldOffset(*indexingValue));
1088 
1089  auto subOffset = ComputeSCEVForGepInnerOffset(gepNode, inputIndex + 1, *fieldType);
1090 
1091  if (!subOffset)
1092  return offset;
1093 
1094  return SCEVAddExpr::Create(std::move(offset), std::move(subOffset));
1095  }
1096  throw std::logic_error("Unknown GEP type!");
1097 }
1098 
1099 void
1101  const SCEV & scev,
1102  DependencyMap & dependencies,
1103  const DependencyOp op = DependencyOp::None)
1104 {
1105  if (const auto placeholderSCEV = dynamic_cast<const SCEVPlaceholder *>(&scev))
1106  {
1107  auto & dependency = placeholderSCEV->GetPrePointer();
1108  // Retrieves dependency info struct from the map
1109  // In the case where the dependency does not already exist, a new struct is created with the
1110  // default count being 0 and the default operation being None
1111  auto & depInfo = dependencies[&dependency];
1112  depInfo.operation = op;
1113  depInfo.count++;
1114  }
1115 
1116  if (const auto addSCEV = dynamic_cast<const SCEVAddExpr *>(&scev))
1117  {
1118  FindDependenciesForSCEV(*addSCEV->GetLeftOperand(), dependencies, DependencyOp::Add);
1119  FindDependenciesForSCEV(*addSCEV->GetRightOperand(), dependencies, DependencyOp::Add);
1120  }
1121 
1122  if (const auto mulSCEV = dynamic_cast<const SCEVMulExpr *>(&scev))
1123  {
1124  FindDependenciesForSCEV(*mulSCEV->GetLeftOperand(), dependencies, DependencyOp::Mul);
1125  FindDependenciesForSCEV(*mulSCEV->GetRightOperand(), dependencies, DependencyOp::Mul);
1126  }
1127 }
1128 
1131 {
1132  DependencyGraph graph{};
1133 
1134  for (const auto & [output, scev] : Context_->GetSCEVMap())
1135  {
1136  DependencyMap dependencies{};
1137  if (const auto theta = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(*output);
1138  theta == &thetaNode)
1139  {
1140  // We know this is a pre pointer, so we map it to loop var and use the SCEV for the
1141  // post's origin (backedge) instead
1142  const auto loopVar = theta->MapPreLoopVar(*output);
1143  auto newScev = Context_->TryGetSCEVForOutput(*loopVar.post->origin());
1144 
1145  FindDependenciesForSCEV(*newScev.get(), dependencies);
1146  }
1147  else
1148  FindDependenciesForSCEV(*scev.get(), dependencies);
1149 
1150  graph[output] = dependencies;
1151  }
1152  return graph;
1153 }
1154 
1155 // Implementation of Kahn's algorithm for topological sort
1156 std::vector<rvsdg::Output *>
1158 {
1159  const size_t numVertices = dependencyGraph.size();
1160  std::unordered_map<const rvsdg::Output *, int> indegree(numVertices);
1161  std::queue<rvsdg::Output *> q{};
1162  for (auto & [node, deps] : dependencyGraph)
1163  {
1164  for (auto & dep : deps)
1165  {
1166  if (const auto ptr = dep.first; ptr == node)
1167  continue; // Ignore self-edges
1168  // To begin with, the indegree is just the number of incoming edges
1169  indegree[node] += 1;
1170  }
1171  if (indegree[node] == 0)
1172  {
1173  // Add nodes with no incoming edges to the queue, we know that these have no dependencies
1174  q.push(node);
1175  }
1176  }
1177 
1178  std::vector<rvsdg::Output *> result{};
1179  while (!q.empty())
1180  {
1181  rvsdg::Output * currentNode = q.front();
1182  q.pop();
1183  result.push_back(currentNode);
1184 
1185  for (const auto & [node, deps] : dependencyGraph)
1186  {
1187  if (node == currentNode)
1188  continue;
1189 
1190  for (const auto & dep : deps)
1191  {
1192  const auto ptr = dep.first;
1193  if (ptr == node)
1194  continue; // Skip self-edges
1195  if (ptr == currentNode)
1196  {
1197  // Update the indegree of nodes depending on this one
1198  indegree[node] -= 1;
1199  if (indegree[node] == 0)
1200  q.push(node);
1201  }
1202  }
1203  }
1204  }
1205  JLM_ASSERT(result.size() == numVertices);
1206  return result;
1207 }
1208 
1209 std::unique_ptr<SCEVChainRecurrence>
1211  rvsdg::Output & output,
1212  const SCEV & scev,
1213  rvsdg::ThetaNode & thetaNode)
1214 {
1215  if (const auto existing = Context_->TryGetChrecForOutput(output))
1216  {
1217  return SCEV::CloneAs<SCEVChainRecurrence>(*existing);
1218  }
1219 
1220  auto stepRecurrence = GetOrCreateStepForSCEV(output, scev, thetaNode);
1221 
1222  if (const auto theta = rvsdg::TryGetRegionParentNode<rvsdg::ThetaNode>(output);
1223  theta == &thetaNode)
1224  {
1225  // Find the start value for the recurrence
1226  const auto inputOrigin = thetaNode.MapPreLoopVar(output).input->origin();
1227  if (const auto constantInteger = tryGetConstantSignedInteger(*inputOrigin))
1228  {
1229  // If the input value is a constant, create a SCEV representation and set it as start
1230  // value (first operand in rec)
1231  stepRecurrence->AddOperandToFront(SCEVConstant::Create(*constantInteger));
1232  }
1233  else
1234  {
1235  // If not, create a SCEVInit node representing the start value
1236  stepRecurrence->AddOperandToFront(SCEVInit::Create(output));
1237  }
1238  }
1239  return stepRecurrence;
1240 }
1241 
1242 std::unique_ptr<SCEVChainRecurrence>
1244  rvsdg::Output & output,
1245  const SCEV & scevTree,
1246  rvsdg::ThetaNode & thetaNode)
1247 {
1248  if (const auto scevConstant = dynamic_cast<const SCEVConstant *>(&scevTree))
1249  {
1250  // This is a constant, we add it as the only operand
1251  return SCEVChainRecurrence::Create(thetaNode, output, scevConstant->Clone());
1252  }
1253  if (const auto scevPlaceholder = dynamic_cast<const SCEVPlaceholder *>(&scevTree))
1254  {
1255  if (&scevPlaceholder->GetPrePointer() == &output)
1256  {
1257  // Since we are only interested in the step value, and not the initial value, we can ignore
1258  // ourselves by returning an empty chain recurrence (treated as the identity element - 0 for
1259  // addition and 1 for multiplication)
1260  return SCEVChainRecurrence::Create(thetaNode, output);
1261  }
1262  if (auto storedRec = Context_->TryGetChrecForOutput(scevPlaceholder->GetPrePointer()))
1263  {
1264  // We have a dependency of another IV
1265  // Get it's saved value. This is safe to do due to the topological ordering
1266  return storedRec;
1267  }
1268  return SCEVChainRecurrence::Create(thetaNode, output, SCEVUnknown::Create());
1269  }
1270  if (const auto scevAddExpr = dynamic_cast<const SCEVAddExpr *>(&scevTree))
1271  {
1272  const auto lhsStep = GetOrCreateStepForSCEV(output, *scevAddExpr->GetLeftOperand(), thetaNode);
1273  const auto rhsStep = GetOrCreateStepForSCEV(output, *scevAddExpr->GetRightOperand(), thetaNode);
1274 
1275  return SCEV::CloneAs<SCEVChainRecurrence>(
1276  *ApplyAddFolding(lhsStep.get(), rhsStep.get(), output));
1277  }
1278  if (const auto scevMulExpr = dynamic_cast<const SCEVMulExpr *>(&scevTree))
1279  {
1280  const auto lhsStep = GetOrCreateStepForSCEV(output, *scevMulExpr->GetLeftOperand(), thetaNode);
1281  const auto rhsStep = GetOrCreateStepForSCEV(output, *scevMulExpr->GetRightOperand(), thetaNode);
1282 
1283  return SCEV::CloneAs<SCEVChainRecurrence>(
1284  *ApplyMulFolding(lhsStep.get(), rhsStep.get(), output));
1285  }
1286  return SCEVChainRecurrence::Create(thetaNode, output, SCEVUnknown::Create());
1287 }
1288 
1289 std::unique_ptr<SCEV>
1291 {
1292  // In some cases, we end up with an n-ary expression like (1 + Init(a1) + 2).
1293  // This method folds the constant operands, turning it into (3 + Init(a1)).
1294  bool folded{};
1295  do
1296  {
1297  folded = false;
1298  for (size_t i = 0; i < expression.NumOperands(); ++i)
1299  {
1300  std::vector<SCEV *> ops = expression.GetOperands();
1301  if (dynamic_cast<const SCEVInit *>(ops[i]))
1302  continue; // Cannot fold init
1303  for (size_t j = i + 1; j < expression.NumOperands(); ++j)
1304  {
1305  if (dynamic_cast<const SCEVInit *>(ops[j]))
1306  continue;
1307 
1308  // Both are foldable (constants or recurrences) fold them according to the rules
1309  std::unique_ptr<SCEV> foldedOperand{};
1310  if (dynamic_cast<SCEVNAryAddExpr *>(&expression))
1311  {
1312  foldedOperand = ApplyAddFolding(ops[i], ops[j], output);
1313  }
1314  else if (dynamic_cast<SCEVNAryMulExpr *>(&expression))
1315  {
1316  foldedOperand = ApplyMulFolding(ops[i], ops[j], output);
1317  }
1318  else
1319  {
1320  throw std::logic_error("Invalid n-ary SCEV expression type in FoldNAryExpression!");
1321  }
1322  expression.RemoveOperand(j);
1323  expression.ReplaceOperand(i, foldedOperand);
1324  folded = true;
1325  break;
1326  }
1327  if (folded)
1328  break;
1329  }
1330  } while (folded);
1331 
1332  if (expression.NumOperands() == 1)
1333  {
1334  // If there is only one operand in the n-ary expression, we just return the operand
1335  return expression.GetOperand(0)->Clone();
1336  }
1337 
1338  return expression.Clone();
1339 }
1340 
1341 std::unique_ptr<SCEV>
1342 ScalarEvolution::ApplyAddFolding(SCEV * lhsOperand, SCEV * rhsOperand, rvsdg::Output & output)
1343 {
1344  // We have the following folding rules from the CR algebra:
1345  // G + {e,+,f} => {G + e,+,f} (1)
1346  // {e,+,f} + {g,+,h} => {e + g,+,f + h} (2)
1347  //
1348  // And by generalizing rule 2, we have that:
1349  // {G,+,0} + {e,+,f} = {G + e,+,0 + f} = {G + e,+,f}
1350  //
1351  // Since we represent constants in the SCEVTree as recurrences consisting of only a SCEVConstant
1352  // node, we can therefore pad the constant recurrence with however many zeroes we need for the
1353  // length of the other recurrence. This effectively lets us apply both rules in one go.
1354  //
1355  // For constants and unknowns this is trivial, however it becomes a bit complicated when we
1356  // factor in SCEVInit nodes. These nodes represent the initial value of an IV in the case where
1357  // the exact value is unknown at compile time. E.g. function argument or result from a
1358  // call-instruction. In the cases where we have to fold one or more of these init-nodes, we
1359  // create an n-ary add expression (add expression with an arbitrary number of operands), and add
1360  // this to the chrec. Folding two of these n-ary add expressions will result in another n-ary
1361  // add expression, which consists of all the operands in both the left and the right expression.
1362 
1363  // The if-chain below goes through each of the possible combinations of lhs and rhs values
1364  if (const auto *lhsUnknown = dynamic_cast<const SCEVUnknown *>(lhsOperand),
1365  *rhsUnknown = dynamic_cast<const SCEVUnknown *>(rhsOperand);
1366  lhsUnknown || rhsUnknown)
1367  {
1368  // If one of the sides is unknown. Return unknown
1369  return SCEVUnknown::Create();
1370  }
1371 
1372  auto lhsChrec = dynamic_cast<SCEVChainRecurrence *>(lhsOperand);
1373  auto rhsChrec = dynamic_cast<SCEVChainRecurrence *>(rhsOperand);
1374  if (lhsChrec && rhsChrec)
1375  {
1376  if (&lhsChrec->GetLoop() != &rhsChrec->GetLoop())
1377  {
1379  lhsChrec->GetLoop(),
1380  output,
1381  SCEVNAryAddExpr::Create(lhsChrec->Clone(), rhsChrec->Clone()));
1382  }
1383 
1384  auto newChrec = SCEVChainRecurrence::Create(lhsChrec->GetLoop(), output);
1385  const auto lhsSize = lhsChrec->NumOperands();
1386  const auto rhsSize = rhsChrec->NumOperands();
1387  for (size_t i = 0; i < std::max(lhsSize, rhsSize); ++i)
1388  {
1389  SCEV * lhs{};
1390  SCEV * rhs{};
1391  if (i < lhsSize)
1392  lhs = lhsChrec->GetOperand(i);
1393 
1394  if (i < rhsSize)
1395  rhs = rhsChrec->GetOperand(i);
1396  newChrec->AddOperand(ApplyAddFolding(lhs, rhs, output));
1397  }
1398  return newChrec;
1399  }
1400 
1401  // Chrec + any other operand
1402  // This handles Init, Constant, and any other SCEV type uniformly
1403  if (lhsChrec || rhsChrec)
1404  {
1405  auto * chrec = lhsChrec ? lhsChrec : rhsChrec;
1406  auto * otherOperand = lhsChrec ? rhsOperand : lhsOperand;
1407 
1408  // Skip if otherOperand is zero constant (identity for addition)
1409  if (const auto constant = dynamic_cast<const SCEVConstant *>(otherOperand))
1410  {
1411  if (!SCEVConstant::IsNonZero(constant))
1412  {
1413  return chrec->Clone();
1414  }
1415  }
1416  auto newChrec = SCEVChainRecurrence::Create(chrec->GetLoop(), output);
1417  const auto chrecOperands = chrec->GetOperands();
1418 
1419  bool isFirst = true;
1420  for (const auto operand : chrecOperands)
1421  {
1422  if (isFirst)
1423  {
1424  // Recursively fold the start value with the other operand
1425  newChrec->AddOperand(ApplyAddFolding(operand, otherOperand, output));
1426  isFirst = false;
1427  }
1428  else
1429  {
1430  newChrec->AddOperand(operand->Clone());
1431  }
1432  }
1433  return newChrec;
1434  }
1435 
1436  const auto lhsNAryMulExpr = dynamic_cast<const SCEVNAryMulExpr *>(lhsOperand);
1437  const auto rhsNAryMulExpr = dynamic_cast<const SCEVNAryMulExpr *>(rhsOperand);
1438  // Handle n-ary multiply expressions - they become terms in an n-ary add expression
1439  if (lhsNAryMulExpr && rhsNAryMulExpr)
1440  {
1441  // Two multiply expressions - create add expression with both
1442  return SCEVNAryAddExpr::Create(lhsNAryMulExpr->Clone(), rhsNAryMulExpr->Clone());
1443  }
1444 
1445  const auto lhsNAryAddExpr = dynamic_cast<const SCEVNAryAddExpr *>(lhsOperand);
1446  const auto rhsNAryAddExpr = dynamic_cast<const SCEVNAryAddExpr *>(rhsOperand);
1447  if ((lhsNAryMulExpr && rhsNAryAddExpr) || (rhsNAryMulExpr && lhsNAryAddExpr))
1448  {
1449  // Multiply expression with add expression - Clone the add expression and add the multiply as
1450  // a term
1451  const auto * mulExpr = lhsNAryMulExpr ? lhsNAryMulExpr : rhsNAryMulExpr;
1452  auto * addExpr = lhsNAryAddExpr ? lhsNAryAddExpr : rhsNAryAddExpr;
1453  auto newAddExpr = SCEV::CloneAs<SCEVNAryExpr>(*addExpr);
1454  newAddExpr->AddOperand(mulExpr->Clone());
1455  return newAddExpr->Clone();
1456  }
1457 
1458  const auto lhsInit = dynamic_cast<const SCEVInit *>(lhsOperand);
1459  const auto rhsInit = dynamic_cast<const SCEVInit *>(rhsOperand);
1460  if ((lhsNAryMulExpr && rhsInit) || (rhsNAryMulExpr && lhsInit))
1461  {
1462  // Multiply expression with init - create add expression
1463  const auto * mulExpr = lhsNAryMulExpr ? lhsNAryMulExpr : rhsNAryMulExpr;
1464  const auto * init = lhsInit ? lhsInit : rhsInit;
1465  return SCEVNAryAddExpr::Create(mulExpr->Clone(), init->Clone());
1466  }
1467 
1468  const auto lhsConstant = dynamic_cast<SCEVConstant *>(lhsOperand);
1469  const auto rhsConstant = dynamic_cast<SCEVConstant *>(rhsOperand);
1470  if ((lhsNAryMulExpr && SCEVConstant::IsNonZero(rhsConstant))
1471  || (rhsNAryMulExpr && SCEVConstant::IsNonZero(lhsConstant)))
1472  {
1473  // Multiply expression with nonzero constant - create add expression
1474  const auto * mulExpr = lhsNAryMulExpr ? lhsNAryMulExpr : rhsNAryMulExpr;
1475  const auto * constant = lhsConstant ? lhsConstant : rhsConstant;
1476  return SCEVNAryAddExpr::Create(mulExpr->Clone(), constant->Clone());
1477  }
1478 
1479  if (lhsNAryMulExpr || rhsNAryMulExpr)
1480  {
1481  // Single multiply expression, no folding necessary
1482  const auto * mulExpr = lhsNAryMulExpr ? lhsNAryMulExpr : rhsNAryMulExpr;
1483  return mulExpr->Clone();
1484  }
1485 
1486  if (lhsInit && rhsInit)
1487  {
1488  // We have two init nodes. Create a nAryAdd with lhsInit and rhsInit
1489  return SCEVNAryAddExpr::Create(lhsInit->Clone(), rhsInit->Clone());
1490  }
1491 
1492  if ((lhsInit && rhsNAryAddExpr) || (rhsInit && lhsNAryAddExpr))
1493  {
1494  // We have an init and an add expr. Clone the add expression and add the init as an operand
1495  const auto * init = lhsInit ? lhsInit : rhsInit;
1496  auto * nAryAddExpr = lhsNAryAddExpr ? lhsNAryAddExpr : rhsNAryAddExpr;
1497  auto newAddExpr = SCEV::CloneAs<SCEVNAryAddExpr>(*nAryAddExpr);
1498  newAddExpr->AddOperand(init->Clone());
1499  return newAddExpr->Clone();
1500  }
1501 
1502  if ((lhsInit && SCEVConstant::IsNonZero(rhsConstant))
1503  || (rhsInit && SCEVConstant::IsNonZero(lhsConstant)))
1504  {
1505  // We have an init and a nonzero constant. Create a nAryAdd with init and constant
1506  const auto * init = lhsInit ? lhsInit : rhsInit;
1507  const auto * constant = lhsConstant ? lhsConstant : rhsConstant;
1508  return SCEVNAryAddExpr::Create(init->Clone(), constant->Clone());
1509  }
1510 
1511  if (lhsInit || rhsInit)
1512  {
1513  // Only one operand. Add it
1514  const auto * init = lhsInit ? lhsInit : rhsInit;
1515  return init->Clone();
1516  }
1517 
1518  if (lhsNAryAddExpr && rhsNAryAddExpr)
1519  {
1520  // We have two add expressions. Clone the lhs and add the rhs operands
1521  auto lhsNewNAryAddExpr = SCEV::CloneAs<SCEVNAryAddExpr>(*lhsNAryAddExpr);
1522  for (auto op : rhsNAryAddExpr->GetOperands())
1523  {
1524  lhsNewNAryAddExpr->AddOperand(op->Clone());
1525  }
1526  return lhsNewNAryAddExpr;
1527  }
1528 
1529  if ((lhsNAryAddExpr && SCEVConstant::IsNonZero(rhsConstant))
1530  || (rhsNAryAddExpr && SCEVConstant::IsNonZero(lhsConstant)))
1531  {
1532  // We have an add expr and a nonzero constant. Clone the add expr and add the constant
1533  auto * nAryAddExpr = lhsNAryAddExpr ? lhsNAryAddExpr : rhsNAryAddExpr;
1534  auto * constant = lhsConstant ? lhsConstant : rhsConstant;
1535  auto newNAryAddExpr = SCEV::CloneAs<SCEVNAryAddExpr>(*nAryAddExpr);
1536 
1537  // Check if there is already a constant operand in the n-ary expression
1538  // If so, fold the new constant with the old one instead of adding it as an operand
1539  bool folded = false;
1540  for (size_t i = 0; i < newNAryAddExpr->NumOperands(); ++i)
1541  {
1542  if (auto existingConstant = dynamic_cast<SCEVConstant *>(newNAryAddExpr->GetOperands()[i]))
1543  {
1544  // Fold the two constants together directly
1545  auto foldedConstant = ApplyAddFolding(existingConstant, constant, output);
1546  newNAryAddExpr->ReplaceOperand(i, foldedConstant);
1547  folded = true;
1548  break;
1549  }
1550  }
1551 
1552  if (!folded)
1553  {
1554  // No existing constant to fold with, just append
1555  newNAryAddExpr->AddOperand(constant->Clone());
1556  }
1557 
1558  return newNAryAddExpr;
1559  }
1560 
1561  if (lhsNAryAddExpr || rhsNAryAddExpr)
1562  {
1563  const auto * nAryAddExpr = lhsNAryAddExpr ? lhsNAryAddExpr : rhsNAryAddExpr;
1564  return nAryAddExpr->Clone();
1565  }
1566  if (lhsConstant && rhsConstant)
1567  {
1568  // Two constants, get their value, and combine them (fold)
1569  const auto lhsValue = lhsConstant->GetValue();
1570  const auto rhsValue = rhsConstant->GetValue();
1571 
1572  return SCEVConstant::Create(lhsValue + rhsValue);
1573  }
1574 
1575  if (lhsConstant || rhsConstant)
1576  {
1577  const auto * constant = lhsConstant ? lhsConstant : rhsConstant;
1578  return constant->Clone();
1579  }
1580 
1581  return SCEVUnknown::Create();
1582 }
1583 
1584 std::unique_ptr<SCEVChainRecurrence>
1586  SCEVChainRecurrence * lhsChrec,
1587  SCEVChainRecurrence * rhsChrec,
1588  rvsdg::Output & output)
1589 {
1590  const auto lhsSize = lhsChrec->NumOperands();
1591  const auto rhsSize = rhsChrec->NumOperands();
1592 
1593  if (rhsSize == 0)
1594  return SCEV::CloneAs<SCEVChainRecurrence>(*lhsChrec);
1595  if (lhsSize == 0)
1596  return SCEV::CloneAs<SCEVChainRecurrence>(*rhsChrec);
1597 
1598  // Handle G * {e,+,f,...} where G is loop invariant
1599  if (lhsSize == 1)
1600  {
1601  auto newChrec = SCEVChainRecurrence::Create(lhsChrec->GetLoop(), output);
1602  // G * {e,+,f,...} = {G * e,+,G * f,...}
1603  auto lhs = lhsChrec->GetOperand(0);
1604 
1605  for (auto rhs : rhsChrec->GetOperands())
1606  {
1607  newChrec->AddOperand(ApplyMulFolding(lhs, rhs, output));
1608  }
1609  return newChrec;
1610  }
1611  if (rhsSize == 1)
1612  {
1613  auto newChrec = SCEVChainRecurrence::Create(lhsChrec->GetLoop(), output);
1614  // {e,+,f,...} * G = {e * G,+,f * G,...}
1615  auto rhs = rhsChrec->GetOperand(0);
1616 
1617  for (auto lhs : lhsChrec->GetOperands())
1618  {
1619  newChrec->AddOperand(ApplyMulFolding(lhs, rhs, output));
1620  }
1621  return newChrec;
1622  }
1623 
1624  // Below is an implementation of the algorithm CRProd from Bachmann et al., ‘Chains of recurrences
1625  // — a method to expedite the evaluation of closed-form functions’
1626  // (https://doi.org/10.1145/190347.190423)
1627  //
1628  // Let lhs = F, rhs = G be CR’s of length k and l.
1629  //
1630  // The product of F and G, S = F * G can be constructed using the algorithm CRProd given below
1631  //
1632  // Algorithm CRProd: Let
1633  // F = {a0, +, a1, +, ..., +, ak} and G = {b0, +, b1, +, ..., +, bl}
1634  // with k ≥ l. This algorithm returns a simple CR S of length k + l such that F * G = S.
1635  //
1636  // P1 [Base case]
1637  // If l = 1 return {a0*b0, +, a1*b0, +, ..., +, ak*b0}
1638  //
1639  // P2 [Prepare recursive calls] Let
1640  // f = {a1, +, a2, +, ..., +, ak}
1641  // g = {b1, +, b2, +, ..., +, bl}
1642  //
1643  // G' = G + g = {b0 + b1, +, b1 + b2, +, ..., +, bl}
1644  //
1645  // P3 [Recursive calls] Set
1646  // {x1'', +, x2'', +, ..., +, x(k+l)''} ← CRProd(F, g)
1647  // {x1', +, x2', +, ..., +, x(k+l)'} ← CRProd(f, G')
1648  //
1649  // P4 [Fold the results together and return]
1650  // return {a0*b0, +, x1' + x1'', +, ..., +, x(k+l)' + x(k+l)''}
1651 
1652  JLM_ASSERT(lhsSize >= 2);
1653  JLM_ASSERT(rhsSize >= 2);
1654 
1655  if (rhsSize > lhsSize)
1656  std::swap(lhsChrec, rhsChrec);
1657 
1658  std::unique_ptr<SCEVChainRecurrence> lhsStepRecurrence, rhsStepRecurrence;
1659 
1660  const auto lhsStep = *lhsChrec->GetStep();
1661  if (!lhsStep)
1662  {
1663  // This should not happen since we check size above
1664  throw std::logic_error("Could not get step for LHS in ComputeProductOfChrecs!");
1665  }
1666 
1667  const auto rhsStep = *rhsChrec->GetStep();
1668  if (!rhsStep)
1669  {
1670  throw std::logic_error("Could not get step for RHS in ComputeProductOfChrecs!");
1671  }
1672 
1673  if (dynamic_cast<SCEVChainRecurrence *>(lhsStep.get()))
1674  lhsStepRecurrence = SCEV::CloneAs<SCEVChainRecurrence>(*lhsStep);
1675  else
1676  lhsStepRecurrence = SCEVChainRecurrence::Create(lhsChrec->GetLoop(), output, lhsStep->Clone());
1677 
1678  if (dynamic_cast<SCEVChainRecurrence *>(rhsStep.get()))
1679  rhsStepRecurrence = SCEV::CloneAs<SCEVChainRecurrence>(*rhsStep);
1680  else
1681  rhsStepRecurrence = SCEVChainRecurrence::Create(rhsChrec->GetLoop(), output, rhsStep->Clone());
1682 
1683  const auto rhsMarked = SCEV::CloneAs<SCEVChainRecurrence>(
1684  *ApplyAddFolding(rhsChrec, rhsStepRecurrence.get(), output));
1685 
1686  const auto res1 = ComputeProductOfChrecs(lhsChrec, rhsStepRecurrence.get(), output);
1687  const auto res2 = ComputeProductOfChrecs(rhsMarked.get(), lhsStepRecurrence.get(), output);
1688 
1689  auto resFolded =
1690  SCEV::CloneAs<SCEVChainRecurrence>(*ApplyAddFolding(res1.get(), res2.get(), output));
1691 
1692  const auto first = ApplyMulFolding(lhsChrec->GetOperand(0), rhsChrec->GetOperand(0), output);
1693  resFolded->AddOperandToFront(first);
1694 
1695  return resFolded;
1696 }
1697 
1698 std::unique_ptr<SCEV>
1699 ScalarEvolution::ApplyMulFolding(SCEV * lhsOperand, SCEV * rhsOperand, rvsdg::Output & output)
1700 {
1701  // We have the following folding rules from the CR algebra:
1702  // G * {e,+,f} => {G * e,+,G * f}
1703  // {e,+,f} * {g,+,h} => {e * g,+,e * h + f * g + f * h,+,2*f*h}
1704  //
1705  // Similar to addition, we need to handle SCEVInit nodes and n-ary expressions.
1706  // For multiplication with init nodes, we create n-ary multiply expressions.
1707 
1708  if (const auto *lhsUnknown = dynamic_cast<const SCEVUnknown *>(lhsOperand),
1709  *rhsUnknown = dynamic_cast<const SCEVUnknown *>(rhsOperand);
1710  lhsUnknown || rhsUnknown)
1711  {
1712  return SCEVUnknown::Create();
1713  }
1714 
1715  auto lhsChrec = dynamic_cast<SCEVChainRecurrence *>(lhsOperand);
1716  auto rhsChrec = dynamic_cast<SCEVChainRecurrence *>(rhsOperand);
1717  if (lhsChrec && rhsChrec)
1718  {
1719  if (&lhsChrec->GetLoop() != &rhsChrec->GetLoop())
1720  {
1722  lhsChrec->GetLoop(),
1723  output,
1724  SCEVNAryMulExpr::Create(lhsChrec->Clone(), rhsChrec->Clone()));
1725  }
1726 
1727  return ComputeProductOfChrecs(lhsChrec, rhsChrec, output);
1728  }
1729 
1730  // Chrec * any other operand
1731  // This handles Init, Constant, and any other SCEV type uniformly
1732  if (lhsChrec || rhsChrec)
1733  {
1734  auto * chrec = lhsChrec ? lhsChrec : rhsChrec;
1735  auto * otherOperand = lhsChrec ? rhsOperand : lhsOperand;
1736 
1737  if (auto constant = dynamic_cast<const SCEVConstant *>(otherOperand))
1738  {
1739  if (constant->GetValue() == 1)
1740  {
1741  // Dont fold if operand is constant one (identity for multiplication)
1742  return chrec->Clone();
1743  }
1744 
1745  if (constant->GetValue() == 0)
1746  {
1747  // Fold to zero
1748  return SCEVConstant::Create(0);
1749  }
1750  }
1751  auto newChrec = SCEVChainRecurrence::Create(chrec->GetLoop(), output);
1752  const auto chrecOperands = chrec->GetOperands();
1753 
1754  for (auto & operand : chrecOperands)
1755  {
1756  // Recursively fold the start value with the other operand
1757  newChrec->AddOperand(ApplyMulFolding(operand, otherOperand, output));
1758  }
1759  return newChrec;
1760  }
1761 
1762  const auto lhsNAryAddExpr = dynamic_cast<const SCEVNAryAddExpr *>(lhsOperand);
1763  const auto rhsNAryAddExpr = dynamic_cast<const SCEVNAryAddExpr *>(rhsOperand);
1764  if (lhsNAryAddExpr || rhsNAryAddExpr)
1765  {
1766  // Handle n-ary add expressions - distribute multiplication
1767  // (a + b + c) × G = a×G + b×G + c×G
1768  const auto nAryAddExpr = lhsNAryAddExpr ? lhsNAryAddExpr : rhsNAryAddExpr;
1769  const auto other = lhsNAryAddExpr ? rhsOperand : lhsOperand;
1770 
1771  auto resultAddExpr = SCEVNAryAddExpr::Create();
1772  for (auto operand : nAryAddExpr->GetOperands())
1773  {
1774  auto product = ApplyMulFolding(operand, other, output);
1775  resultAddExpr->AddOperand(std::move(product));
1776  }
1777  return resultAddExpr;
1778  }
1779 
1780  const auto lhsInit = dynamic_cast<const SCEVInit *>(lhsOperand);
1781  const auto rhsInit = dynamic_cast<const SCEVInit *>(rhsOperand);
1782  if (lhsInit && rhsInit)
1783  {
1784  // Two init nodes - create n-ary multiply expression
1785  return SCEVNAryMulExpr::Create(lhsInit->Clone(), rhsInit->Clone());
1786  }
1787 
1788  const auto lhsNAryMulExpr = dynamic_cast<const SCEVNAryMulExpr *>(lhsOperand);
1789  const auto rhsNAryMulExpr = dynamic_cast<const SCEVNAryMulExpr *>(rhsOperand);
1790  if ((lhsInit && rhsNAryMulExpr) || (rhsInit && lhsNAryMulExpr))
1791  {
1792  // Init node with n-ary multiply expression - Clone mult expr and add init as an operand
1793  const auto * init = lhsInit ? lhsInit : rhsInit;
1794  auto * nAryMulExpr = lhsNAryMulExpr ? lhsNAryMulExpr : rhsNAryMulExpr;
1795  auto newNAryMulExpr = SCEV::CloneAs<SCEVNAryMulExpr>(*nAryMulExpr);
1796  newNAryMulExpr->AddOperand(init->Clone());
1797  return newNAryMulExpr->Clone();
1798  }
1799 
1800  auto lhsConstant = dynamic_cast<SCEVConstant *>(lhsOperand);
1801  auto rhsConstant = dynamic_cast<SCEVConstant *>(rhsOperand);
1802  if ((lhsInit && rhsConstant && rhsConstant->GetValue() != 1)
1803  || (rhsInit && lhsConstant && lhsConstant->GetValue() != 1))
1804  {
1805  // Init node with non-one constant - create n-ary multiply expression
1806  const auto * init = lhsInit ? lhsInit : rhsInit;
1807  const auto * constant = lhsConstant ? lhsConstant : rhsConstant;
1808  return SCEVNAryMulExpr::Create(init->Clone(), constant->Clone());
1809  }
1810 
1811  if (lhsInit || rhsInit)
1812  {
1813  // Single init node, no folding necessary
1814  const auto * init = lhsInit ? lhsInit : rhsInit;
1815  return init->Clone();
1816  }
1817 
1818  if (lhsNAryMulExpr && rhsNAryMulExpr)
1819  {
1820  // Two n-ary mult expressions - combine operands
1821  auto lhsNewNAryMulExpr = SCEV::CloneAs<SCEVNAryMulExpr>(*lhsNAryMulExpr);
1822  for (auto op : rhsNAryMulExpr->GetOperands())
1823  {
1824  lhsNewNAryMulExpr->AddOperand(op->Clone());
1825  }
1826  return lhsNewNAryMulExpr;
1827  }
1828 
1829  if ((lhsNAryMulExpr && rhsConstant && rhsConstant->GetValue() != 1)
1830  || (rhsNAryMulExpr && lhsConstant && lhsConstant->GetValue() != 1))
1831  {
1832  // N-ary mult expression with non-one constant - Clone mult expression and add constant
1833  auto * nAryMulExpr = lhsNAryMulExpr ? lhsNAryMulExpr : rhsNAryMulExpr;
1834  auto * constant = lhsConstant ? lhsConstant : rhsConstant;
1835 
1836  auto newNAryMulExpr = SCEV::CloneAs<SCEVNAryMulExpr>(*nAryMulExpr);
1837 
1838  bool folded = false;
1839  for (size_t i = 0; i < newNAryMulExpr->NumOperands(); ++i)
1840  {
1841  if (auto existingConstant = dynamic_cast<SCEVConstant *>(newNAryMulExpr->GetOperands()[i]))
1842  {
1843  // Fold the two constants together directly
1844  auto foldedConstant = ApplyMulFolding(existingConstant, constant, output);
1845  newNAryMulExpr->ReplaceOperand(i, foldedConstant);
1846  folded = true;
1847  break;
1848  }
1849  }
1850 
1851  if (!folded)
1852  {
1853  // No existing constant to fold with, just append
1854  newNAryMulExpr->AddOperand(constant->Clone());
1855  }
1856 
1857  return newNAryMulExpr;
1858  }
1859 
1860  if (lhsNAryMulExpr || rhsNAryMulExpr)
1861  {
1862  const auto * nAryMulExpr = lhsNAryMulExpr ? lhsNAryMulExpr : rhsNAryMulExpr;
1863  return nAryMulExpr->Clone();
1864  }
1865 
1866  if (lhsConstant && rhsConstant)
1867  {
1868  // Two constants - fold by multiplying values together
1869  const auto lhsValue = lhsConstant->GetValue();
1870  const auto rhsValue = rhsConstant->GetValue();
1871  return SCEVConstant::Create(lhsValue * rhsValue);
1872  }
1873 
1874  if (lhsConstant || rhsConstant)
1875  {
1876  const auto * constant = lhsConstant ? lhsConstant : rhsConstant;
1877  return constant->Clone();
1878  }
1879 
1880  return SCEVUnknown::Create();
1881 }
1882 
1883 std::unique_ptr<SCEV>
1885 {
1886  // -(c)
1887  if (const auto c = dynamic_cast<const SCEVConstant *>(&scev))
1888  {
1889  const auto value = c->GetValue();
1890  return SCEVConstant::Create(-value);
1891  }
1892  // -(-x) -> x
1893  if (const auto mul = dynamic_cast<const SCEVMulExpr *>(&scev))
1894  {
1895  if (const auto c = dynamic_cast<const SCEVConstant *>(mul->GetLeftOperand());
1896  c && c->GetValue() == -1)
1897  {
1898  return mul->GetRightOperand()->Clone();
1899  }
1900  if (const auto c = dynamic_cast<const SCEVConstant *>(mul->GetRightOperand());
1901  c && c->GetValue() == -1)
1902  {
1903  return mul->GetLeftOperand()->Clone();
1904  }
1905  } // -(x + y) -> (-x) + (-y)
1906  if (const auto add = dynamic_cast<const SCEVAddExpr *>(&scev))
1907  {
1908  return SCEVAddExpr::Create(
1909  GetNegativeSCEV(*add->GetLeftOperand()),
1910  GetNegativeSCEV(*add->GetRightOperand()));
1911  }
1912  // General case: -(x) -> (-1) * x
1913  return SCEVMulExpr::Create(SCEVConstant::Create(-1), scev.Clone());
1914 }
1915 
1916 bool
1918 {
1919  auto deps = dependencyGraph[&output];
1920  if (deps.find(&output) != deps.end())
1921  {
1922  if (deps[&output].count != 1)
1923  {
1924  // First check that variable has only one self-reference
1925  return false;
1926  }
1927  if (deps[&output].operation == DependencyOp::Mul)
1928  {
1929  // A variable cannot have a self-depencency via multiplication (results in a geometric
1930  // induction variable)
1931  return false;
1932  }
1933  }
1934 
1935  // Then check for cycles through other variables
1936  std::unordered_set<const rvsdg::Output *> visited{};
1937  std::unordered_set<const rvsdg::Output *> recursionStack{};
1938  return !HasCycleThroughOthers(output, output, dependencyGraph, visited, recursionStack);
1939 }
1940 
1941 bool
1943  rvsdg::Output & currentOutput,
1944  const rvsdg::Output & originalOutput,
1945  DependencyGraph & dependencyGraph,
1946  std::unordered_set<const rvsdg::Output *> & visited,
1947  std::unordered_set<const rvsdg::Output *> & recursionStack)
1948 {
1949  visited.insert(&currentOutput);
1950  recursionStack.insert(&currentOutput);
1951 
1952  for (const auto & [depPtr, depCount] : dependencyGraph[&currentOutput])
1953  {
1954  // Ignore self-references
1955  if (depPtr == &currentOutput)
1956  continue;
1957 
1958  // Found a cycle back to the ORIGINAL node we started from
1959  // This means the original output is explicitly part of the cycle
1960  if (depPtr == &originalOutput)
1961  return true;
1962 
1963  // Already explored this branch, no cycle containing the original output
1964  if (visited.find(depPtr) != visited.end())
1965  continue;
1966 
1967  // Recursively check dependencies, keeping track of the original node
1968  if (HasCycleThroughOthers(*depPtr, originalOutput, dependencyGraph, visited, recursionStack))
1969  return true;
1970  }
1971 
1972  recursionStack.erase(&currentOutput);
1973  return false;
1974 }
1975 
1976 bool
1978 {
1979  if (dynamic_cast<const SCEVUnknown *>(&scev))
1980  return true;
1981 
1982  if (dynamic_cast<const SCEVInit *>(&scev) || dynamic_cast<const SCEVConstant *>(&scev)
1983  || dynamic_cast<const SCEVPlaceholder *>(&scev))
1984  {
1985  return false;
1986  }
1987 
1988  if (auto * binaryExpr = dynamic_cast<const SCEVBinaryExpr *>(&scev))
1989  {
1990  return IsUnknown(*binaryExpr->GetLeftOperand()) || IsUnknown(*binaryExpr->GetLeftOperand());
1991  }
1992 
1993  if (auto * nAryExpr = dynamic_cast<const SCEVNAryExpr *>(&scev))
1994  {
1995  for (const auto operand : nAryExpr->GetOperands())
1996  {
1997  if (IsUnknown(*operand))
1998  return true;
1999  }
2000  return false;
2001  }
2002 
2003  throw std::logic_error("Invalid SCEV type in IsUnknown!\n");
2004 }
2005 
2006 bool
2008 {
2009  if (typeid(a) != typeid(b))
2010  return false;
2011 
2012  if (dynamic_cast<const SCEVUnknown *>(&a))
2013  return true;
2014 
2015  if (auto * constantA = dynamic_cast<const SCEVConstant *>(&a))
2016  {
2017  auto * constantB = dynamic_cast<const SCEVConstant *>(&b);
2018  return constantA->GetValue() == constantB->GetValue();
2019  }
2020 
2021  if (auto * initA = dynamic_cast<const SCEVInit *>(&a))
2022  {
2023  auto * initB = dynamic_cast<const SCEVInit *>(&b);
2024  return &initA->GetPrePointer() == &initB->GetPrePointer();
2025  }
2026 
2027  if (auto * binaryExprA = dynamic_cast<const SCEVBinaryExpr *>(&a))
2028  {
2029  auto * binaryExprB = dynamic_cast<const SCEVBinaryExpr *>(&b);
2030  return StructurallyEqual(*binaryExprA->GetLeftOperand(), *binaryExprB->GetLeftOperand())
2031  && StructurallyEqual(*binaryExprA->GetRightOperand(), *binaryExprB->GetRightOperand());
2032  }
2033 
2034  if (auto * chrecA = dynamic_cast<const SCEVChainRecurrence *>(&a))
2035  {
2036  auto * chrecB = dynamic_cast<const SCEVChainRecurrence *>(&b);
2037  if (&chrecA->GetLoop() != &chrecB->GetLoop())
2038  return false;
2039  if (&chrecA->GetOutput() != &chrecB->GetOutput())
2040  return false;
2041  if (chrecA->NumOperands() != chrecB->NumOperands())
2042  return false;
2043  for (size_t i = 0; i < chrecA->NumOperands(); ++i)
2044  {
2045  if (!StructurallyEqual(*chrecA->GetOperands()[i], *chrecB->GetOperands()[i]))
2046  return false;
2047  }
2048  return true;
2049  }
2050 
2051  if (auto * nAryExprA = dynamic_cast<const SCEVNAryExpr *>(&a))
2052  {
2053  auto * nAryExprB = dynamic_cast<const SCEVNAryExpr *>(&b);
2054  if (nAryExprA->NumOperands() != nAryExprB->NumOperands())
2055  return false;
2056  for (size_t i = 0; i < nAryExprA->NumOperands(); ++i)
2057  {
2058  if (!StructurallyEqual(*nAryExprA->GetOperands()[i], *nAryExprB->GetOperands()[i]))
2059  return false;
2060  }
2061  return true;
2062  }
2063 
2064  return false;
2065 }
2066 }
static jlm::util::StatisticsCollector statisticsCollector
static rvsdg::Input & BarredInput(const rvsdg::SimpleNode &node) noexcept
Definition: IOBarrier.hpp:70
static std::unique_ptr< SCEVAddExpr > Create(std::unique_ptr< SCEV > left, std::unique_ptr< SCEV > right)
rvsdg::ThetaNode & GetLoop() const noexcept
static bool IsQuadratic(const SCEVChainRecurrence &chrec)
static bool IsConstant(const SCEVChainRecurrence &chrec)
std::optional< std::unique_ptr< SCEV > > GetStep() const
static std::unique_ptr< SCEVChainRecurrence > Create(rvsdg::ThetaNode &loop, rvsdg::Output &output)
static bool IsAffine(const SCEVChainRecurrence &chrec)
static bool IsNonZero(const SCEVConstant *c)
static std::unique_ptr< SCEVConstant > Create(const int64_t value)
rvsdg::Output & GetPrePointer() const noexcept
static std::unique_ptr< SCEVInit > Create(rvsdg::Output &prePointer)
static std::unique_ptr< SCEVMulExpr > Create(std::unique_ptr< SCEV > left, std::unique_ptr< SCEV > right)
static std::unique_ptr< SCEVNAryAddExpr > Create(Args &&... operands)
SCEV * GetOperand(const size_t index) const
void RemoveOperand(const size_t index)
std::vector< SCEV * > GetOperands() const
void ReplaceOperand(const size_t index, const std::unique_ptr< SCEV > &operand)
static std::unique_ptr< SCEVNAryMulExpr > Create(Args &&... operands)
static std::unique_ptr< SCEVPlaceholder > Create(rvsdg::Output &PrePointer_)
static std::unique_ptr< SCEVUnknown > Create()
virtual std::unique_ptr< SCEV > Clone() const =0
size_t GetTripCount(const rvsdg::ThetaNode &thetaNode) const
std::unordered_map< const rvsdg::ThetaNode *, size_t > TripCountMap_
const std::unordered_map< rvsdg::Output *, std::unique_ptr< SCEVChainRecurrence > > & GetChrecMap() const noexcept
const std::unordered_map< rvsdg::Output *, std::unique_ptr< SCEV > > & GetSCEVMap() const noexcept
std::unique_ptr< SCEVChainRecurrence > TryGetChrecForOutput(rvsdg::Output &output) const
std::unordered_set< const rvsdg::Output * > LoopVars_
Context & operator=(const Context &)=delete
const std::unordered_map< const rvsdg::ThetaNode *, size_t > & GetTripCountMap() const noexcept
void SetTripCount(const rvsdg::ThetaNode &thetaNode, const size_t tripCount)
void InsertSCEV(rvsdg::Output &output, const std::unique_ptr< SCEV > &scev)
std::unordered_map< rvsdg::Output *, std::unique_ptr< SCEVChainRecurrence > > ChrecMap_
Context(const Context &)=delete
void AddLoopVar(const rvsdg::Output &var)
std::unordered_map< rvsdg::Output *, std::unique_ptr< SCEV > > SCEVMap_
Context & operator=(Context &&)=delete
std::unique_ptr< SCEV > TryGetSCEVForOutput(rvsdg::Output &output) const
static std::unique_ptr< Context > Create()
void InsertChrec(rvsdg::Output &output, const std::unique_ptr< SCEVChainRecurrence > &chrec)
int GetNumInductionVariablesWithOrder(const size_t n) const
static std::unique_ptr< Statistics > Create(const util::FilePath &sourceFile)
~Statistics() noexcept override=default
static std::string GetTripCountString(const std::unordered_map< const rvsdg::ThetaNode *, size_t > &tripCountMap)
void Stop(const Context &context) noexcept
std::unordered_map< rvsdg::Output *, DependencyInfo > DependencyMap
std::unordered_map< rvsdg::Output *, DependencyMap > DependencyGraph
static std::unique_ptr< SCEVChainRecurrence > ComputeProductOfChrecs(SCEVChainRecurrence *lhsChrec, SCEVChainRecurrence *rhsChrec, rvsdg::Output &output)
std::optional< std::unique_ptr< SCEV > > TryReplaceInitForSCEV(const SCEV &scev, rvsdg::Output &output)
std::unique_ptr< SCEVChainRecurrence > GetOrCreateStepForSCEV(rvsdg::Output &output, const SCEV &scevTree, rvsdg::ThetaNode &thetaNode)
void PerformSCEVAnalysis(rvsdg::ThetaNode &thetaNode)
void Run(rvsdg::RvsdgModule &rvsdgModule, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static void FindDependenciesForSCEV(const SCEV &scev, DependencyMap &dependencies, DependencyOp op)
static bool IsStepZero(const SCEV &stepSCEV)
static std::unique_ptr< SCEV > ApplyMulFolding(SCEV *lhsOperand, SCEV *rhsOperand, rvsdg::Output &output)
Apply folding rules for multiplication to combine two SCEV operands into one.
static std::unique_ptr< SCEV > FoldNAryExpression(SCEVNAryExpr &expression, rvsdg::Output &output)
Try to combine the constants in an n-ary expression (Add or Mul) into themselves.
static std::optional< size_t > SolveQuadraticEquation(int64_t a, int64_t b, int64_t c)
Tries to find a solution to the quadratic equation a^2 x + b x + c = 0 using integer arithmetic.
DependencyGraph CreateDependencyGraph(const rvsdg::ThetaNode &thetaNode) const
static std::unique_ptr< SCEV > ApplyAddFolding(SCEV *lhsOperand, SCEV *rhsOperand, rvsdg::Output &output)
Apply folding rules for addition to combine two SCEV operands into one.
static bool HasCycleThroughOthers(rvsdg::Output &currentOutput, const rvsdg::Output &originalOutput, DependencyGraph &dependencyGraph, std::unordered_set< const rvsdg::Output * > &visited, std::unordered_set< const rvsdg::Output * > &recursionStack)
~ScalarEvolution() noexcept override
std::unique_ptr< Context > Context_
static bool IsStepPositive(const SCEV &stepSCEV)
static bool CanCreateChainRecurrence(rvsdg::Output &output, DependencyGraph &dependencyGraph)
std::unique_ptr< SCEV > ComputeSCEVForGepInnerOffset(const rvsdg::SimpleNode &gepNode, size_t inputIndex, const rvsdg::Type &type)
std::optional< size_t > GetPredictedTripCount(rvsdg::ThetaNode &thetaNode)
static bool StructurallyEqual(const SCEV &a, const SCEV &b)
static bool IsUnknown(const SCEV &scev)
std::unordered_map< const rvsdg::Output *, std::unique_ptr< SCEV > > GetSCEVMap() const
static bool IsStepNegative(const SCEV &stepSCEV)
static std::vector< rvsdg::Output * > TopologicalSort(DependencyGraph &dependencyGraph)
std::unique_ptr< SCEV > GetOrCreateSCEVForOutput(rvsdg::Output &output)
std::unique_ptr< SCEVChainRecurrence > GetOrCreateChainRecurrence(rvsdg::Output &output, const SCEV &scev, rvsdg::ThetaNode &thetaNode)
std::unordered_map< const rvsdg::ThetaNode *, size_t > GetTripCountMap() const noexcept
static std::unique_ptr< SCEV > GetNegativeSCEV(const SCEV &scev)
void AnalyzeRegion(rvsdg::Region &region)
static std::optional< size_t > ComputeBackedgeTakenCountForChrec(const SCEVChainRecurrence &chrec, int64_t bound, const rvsdg::SimpleOperation *comparisonOperation)
StructType class.
Definition: types.hpp:184
Region & GetRootRegion() const noexcept
Definition: graph.hpp:99
Output * origin() const noexcept
Definition: node.hpp:58
size_t ninputs() const noexcept
Definition: node.hpp:609
Represent acyclic RVSDG subgraphs.
Definition: region.hpp:213
NodeRange Nodes() noexcept
Definition: region.hpp:328
const std::optional< util::FilePath > & SourceFilePath() const noexcept
Definition: RvsdgModule.hpp:73
Graph & Rvsdg() noexcept
Definition: RvsdgModule.hpp:57
NodeInput * input(size_t index) const noexcept
Definition: simple-node.hpp:82
RegionResult * predicate() const noexcept
Definition: theta.hpp:85
LoopVar MapPreLoopVar(const rvsdg::Output &argument) const
Maps variable at start of loop iteration to full varibale description.
Definition: theta.cpp:140
std::vector< LoopVar > GetLoopVars() const
Returns all loop variables.
Definition: theta.cpp:176
bool insert(ItemType item)
Definition: HashSet.hpp:210
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
Definition: Statistics.hpp:574
Statistics Interface.
Definition: Statistics.hpp:31
util::Timer & GetTimer(const std::string &name)
Definition: Statistics.cpp:137
util::Timer & AddTimer(std::string name)
Definition: Statistics.cpp:158
void AddMeasurement(std::string name, T value)
Definition: Statistics.hpp:177
void start() noexcept
Definition: time.hpp:54
void stop() noexcept
Definition: time.hpp:67
#define JLM_ASSERT(x)
Definition: common.hpp:16
Global memory state passed between functions.
size_t GetTypeAllocSize(const rvsdg::Type &type)
Definition: types.cpp:473
std::optional< int64_t > tryGetConstantSignedInteger(const rvsdg::Output &output)
Definition: Trace.cpp:62
static bool ThetaLoopVarIsInvariant(const ThetaNode::LoopVar &loopVar) noexcept
Definition: theta.hpp:227
Output & traceOutput(Output &output, const rvsdg::Region *withinRegion)
Definition: Trace.cpp:292
@ State
Designate a state type.
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
Definition: node.hpp:1049
rvsdg::Input * input
Variable at loop entry (input to theta).
Definition: theta.hpp:54
rvsdg::Input * post
Variable after iteration (output result from subregion).
Definition: theta.hpp:62
static const char * TripCounts
Definition: Statistics.hpp:242
static const char * NumFirstOrderInductionVariables
Definition: Statistics.hpp:239
static const char * NumConstantInductionVariables
Definition: Statistics.hpp:238
static const char * NumTotalInductionVariables
Definition: Statistics.hpp:237
static const char * Timer
Definition: Statistics.hpp:251
static const char * NumSecondOrderInductionVariables
Definition: Statistics.hpp:240
static const char * NumLoopVariablesTotal
Definition: Statistics.hpp:236
static const char * NumLoops
Definition: Statistics.hpp:241