22 return operation && operation->
narguments() == narguments();
28 return "MemoryStateMerge";
31 std::unique_ptr<rvsdg::Operation>
34 return std::make_unique<MemoryStateMergeOperation>(*
this);
37 std::optional<std::vector<rvsdg::Output *>>
40 const std::vector<rvsdg::Output *> &
operands)
48 std::optional<std::vector<rvsdg::Output *>>
51 const std::vector<rvsdg::Output *> &
operands)
62 template<
class TMemoryStateMergeOrJoinOperation>
63 std::vector<rvsdg::Output *>
67 std::is_same_v<TMemoryStateMergeOrJoinOperation, MemoryStateMergeOperation>
68 || std::is_same_v<TMemoryStateMergeOrJoinOperation, MemoryStateJoinOperation>,
69 "Template parameter T must be a MemoryStateMergeOperation or a MemoryStateJoinOperation!");
71 std::vector<rvsdg::Output *> newOperands;
74 auto [node, operation] =
75 rvsdg::TryGetSimpleNodeAndOptionalOp<TMemoryStateMergeOrJoinOperation>(*operand);
79 CollectNestedMemoryStateMergeOrJoinOperands<TMemoryStateMergeOrJoinOperation>(
81 newOperands.insert(newOperands.end(), nodeOperands.begin(), nodeOperands.end());
85 newOperands.emplace_back(operand);
92 std::optional<std::vector<rvsdg::Output *>>
95 const std::vector<rvsdg::Output *> &
operands)
98 CollectNestedMemoryStateMergeOrJoinOperands<MemoryStateMergeOperation>(
operands);
107 std::optional<std::vector<rvsdg::Output *>>
110 const std::vector<rvsdg::Output *> &
operands)
112 std::vector<rvsdg::Output *> newOperands;
115 auto [splitNode, splitOperation] =
116 rvsdg::TryGetSimpleNodeAndOptionalOp<MemoryStateSplitOperation>(*operand);
119 newOperands.emplace_back(splitNode->input(0)->origin());
123 newOperands.emplace_back(operand);
140 return operation && operation->
narguments() == narguments();
146 return "MemoryStateJoin";
149 std::unique_ptr<rvsdg::Operation>
152 return std::make_unique<MemoryStateJoinOperation>(*
this);
155 std::optional<std::vector<rvsdg::Output *>>
158 const std::vector<rvsdg::Output *> &
operands)
166 std::optional<std::vector<rvsdg::Output *>>
169 const std::vector<rvsdg::Output *> &
operands)
171 std::vector<rvsdg::Output *> newOperands;
178 seenOperands.
insert(operand);
179 newOperands.emplace_back(operand);
182 if (newOperands.size() ==
operands.size())
185 if (newOperands.size() == 1)
194 std::optional<std::vector<rvsdg::Output *>>
197 const std::vector<rvsdg::Output *> &
operands)
200 CollectNestedMemoryStateMergeOrJoinOperands<MemoryStateJoinOperation>(
operands);
205 const auto & memoryStateJoinNode =
CreateNode(std::move(newOperands));
206 return { { memoryStateJoinNode.output(0) } };
215 return operation && operation->
nresults() == nresults();
221 return "MemoryStateSplit";
224 std::unique_ptr<rvsdg::Operation>
227 return std::make_unique<MemoryStateSplitOperation>(*
this);
230 std::optional<std::vector<rvsdg::Output *>>
233 const std::vector<rvsdg::Output *> &
operands)
243 std::optional<std::vector<rvsdg::Output *>>
246 const std::vector<rvsdg::Output *> &
operands)
251 auto [splitNode, splitOperation] =
252 rvsdg::TryGetSimpleNodeAndOptionalOp<MemoryStateSplitOperation>(*operand);
256 const auto numResults = splitOperation->nresults() + operation.
nresults();
257 auto & newOperand = *splitNode->input(0)->origin();
258 auto results =
Create(newOperand, numResults);
260 for (
size_t n = 0; n < splitNode->noutputs(); n++)
262 const auto output = splitNode->output(n);
263 output->divert_users(results[n]);
266 return { { std::next(results.begin(), splitNode->noutputs()), results.end() } };
269 std::optional<std::vector<rvsdg::Output *>>
272 const std::vector<rvsdg::Output *> &
operands)
277 auto [mergeNode, mergeOperation] =
278 rvsdg::TryGetSimpleNodeAndOptionalOp<MemoryStateMergeOperation>(*operand);
279 if (!mergeOperation || mergeOperation->narguments() != operation.
nresults())
289 { memoryNodeIds.begin(), memoryNodeIds.end() });
291 if (memoryNodeIdsSet.Size() != memoryNodeIds.size())
292 throw std::logic_error(
"Found duplicated memory node identifiers.");
296 ToString(
const std::vector<MemoryNodeId> & memoryNodeIds)
299 for (
size_t n = 0; n < memoryNodeIds.size(); n++)
302 if (n != memoryNodeIds.size() - 1)
310 const std::vector<MemoryNodeId> & memoryNodeIds)
314 for (
size_t n = 0; n < memoryNodeIds.size(); n++)
326 return operation && operation->
nresults() == nresults()
327 && operation->memoryNodeIdToIndexMap_ == memoryNodeIdToIndexMap_;
336 std::unique_ptr<rvsdg::Operation>
339 return std::make_unique<LambdaEntryMemoryStateSplitOperation>(*
this);
347 const auto operation =
354 if (!operation->memoryNodeIdToIndexMap_.HasKey(memoryNodeId))
360 return node.
output(index);
366 auto [_, operation] =
367 rvsdg::TryGetSimpleNodeAndOptionalOp<LambdaEntryMemoryStateSplitOperation>(output);
370 return operation->memoryNodeIdToIndexMap_.LookupValue(output.
index());
373 std::optional<std::vector<rvsdg::Output *>>
376 const std::vector<rvsdg::Output *> &
operands)
380 auto [callEntryMergeNode, callEntryMergeOperation] =
381 rvsdg::TryGetSimpleNodeAndOptionalOp<CallEntryMemoryStateMergeOperation>(*
operands[0]);
382 if (!callEntryMergeOperation)
385 if (callEntryMergeOperation->narguments() != lambdaEntrySplitOperation.
nresults())
410 std::vector<rvsdg::Output *> newOperands;
411 for (
const auto & memoryNodeId : lambdaEntrySplitOperation.
getMemoryNodeIds())
417 newOperands.push_back(input->origin());
424 const std::vector<MemoryNodeId> & memoryNodeIds)
428 for (
size_t n = 0; n < memoryNodeIds.size(); n++)
449 std::unique_ptr<rvsdg::Operation>
452 return std::make_unique<LambdaExitMemoryStateMergeOperation>(*
this);
460 const auto operation =
467 if (!operation->MemoryNodeIdToIndex_.HasKey(memoryNodeId))
474 return node.
input(index);
480 auto [_, operation] =
481 rvsdg::TryGetSimpleNodeAndOptionalOp<LambdaExitMemoryStateMergeOperation>(input);
484 return operation->MemoryNodeIdToIndex_.LookupValue(input.
index());
487 std::optional<std::vector<rvsdg::Output *>>
490 const std::vector<rvsdg::Output *> &
operands)
495 bool replacedOperands =
false;
496 std::vector<rvsdg::Output *> newOperands;
499 auto [loadNode, loadOperation] = rvsdg::TryGetSimpleNodeAndOptionalOp<LoadOperation>(*operand);
502 newOperands.push_back(operand);
507 if (!rvsdg::IsOwnerNodeOperation<AllocaOperation>(*loadAddress))
509 newOperands.push_back(operand);
514 newOperands.push_back(newOperand);
515 replacedOperands =
true;
518 if (!replacedOperands)
526 std::optional<std::vector<rvsdg::Output *>>
529 const std::vector<rvsdg::Output *> &
operands)
534 bool replacedOperands =
false;
535 std::vector<rvsdg::Output *> newOperands;
538 auto [storeNode, storeOperation] =
539 rvsdg::TryGetSimpleNodeAndOptionalOp<StoreOperation>(*operand);
542 newOperands.push_back(operand);
547 if (!rvsdg::IsOwnerNodeOperation<AllocaOperation>(*storeAddress))
549 newOperands.push_back(operand);
554 newOperands.push_back(newOperand);
555 replacedOperands =
true;
558 if (!replacedOperands)
566 std::optional<std::vector<rvsdg::Output *>>
569 const std::vector<rvsdg::Output *> &
operands)
574 bool replacedOperands =
false;
575 std::vector<rvsdg::Output *> newOperands;
578 auto [allocaNode, allocaOperation] =
579 rvsdg::TryGetSimpleNodeAndOptionalOp<AllocaOperation>(*operand);
584 newOperands.push_back(newOperand);
585 replacedOperands =
true;
589 newOperands.push_back(operand);
593 if (!replacedOperands)
608 for (
size_t n = 0; n < memoryNodeIds.size(); n++)
610 MemoryNodeIdToIndex_.Insert(memoryNodeIds[n], n);
627 std::unique_ptr<rvsdg::Operation>
630 return std::make_unique<CallEntryMemoryStateMergeOperation>(*
this);
638 const auto operation =
645 if (!operation->MemoryNodeIdToIndex_.HasKey(memoryNodeId))
652 return node.
input(index);
662 for (
size_t n = 0; n < memoryNodeIds.size(); n++)
664 memoryNodeIdToIndexMap_.Insert(memoryNodeIds[n], n);
681 std::unique_ptr<rvsdg::Operation>
684 return std::make_unique<CallExitMemoryStateSplitOperation>(*
this);
692 const auto operation =
699 if (!operation->memoryNodeIdToIndexMap_.HasKey(memoryNodeId))
706 return node.
output(index);
712 auto [_, operation] =
713 rvsdg::TryGetSimpleNodeAndOptionalOp<CallExitMemoryStateSplitOperation>(output);
716 return operation->memoryNodeIdToIndexMap_.LookupValue(output.
index());
719 std::optional<std::vector<rvsdg::Output *>>
722 const std::vector<rvsdg::Output *> &
operands)
726 auto [lambdaExitMergeNode, lambdaExitMergeOperation] =
727 rvsdg::TryGetSimpleNodeAndOptionalOp<LambdaExitMemoryStateMergeOperation>(*
operands[0]);
728 if (!lambdaExitMergeOperation)
731 if (lambdaExitMergeNode->ninputs() != callExitSplitOperation.
nresults())
740 std::vector<rvsdg::Output *> newOperands;
744 *lambdaExitMergeNode,
747 newOperands.push_back(input->origin());
756 for (
auto & input : node.
Inputs())
758 if (is<MemoryStateType>(input.Type()))
764 for (
auto & output : node.
Outputs())
766 if (is<MemoryStateType>(output.Type()))
std::string debug_string() const override
bool operator==(const Operation &other) const noexcept override
util::BijectiveMap< MemoryNodeId, size_t > MemoryNodeIdToIndex_
std::unique_ptr< Operation > copy() const override
~CallEntryMemoryStateMergeOperation() noexcept override
std::vector< MemoryNodeId > getMemoryNodeIds() const noexcept
static rvsdg::Input * tryMapMemoryNodeIdToInput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
~CallExitMemoryStateSplitOperation() noexcept override
static MemoryNodeId mapOutputToMemoryNodeId(const rvsdg::Output &output)
static std::optional< std::vector< rvsdg::Output * > > NormalizeLambdaExitMemoryStateMerge(const CallExitMemoryStateSplitOperation &callExitSplitOperation, const std::vector< rvsdg::Output * > &operands)
util::BijectiveMap< MemoryNodeId, size_t > memoryNodeIdToIndexMap_
std::unique_ptr< Operation > copy() const override
static rvsdg::Output * tryMapMemoryNodeIdToOutput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
std::string debug_string() const override
bool operator==(const Operation &other) const noexcept override
std::vector< MemoryNodeId > getMemoryNodeIds() const noexcept
util::BijectiveMap< MemoryNodeId, size_t > memoryNodeIdToIndexMap_
static rvsdg::Output * tryMapMemoryNodeIdToOutput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
~LambdaEntryMemoryStateSplitOperation() noexcept override
std::vector< MemoryNodeId > getMemoryNodeIds() const noexcept
LambdaEntryMemoryStateSplitOperation(const std::vector< MemoryNodeId > &memoryNodeIds)
static std::optional< std::vector< rvsdg::Output * > > NormalizeCallEntryMemoryStateMerge(const LambdaEntryMemoryStateSplitOperation &lambdaEntrySplitOperation, const std::vector< rvsdg::Output * > &operands)
std::unique_ptr< Operation > copy() const override
static MemoryNodeId mapOutputToMemoryNodeId(const rvsdg::Output &output)
std::string debug_string() const override
static std::optional< std::vector< rvsdg::Output * > > NormalizeLoadFromAlloca(const LambdaExitMemoryStateMergeOperation &operation, const std::vector< rvsdg::Output * > &operands)
std::vector< MemoryNodeId > getMemoryNodeIds() const noexcept
static std::optional< std::vector< rvsdg::Output * > > NormalizeStoreToAlloca(const LambdaExitMemoryStateMergeOperation &operation, const std::vector< rvsdg::Output * > &operands)
util::BijectiveMap< MemoryNodeId, size_t > MemoryNodeIdToIndex_
static rvsdg::Input * tryMapMemoryNodeIdToInput(const rvsdg::SimpleNode &node, MemoryNodeId memoryNodeId)
static rvsdg::Node & CreateNode(rvsdg::Region ®ion, const std::vector< rvsdg::Output * > &operands, const std::vector< MemoryNodeId > &memoryNodeIds)
LambdaExitMemoryStateMergeOperation(const std::vector< MemoryNodeId > &memoryNodeIds)
std::unique_ptr< Operation > copy() const override
static std::optional< std::vector< rvsdg::Output * > > NormalizeAlloca(const LambdaExitMemoryStateMergeOperation &operation, const std::vector< rvsdg::Output * > &operands)
~LambdaExitMemoryStateMergeOperation() noexcept override
static MemoryNodeId mapInputToMemoryNodeId(const rvsdg::Input &input)
std::string debug_string() const override
static rvsdg::Input & AddressInput(const rvsdg::Node &node) noexcept
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
std::unique_ptr< Operation > copy() const override
static std::optional< std::vector< rvsdg::Output * > > NormalizeDuplicateOperands(const MemoryStateJoinOperation &operation, const std::vector< rvsdg::Output * > &operands)
Removes duplicated operands from the MemoryStateJoinOperation.
static std::optional< std::vector< rvsdg::Output * > > NormalizeNestedJoins(const MemoryStateJoinOperation &operation, const std::vector< rvsdg::Output * > &operands)
Fuses nested MemoryStateJoinOperation nodes into a single node.
static std::optional< std::vector< rvsdg::Output * > > NormalizeSingleOperand(const MemoryStateJoinOperation &operation, const std::vector< rvsdg::Output * > &operands)
Removes the MemoryStateJoinOperation as it has only a single operand, i.e., no joining is performed.
std::string debug_string() const override
static rvsdg::SimpleNode & CreateNode(const std::vector< rvsdg::Output * > &operands)
~MemoryStateJoinOperation() noexcept override
static std::optional< std::vector< rvsdg::Output * > > NormalizeNestedMerges(const MemoryStateMergeOperation &operation, const std::vector< rvsdg::Output * > &operands)
Fuses nested merges into a single merge.
static rvsdg::Output * Create(const std::vector< rvsdg::Output * > &operands)
~MemoryStateMergeOperation() noexcept override
std::unique_ptr< Operation > copy() const override
static std::optional< std::vector< rvsdg::Output * > > NormalizeMergeSplit(const MemoryStateMergeOperation &operation, const std::vector< rvsdg::Output * > &operands)
Fuses nested splits into a single merge.
static std::optional< std::vector< rvsdg::Output * > > NormalizeSingleOperand(const MemoryStateMergeOperation &operation, const std::vector< rvsdg::Output * > &operands)
Removes the MemoryStateMergeOperation as it has only a single operand, i.e., no merging is performed.
static std::optional< std::vector< rvsdg::Output * > > NormalizeDuplicateOperands(const MemoryStateMergeOperation &operation, const std::vector< rvsdg::Output * > &operands)
Removes duplicated operands from the MemoryStateMergeOperation.
std::string debug_string() const override
static std::optional< std::vector< rvsdg::Output * > > NormalizeNestedSplits(const MemoryStateSplitOperation &operation, const std::vector< rvsdg::Output * > &operands)
Fuses nested splits into a single split.
~MemoryStateSplitOperation() noexcept override
static std::optional< std::vector< rvsdg::Output * > > NormalizeSingleResult(const MemoryStateSplitOperation &operation, const std::vector< rvsdg::Output * > &operands)
Removes the MemoryStateSplitOperation as it has only a single result, i.e., no splitting is performed...
static std::vector< rvsdg::Output * > Create(rvsdg::Output &operand, const size_t numResults)
std::unique_ptr< Operation > copy() const override
std::string debug_string() const override
static std::optional< std::vector< rvsdg::Output * > > NormalizeSplitMerge(const MemoryStateSplitOperation &operation, const std::vector< rvsdg::Output * > &operands)
Removes an idempotent split-merge pair.
static std::shared_ptr< const MemoryStateType > Create()
static rvsdg::Input & MapMemoryStateOutputToInput(const rvsdg::Output &output)
static rvsdg::Input & AddressInput(const rvsdg::Node &node) noexcept
static jlm::rvsdg::Output * Create(rvsdg::Region ®ion, std::shared_ptr< const jlm::rvsdg::Type > type)
OutputIteratorRange Outputs() noexcept
InputIteratorRange Inputs() noexcept
NodeOutput * output(size_t index) const noexcept
size_t ninputs() const noexcept
size_t noutputs() const noexcept
size_t index() const noexcept
const SimpleOperation & GetOperation() const noexcept override
NodeInput * input(size_t index) const noexcept
NodeOutput * output(size_t index) const noexcept
const std::shared_ptr< const rvsdg::Type > & result(size_t index) const noexcept
size_t nresults() const noexcept
size_t narguments() const noexcept
const V & LookupKey(const K &key) const
bool Insert(const K &key, const V &value)
bool insert(ItemType item)
std::size_t Size() const noexcept
bool Contains(const ItemType &item) const noexcept
IteratorRange< ItemConstIterator > Items() const noexcept
Global memory state passed between functions.
bool hasMemoryState(const rvsdg::Node &node)
std::vector< rvsdg::Output * > CollectNestedMemoryStateMergeOrJoinOperands(const std::vector< rvsdg::Output * > &operands)
static std::string ToString(const std::vector< MemoryNodeId > &memoryNodeIds)
static void CheckMemoryNodeIds(const std::vector< MemoryNodeId > &memoryNodeIds)
static std::vector< jlm::rvsdg::Output * > operands(const Node *node)
static std::string strfmt(Args... args)