52 const size_t numAggregateAllocaNodes,
53 const size_t numAggregateStructAllocaNodes,
54 const size_t numSplitableTypeAggregateAllocaNodes,
55 const size_t numSplitAggregateAllocaNodes)
62 numSplitableTypeAggregateAllocaNodes);
66 static std::unique_ptr<Statistics>
69 return std::make_unique<Statistics>(std::move(filePath));
102 const auto structType =
dynamic_cast<const StructType *
>(&type);
106 for (
const auto & elementType : structType->elementTypes())
118 std::optional<AggregateAllocaSplitting::AllocaTraceInfo>
121 [[maybe_unused]]
auto allocaOperation =
131 std::deque<rvsdg::Output *> toVisit{ &address };
136 toVisit.push_back(&output);
140 auto removeFromVisitSet = [&]()
142 const auto output = toVisit.front();
149 const auto currentOutput = removeFromVisitSet();
151 for (
auto & user : currentOutput->Users())
168 auto & gammaOutput = gammaNode.mapBranchResultToOutput(user);
169 addToVisitSet(gammaOutput);
174 const auto loopVar = thetaNode.MapPostLoopVar(user);
175 addToVisitSet(*loopVar.pre);
176 addToVisitSet(*loopVar.output);
185 throw std::logic_error(util::strfmt(
186 "Unhandled owner region node type: ",
187 userRegion->node()->DebugString()));
192 else if (
auto userNode = rvsdg::TryGetOwnerNode<rvsdg::Node>(user))
198 auto roleVar = gammaNode.
MapInput(user);
199 if (
auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
201 for (
auto argument : entryVar->branchArgument)
203 addToVisitSet(*argument);
208 throw std::logic_error(
util::strfmt(
"Unhandled role variable."));
215 const auto loopVar = thetaNode.MapInputLoopVar(user);
216 addToVisitSet(*loopVar.pre);
221 auto & operation = simpleNode.GetOperation();
227 if (
const auto indicesOpt =
229 !indicesOpt.has_value())
242 throw std::logic_error(
243 util::strfmt(
"Unhandled node type: ", userNode->DebugString()));
250 throw std::logic_error(
"Unhandled owner type");
264 return std::make_optional(allocaTraceInfo);
270 [[maybe_unused]]
auto gepOperation =
272 auto & address = *gepNode.
output(0);
274 bool hasOnlyLoadsAndStores =
true;
277 std::deque<rvsdg::Output *> toVisit{ &address };
282 toVisit.push_back(&output);
286 auto removeFromVisitSet = [&]()
288 const auto output = toVisit.front();
293 while (!toVisit.empty() && hasOnlyLoadsAndStores)
295 const auto currentOutput = removeFromVisitSet();
296 for (
auto & user : currentOutput->Users())
298 if (!hasOnlyLoadsAndStores)
313 auto & gammaOutput = gammaNode.mapBranchResultToOutput(user);
314 addToVisitSet(gammaOutput);
319 const auto loopVar = thetaNode.MapPostLoopVar(user);
320 addToVisitSet(*loopVar.pre);
321 addToVisitSet(*loopVar.output);
330 throw std::logic_error(util::strfmt(
331 "Unhandled owner region node type: ",
332 userRegion->node()->DebugString()));
337 else if (
auto userNode = rvsdg::TryGetOwnerNode<rvsdg::Node>(user))
343 auto roleVar = gammaNode.
MapInput(user);
344 if (
auto entryVar = std::get_if<rvsdg::GammaNode::EntryVar>(&roleVar))
346 for (
auto argument : entryVar->branchArgument)
348 addToVisitSet(*argument);
353 throw std::logic_error(
util::strfmt(
"Unhandled role variable."));
360 const auto loopVar = thetaNode.MapInputLoopVar(user);
361 addToVisitSet(*loopVar.pre);
366 auto & operation = simpleNode.GetOperation();
382 addToVisitSet(*simpleNode.output(0));
392 throw std::logic_error(
393 util::strfmt(
"Unhandled node type: ", userNode->DebugString()));
400 throw std::logic_error(
"Unhandled owner type");
405 return hasOnlyLoadsAndStores;
408 std::vector<AggregateAllocaSplitting::AllocaTraceInfo>
411 std::function<void(
rvsdg::Region &, std::vector<AllocaTraceInfo> &)> findAllocaNodes =
412 [&](
rvsdg::Region & region, std::vector<AllocaTraceInfo> & traceInfo)
414 for (
auto & node : region.
Nodes())
420 for (
auto & subregion : gammaNode.
Subregions())
421 findAllocaNodes(subregion, traceInfo);
425 findAllocaNodes(*thetaNode.subregion(), traceInfo);
429 findAllocaNodes(*lambdaNode.subregion(), traceInfo);
433 findAllocaNodes(*phiNode.subregion(), traceInfo);
441 const auto allocaOperation =
443 if (!allocaOperation)
447 if (is<StructType>(allocaType))
449 context_->numAggregateStructAllocaNodes++;
450 context_->numAggregateAllocaNodes++;
454 context_->numAggregateAllocaNodes++;
459 context_->numSplitableTypeAggregateAllocaNodes++;
460 if (
auto allocaTraceInfo =
isSplitable(simpleNode))
462 context_->numSplitAggregateAllocaNodes++;
463 traceInfo.emplace_back(*allocaTraceInfo);
469 throw std::logic_error(
"Unhandled node type.");
474 std::vector<AllocaTraceInfo> traceInfo;
475 findAllocaNodes(region, traceInfo);
482 auto & allocaNode = *allocaTraceInfo.
allocaNode;
483 const auto allocaOperation =
dynamic_cast<const AllocaOperation *
>(&allocaNode.GetOperation());
485 auto & allocaType = *std::static_pointer_cast<const StructType>(allocaOperation->allocatedType());
487 const auto alignment = allocaOperation->alignment();
490 std::vector<rvsdg::Node *> elementAllocaNodes;
491 std::vector<rvsdg::Output *> allocaMemoryStates;
492 for (
const auto & elementType : allocaType.elementTypes())
494 auto & elementAlloca =
496 elementAllocaNodes.push_back(&elementAlloca);
508 allocaConsumer->GetOperation(),
511 JLM_ASSERT(GetElementPtrOperation::numIndices(*allocaConsumer) == 2);
512 auto & consumerRegion = *allocaConsumer->region();
515 GetElementPtrOperation::tryGetConstantIndices(*allocaConsumer).value();
516 JLM_ASSERT(indices.size() == 2 && indices[0] == 0);
518 auto elementAlloca = elementAllocaNodes[indices[1]];
520 auto & routedAddress = rvsdg::RouteToRegion(
521 AllocaOperation::getPointerOutput(*elementAlloca),
523 allocaConsumer->output(0)->divert_users(&routedAddress);
527 throw std::logic_error(
528 util::strfmt(
"Unhandled node type: ", allocaConsumer->DebugString()));
537 for (
const auto & allocaTraceInfo : traceInfo)
551 context_ = std::make_unique<Context>();
558 context_->numAggregateStructAllocaNodes,
559 context_->numSplitableTypeAggregateAllocaNodes,
560 context_->numSplitAggregateAllocaNodes);
static jlm::util::StatisticsCollector statisticsCollector
const char * numSplitAggregateAllocaNodesLabel_
~Statistics() noexcept override=default
void stop(const size_t numAggregateAllocaNodes, const size_t numAggregateStructAllocaNodes, const size_t numSplitableTypeAggregateAllocaNodes, const size_t numSplitAggregateAllocaNodes)
const char * numAggregateAllocaNodesLabel_
const char * numSplitableTypeAggregateAllocaNodesLabel_
static std::unique_ptr< Statistics > create(util::FilePath filePath)
const char * numAggregateStructAllocaNodesLabel_
const char * aggregateAllocaSplittingTimerLabel_
Aggregate Alloca Splitting Transformation.
std::unique_ptr< Context > context_
static bool checkGetElementPtrUsers(const rvsdg::SimpleNode &gepNode)
void Run(rvsdg::RvsdgModule &module, util::StatisticsCollector &statisticsCollector) override
Perform RVSDG transformation.
static void splitAllocaNode(const AllocaTraceInfo &allocaTraceInfo)
void splitAllocaNodes(rvsdg::RvsdgModule &rvsdgModule)
static bool isSplitableType(const rvsdg::Type &type)
static std::optional< AllocaTraceInfo > isSplitable(rvsdg::SimpleNode &allocaNode)
~AggregateAllocaSplitting() noexcept override
std::vector< AllocaTraceInfo > findSplitableAllocaNodes(rvsdg::Region ®ion) const
static rvsdg::Input & getCountInput(rvsdg::Node &node)
static rvsdg::SimpleNode & createNode(std::shared_ptr< const rvsdg::Type > allocatedType, rvsdg::Output &count, const size_t alignment)
static rvsdg::Output & getMemoryStateOutput(rvsdg::Node &node)
static rvsdg::Output & getPointerOutput(rvsdg::Node &node)
const std::shared_ptr< const rvsdg::Type > & allocatedType() const noexcept
static std::optional< std::vector< uint64_t > > tryGetConstantIndices(const rvsdg::Node &node) noexcept
static rvsdg::Output * Create(const std::vector< rvsdg::Output * > &operands)
static rvsdg::Input & AddressInput(const rvsdg::Node &node) noexcept
Conditional operator / pattern matching.
std::variant< MatchVar, EntryVar > MapInput(const rvsdg::Input &input) const
Maps gamma input to its role (match variable or entry variable).
Region & GetRootRegion() const noexcept
void divert_users(jlm::rvsdg::Output *new_origin)
A phi node represents the fixpoint of mutually recursive definitions.
Represent acyclic RVSDG subgraphs.
NodeRange Nodes() noexcept
const std::optional< util::FilePath > & SourceFilePath() const noexcept
const SimpleOperation & GetOperation() const noexcept override
NodeOutput * output(size_t index) const noexcept
SubregionIteratorRange Subregions()
bool insert(ItemType item)
bool Contains(const ItemType &item) const noexcept
void CollectDemandedStatistics(std::unique_ptr< Statistics > statistics)
util::Timer & GetTimer(const std::string &name)
util::Timer & AddTimer(std::string name)
void AddMeasurement(std::string name, T value)
Global memory state passed between functions.
bool IsAggregateType(const jlm::rvsdg::Type &type)
void MatchTypeWithDefault(T &obj, const Fns &... fns)
Pattern match over subclass type of given object with default handler.
Region * TryGetOwnerRegion(const rvsdg::Input &input) noexcept
static std::string strfmt(Args... args)
std::vector< rvsdg::SimpleNode * > allocaConsumers
rvsdg::SimpleNode * allocaNode
AllocaTraceInfo(rvsdg::SimpleNode &allocaNode)
size_t numSplitableTypeAggregateAllocaNodes
size_t numAggregateAllocaNodes
size_t numSplitAggregateAllocaNodes
size_t numAggregateStructAllocaNodes